From 87c29451cf200cd81c3cbc9022b98b20940ebe61 Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Thu, 22 Aug 2024 16:07:41 +0200 Subject: [PATCH] Deployed 5844490 with MkDocs version: 1.5.3 --- faq/index.html | 9 +++++++-- index.html | 6 +++--- search/search_index.json | 2 +- sitemap.xml.gz | Bin 439 -> 439 bytes 4 files changed, 11 insertions(+), 6 deletions(-) diff --git a/faq/index.html b/faq/index.html index 8260cc29..22fd4fa0 100644 --- a/faq/index.html +++ b/faq/index.html @@ -807,8 +807,13 @@

Frequently asked questions

-

What units does Jaxley use? -How can I save and load cells and networks?

+ +

See also the discussion page and the issue +tracker on the Jaxley GitHub repository for +recent questions and problems.

diff --git a/index.html b/index.html index 363aebdc..80c6839a 100644 --- a/index.html +++ b/index.html @@ -941,19 +941,19 @@

Home

  • elegant mechanisms for parameter sharing
  • Getting started

    -

    Jaxley allows to simulate biophysical neuron models on CPU or GPU: +

    Jaxley allows to simulate biophysical neuron models on CPU, GPU, or TPU:

    import matplotlib.pyplot as plt
     from jax import config
     
     import jaxley as jx
     from jaxley.channels import HH
     
    -config.update("jax_platform_name", "cpu")  # or "gpu".
    +config.update("jax_platform_name", "cpu")  # Or "gpu" / "tpu".
     
     cell = jx.Cell()  # Define cell.
     cell.insert(HH())  # Insert channels.
     
    -current = jx.step_current(i_delay=1.0, i_dur=1.0, i_amp=0.1, dt=0.025, t_max=10.0)
    +current = jx.step_current(i_delay=1.0, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=10.0)
     cell.stimulate(current)  # Stimulate with step current.
     cell.record("v")  # Record voltage.
     
    diff --git a/search/search_index.json b/search/search_index.json
    index 14213d5b..5079074f 100644
    --- a/search/search_index.json
    +++ b/search/search_index.json
    @@ -1 +1 @@
    -{"config":{"lang":["en"],"separator":"[\\s\\-]+","pipeline":["stopWordFilter"]},"docs":[{"location":"","title":"Home","text":"

    Jaxley is a differentiable simulator for biophysical neuron models in JAX. Its key features are:

    • automatic differentiation, allowing gradient-based optimization of thousands of parameters
    • support for CPU, GPU, or TPU without any changes to the code
    • jit-compilation, making it as fast as other packages while being fully written in python
    • backward-Euler solver for stable numerical solution of multicompartment neurons
    • elegant mechanisms for parameter sharing
    "},{"location":"#getting-started","title":"Getting started","text":"

    Jaxley allows to simulate biophysical neuron models on CPU or GPU:

    import matplotlib.pyplot as plt\nfrom jax import config\n\nimport jaxley as jx\nfrom jaxley.channels import HH\n\nconfig.update(\"jax_platform_name\", \"cpu\")  # or \"gpu\".\n\ncell = jx.Cell()  # Define cell.\ncell.insert(HH())  # Insert channels.\n\ncurrent = jx.step_current(i_delay=1.0, i_dur=1.0, i_amp=0.1, dt=0.025, t_max=10.0)\ncell.stimulate(current)  # Stimulate with step current.\ncell.record(\"v\")  # Record voltage.\n\nv = jx.integrate(cell)  # Run simulation.\nplt.plot(v.T)  # Plot voltage trace.\n

    If you want to learn more, we have tutorials on how to:

    • simulate morphologically detailed neurons
    • simulate networks of such neurons
    • set parameters of cells and networks
    • speed up simulations with GPUs and jit
    • define your own channels and synapses
    • define groups
    • read and handle SWC files
    • compute the gradient and train biophysical models
    "},{"location":"#installation","title":"Installation","text":"

    Jaxley is available on pypi:

    pip install jaxley\n
    This will install Jaxley with CPU support. If you want GPU support, follow the instructions on the JAX github repository to install JAX with GPU support (in addition to installing Jaxley). For example, for NVIDIA GPUs, run
    pip install -U \"jax[cuda12]\"\n

    "},{"location":"#feedback-and-contributions","title":"Feedback and Contributions","text":"

    We welcome any feedback on how Jaxley is working for your neuron models and are happy to receive bug reports, pull requests and other feedback (see contribute). We wish to maintain a positive community, please read our Code of Conduct.

    "},{"location":"#license","title":"License","text":"

    Apache License Version 2.0 (Apache-2.0)

    "},{"location":"#citation","title":"Citation","text":"

    If you use Jaxley, consider citing the corresponding paper:

    @article{deistler2024differentiable,\n  doi = {10.1101/2024.08.21.608979},\n  year = {2024},\n  publisher = {Cold Spring Harbor Laboratory},\n  author = {Deistler, Michael and Kadhim, Kyra L. and Pals, Matthijs and Beck, Jonas and Huang, Ziwei and Gloeckler, Manuel and Lappalainen, Janne K. and Schr{\\\"o}der, Cornelius and Berens, Philipp and Gon{\\c c}alves, Pedro J. and Macke, Jakob H.},\n  title = {Differentiable simulation enables large-scale training of detailed biophysical models of neural dynamics},\n  journal = {bioRxiv}\n}\n
    "},{"location":"code_of_conduct/","title":"Contributor Covenant Code of Conduct","text":""},{"location":"code_of_conduct/#our-pledge","title":"Our Pledge","text":"

    We as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, caste, color, religion, or sexual identity and orientation.

    We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community.

    "},{"location":"code_of_conduct/#our-standards","title":"Our Standards","text":"

    Examples of behavior that contributes to a positive environment for our community include:

    • Demonstrating empathy and kindness toward other people
    • Being respectful of differing opinions, viewpoints, and experiences
    • Giving and gracefully accepting constructive feedback
    • Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience
    • Focusing on what is best not just for us as individuals, but for the overall community

    Examples of unacceptable behavior include:

    • The use of sexualized language or imagery, and sexual attention or advances of any kind
    • Trolling, insulting or derogatory comments, and personal or political attacks
    • Public or private harassment
    • Publishing others\u2019 private information, such as a physical or email address, without their explicit permission
    • Other conduct which could reasonably be considered inappropriate in a professional setting
    "},{"location":"code_of_conduct/#enforcement-responsibilities","title":"Enforcement Responsibilities","text":"

    Community leaders are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any behavior that they deem inappropriate, threatening, offensive, or harmful.

    Community leaders have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, and will communicate reasons for moderation decisions when appropriate.

    "},{"location":"code_of_conduct/#scope","title":"Scope","text":"

    This Code of Conduct applies within all community spaces, and also applies when an individual is officially representing the community in public spaces. Examples of representing our community include using an official e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event.

    "},{"location":"code_of_conduct/#enforcement","title":"Enforcement","text":"

    Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting jaxley developer Michael Deistler via email (michael.deistler@uni-tuebingen.de). All complaints will be reviewed and investigated promptly and fairly.

    All community leaders are obligated to respect the privacy and security of the reporter of any incident.

    "},{"location":"code_of_conduct/#enforcement-guidelines","title":"Enforcement Guidelines","text":"

    Community leaders will follow these Community Impact Guidelines in determining the consequences for any action they deem in violation of this Code of Conduct:

    "},{"location":"code_of_conduct/#1-correction","title":"1. Correction","text":"

    Community Impact: Use of inappropriate language or other behavior deemed unprofessional or unwelcome in the community.

    Consequence: A private, written warning from community leaders, providing clarity around the nature of the violation and an explanation of why the behavior was inappropriate. A public apology may be requested.

    "},{"location":"code_of_conduct/#2-warning","title":"2. Warning","text":"

    Community Impact: A violation through a single incident or series of actions.

    Consequence: A warning with consequences for continued behavior. No interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, for a specified period of time. This includes avoiding interactions in community spaces as well as external channels like social media. Violating these terms may lead to a temporary or permanent ban.

    "},{"location":"code_of_conduct/#3-temporary-ban","title":"3. Temporary Ban","text":"

    Community Impact: A serious violation of community standards, including sustained inappropriate behavior.

    Consequence: A temporary ban from any sort of interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Violating these terms may lead to a permanent ban.

    "},{"location":"code_of_conduct/#4-permanent-ban","title":"4. Permanent Ban","text":"

    Community Impact: Demonstrating a pattern of violation of community standards, including sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals.

    Consequence: A permanent ban from any sort of public interaction within the community.

    "},{"location":"code_of_conduct/#attribution","title":"Attribution","text":"

    This Code of Conduct is adapted from the Contributor Covenant, version 2.1, available at https://www.contributor-covenant.org/version/2/1/code_of_conduct.html.

    Community Impact Guidelines were inspired by Mozilla\u2019s code of conduct enforcement ladder.

    For answers to common questions about this code of conduct, see the FAQ at https://www.contributor-covenant.org/faq. Translations are available at https://www.contributor-covenant.org/translations.

    "},{"location":"contribute/","title":"Guide","text":""},{"location":"contribute/#user-experiences-bugs-and-feature-requests","title":"User experiences, bugs, and feature requests","text":"

    To report bugs and suggest features (including better documentation), please head over to issues on GitHub.

    "},{"location":"contribute/#code-contributions","title":"Code contributions","text":"

    In general, we use pull requests to make changes to Jaxley. So, if you are planning to make a contribution, please fork, create a feature branch and then make a PR from your feature branch to the upstream Jaxley (details).

    "},{"location":"contribute/#development-environment","title":"Development environment","text":"

    Clone the repo and install via setup.py using pip install -e \".[dev]\" (the dev flag installs development and testing dependencies).

    "},{"location":"contribute/#style-conventions","title":"Style conventions","text":"

    For docstrings and comments, we use Google Style.

    Code needs to pass through the following tools, which are installed alongside Jaxley:

    black: Automatic code formatting for Python. You can run black manually from the console using black . in the top directory of the repository, which will format all files.

    isort: Used to consistently order imports. You can run isort manually from the console using isort in the top directory.

    black and isort are checked as part of our CI actions. If these checks fail please make sure you have installed the latest versions for each of them and run them locally.

    "},{"location":"contribute/#online-documentation","title":"Online documentation","text":"

    Most of the documentation is written in markdown (basic markdown guide).

    You can directly fix mistakes and suggest clearer formulations in markdown files simply by initiating a PR on through GitHub. Click on documentation file and look for the little pencil at top right.

    "},{"location":"credits/","title":"Credits","text":"

    Jaxley is a collaborative project between the groups of Jakob Macke (Uni T\u00fcbingen), Pedro Gon\u00e7alves (KU Leuven / NERF), and Philipp Berens (Uni T\u00fcbingen).

    "},{"location":"credits/#license","title":"License","text":"

    Jaxley is licensed under the Apache License Version 2.0 (Apache-2.0) and

    Copyright (C) 2024 Michael Deistler, Jakob H. Macke, Pedro J. Goncalves, Philipp Berens.

    "},{"location":"credits/#important-dependencies-and-prior-art","title":"Important dependencies and prior art","text":"
    • We greatly benefited from previous toolboxes for simulating multicompartment neurons, in particular NEURON.
    "},{"location":"credits/#funding","title":"Funding","text":"

    This work was supported by the German Research Foundation (DFG) through Germany\u2019s Excellence Strategy (EXC 2064 \u2013 Project number 390727645) and the CRC 1233 \u201cRobust Vision\u201d, the German Federal Ministry of Education and Research (Tu\u0308bingen AI Center, FKZ: 01IS18039A), the \u2018Certification and Foundations of Safe Machine Learning Systems in Healthcare\u2019 project funded by the Carl Zeiss Foundation, and the European Union (ERC, \u201cDeepCoMechTome\u201d, ref. 101089288, \u201cNextMechMod\u201d, ref. 101039115).

    "},{"location":"faq/","title":"Frequently asked questions","text":"

    What units does Jaxley use? How can I save and load cells and networks?

    "},{"location":"install/","title":"Installation","text":""},{"location":"install/#install-the-most-recent-stable-version","title":"Install the most recent stable version","text":"

    Jaxley is available on PyPI:

    pip install jaxley\n
    This will install Jaxley with CPU support. If you want GPU support, follow the instructions on the JAX github repository to install JAX with GPU support (in addition to installing Jaxley). For example, for NVIDIA GPUs, run
    pip install -U \"jax[cuda12]\"\n

    "},{"location":"install/#install-from-source","title":"Install from source","text":"

    You can also install Jaxley from source:

    git clone https://github.com/jaxleyverse/jaxley.git\ncd jaxley\npip install -e .\n

    Note that pip>=21.3 is required to install the editable version with pyproject.toml see pip docs.

    "},{"location":"faq/question_01/","title":"What units does Jaxley use?","text":"

    Jaxley uses the same units as the NEURON simulator, which are listed here.

    "},{"location":"faq/question_02/","title":"How can I save and load cells and networks?","text":"

    All modules (i.e., compartments, branches, cells, and networks) in Jaxley can be saved and loaded with pickle:

    import jaxley as jx\nimport pickle\n\n# ... define network, cell, etc.\nnetwork = jx.Network([cell1, cell2])\n\n# Save.\nwith open(\"path/to/file.pkl\", \"wb\") as handle:\n    pickle.dump(network, handle)\n\n# Load.\nwith open(\"path/to/file.pkl\", \"rb\") as handle:\n    network = pickle.load(handle)\n

    "},{"location":"reference/connect/","title":"Connecting Cells","text":""},{"location":"reference/connect/#jaxley.connect.connect","title":"connect(pre, post, synapse_type)","text":"

    Connect two compartments with a chemical synapse.

    The pre- and postsynaptic compartments must be different compartments of the same network.

    Parameters:

    Name Type Description Default pre CompartmentView

    View of the presynaptic compartment.

    required post CompartmentView

    View of the postsynaptic compartment.

    required synapse_type Synapse

    The synapse to append

    required Source code in jaxley/connect.py
    def connect(\n    pre: \"CompartmentView\",\n    post: \"CompartmentView\",\n    synapse_type: \"Synapse\",\n):\n    \"\"\"Connect two compartments with a chemical synapse.\n\n    The pre- and postsynaptic compartments must be different compartments of the\n    same network.\n\n    Args:\n        pre: View of the presynaptic compartment.\n        post: View of the postsynaptic compartment.\n        synapse_type: The synapse to append\n    \"\"\"\n    assert is_same_network(\n        pre, post\n    ), \"Pre and post compartments must be part of the same network.\"\n    assert np.all(\n        pre_comp_not_equal_post_comp(pre, post)\n    ), \"Pre and post compartments must be different.\"\n\n    pre._append_multiple_synapses(pre.view, post.view, synapse_type)\n
    "},{"location":"reference/connect/#jaxley.connect.connectivity_matrix_connect","title":"connectivity_matrix_connect(pre_cell_view, post_cell_view, synapse_type, connectivity_matrix)","text":"

    Appends multiple connections which build a custom connected network.

    Connects pre- and postsynaptic cells according to a custom connectivity matrix. Entries > 0 in the matrix indicate a connection between the corresponding cells. Connections are from branch 0 location 0 to a randomly chosen branch and loc.

    Parameters:

    Name Type Description Default pre_cell_view CellView

    View of the presynaptic cell.

    required post_cell_view CellView

    View of the postsynaptic cell.

    required synapse_type Synapse

    The synapse to append.

    required connectivity_matrix ndarray[bool]

    A boolean matrix indicating the connections between cells.

    required Source code in jaxley/connect.py
    def connectivity_matrix_connect(\n    pre_cell_view: \"CellView\",\n    post_cell_view: \"CellView\",\n    synapse_type: \"Synapse\",\n    connectivity_matrix: np.ndarray[bool],\n):\n    \"\"\"Appends multiple connections which build a custom connected network.\n\n    Connects pre- and postsynaptic cells according to a custom connectivity matrix.\n    Entries > 0 in the matrix indicate a connection between the corresponding cells.\n    Connections are from branch 0 location 0 to a randomly chosen branch and loc.\n\n    Args:\n        pre_cell_view: View of the presynaptic cell.\n        post_cell_view: View of the postsynaptic cell.\n        synapse_type: The synapse to append.\n        connectivity_matrix: A boolean matrix indicating the connections between cells.\n    \"\"\"\n    # Get pre- and postsynaptic cell indices.\n    pre_cell_inds, post_cell_inds = get_pre_post_inds(pre_cell_view, post_cell_view)\n\n    assert connectivity_matrix.shape == (\n        pre_cell_view.shape[0],\n        post_cell_view.shape[0],\n    ), \"Connectivity matrix must have shape (num_pre, num_post).\"\n    assert connectivity_matrix.dtype == bool, \"Connectivity matrix must be boolean.\"\n\n    # get connection pairs from connectivity matrix\n    from_idx, to_idx = np.where(connectivity_matrix)\n    pre_cell_inds = pre_cell_inds[from_idx]\n    post_cell_inds = post_cell_inds[to_idx]\n\n    # Sample random postsynaptic compartments (global comp indices).\n    global_post_indices = [\n        sample_comp(post_cell_view, cell_idx).index[0] for cell_idx in post_cell_inds\n    ]\n    post_rows = post_cell_view.view.loc[global_post_indices]\n\n    idcs_to_zero = np.zeros_like(from_idx)\n    get_global_idx = post_cell_view.pointer._local_inds_to_global\n    global_pre_indices = get_global_idx(pre_cell_inds, idcs_to_zero, idcs_to_zero)\n    pre_rows = pre_cell_view.view.loc[global_pre_indices]\n\n    pre_cell_view._append_multiple_synapses(pre_rows, post_rows, synapse_type)\n
    "},{"location":"reference/connect/#jaxley.connect.fully_connect","title":"fully_connect(pre_cell_view, post_cell_view, synapse_type)","text":"

    Appends multiple connections which build a fully connected layer.

    Connections are from branch 0 location 0 to a randomly chosen branch and loc.

    Parameters:

    Name Type Description Default pre_cell_view CellView

    View of the presynaptic cell.

    required post_cell_view CellView

    View of the postsynaptic cell.

    required synapse_type Synapse

    The synapse to append.

    required Source code in jaxley/connect.py
    def fully_connect(\n    pre_cell_view: \"CellView\",\n    post_cell_view: \"CellView\",\n    synapse_type: \"Synapse\",\n):\n    \"\"\"Appends multiple connections which build a fully connected layer.\n\n    Connections are from branch 0 location 0 to a randomly chosen branch and loc.\n\n    Args:\n        pre_cell_view: View of the presynaptic cell.\n        post_cell_view: View of the postsynaptic cell.\n        synapse_type: The synapse to append.\n    \"\"\"\n    # Get pre- and postsynaptic cell indices.\n    pre_cell_inds, post_cell_inds = get_pre_post_inds(pre_cell_view, post_cell_view)\n    num_pre, num_post = len(pre_cell_inds), len(post_cell_inds)\n\n    # Infer indices of (random) postsynaptic compartments.\n    global_post_indices = (\n        post_cell_view.view.groupby(\"cell_index\")\n        .sample(num_pre, replace=True)\n        .index.to_numpy()\n    )\n    global_post_indices = global_post_indices.reshape((-1, num_pre), order=\"F\").ravel()\n    post_rows = post_cell_view.view.loc[global_post_indices]\n\n    # Pre-synapse is at the zero-eth branch and zero-eth compartment.\n    pre_rows = pre_cell_view[0, 0].view\n    # Repeat rows `num_post` times. See SO 50788508.\n    pre_rows = pre_rows.loc[pre_rows.index.repeat(num_post)].reset_index(drop=True)\n\n    pre_cell_view._append_multiple_synapses(pre_rows, post_rows, synapse_type)\n
    "},{"location":"reference/connect/#jaxley.connect.get_pre_post_inds","title":"get_pre_post_inds(pre_cell_view, post_cell_view)","text":"

    Get the unique cell indices of the pre- and postsynaptic cells.

    Source code in jaxley/connect.py
    def get_pre_post_inds(\n    pre_cell_view: \"CellView\", post_cell_view: \"CellView\"\n) -> Tuple[np.ndarray, np.ndarray]:\n    \"\"\"Get the unique cell indices of the pre- and postsynaptic cells.\"\"\"\n    pre_cell_inds = np.unique(pre_cell_view.view[\"cell_index\"].to_numpy())\n    post_cell_inds = np.unique(post_cell_view.view[\"cell_index\"].to_numpy())\n    return pre_cell_inds, post_cell_inds\n
    "},{"location":"reference/connect/#jaxley.connect.is_same_network","title":"is_same_network(pre, post)","text":"

    Check if views are from the same network.

    Source code in jaxley/connect.py
    def is_same_network(pre: \"View\", post: \"View\") -> bool:\n    \"\"\"Check if views are from the same network.\"\"\"\n    is_in_net = \"network\" in pre.pointer.__class__.__name__.lower()\n    is_in_same_net = pre.pointer is post.pointer\n    return is_in_net and is_in_same_net\n
    "},{"location":"reference/connect/#jaxley.connect.pre_comp_not_equal_post_comp","title":"pre_comp_not_equal_post_comp(pre, post)","text":"

    Check if pre and post compartments are different.

    Source code in jaxley/connect.py
    def pre_comp_not_equal_post_comp(\n    pre: \"CompartmentView\", post: \"CompartmentView\"\n) -> np.ndarray[bool]:\n    \"\"\"Check if pre and post compartments are different.\"\"\"\n    cols = [\"cell_index\", \"branch_index\", \"comp_index\"]\n    return np.any(pre.view[cols].values != post.view[cols].values, axis=1)\n
    "},{"location":"reference/connect/#jaxley.connect.sample_comp","title":"sample_comp(cell_view, cell_idx, num=1, replace=True)","text":"

    Sample a compartment from a cell.

    Returns View with shape (num, num_cols).

    Source code in jaxley/connect.py
    def sample_comp(\n    cell_view: \"CellView\", cell_idx: int, num: int = 1, replace=True\n) -> \"CompartmentView\":\n    \"\"\"Sample a compartment from a cell.\n\n    Returns View with shape (num, num_cols).\"\"\"\n    cell_idx_view = lambda view, cell_idx: view[view[\"cell_index\"] == cell_idx]\n    return cell_idx_view(cell_view.view, cell_idx).sample(num, replace=replace)\n
    "},{"location":"reference/connect/#jaxley.connect.sparse_connect","title":"sparse_connect(pre_cell_view, post_cell_view, synapse_type, p)","text":"

    Appends multiple connections which build a sparse, randomly connected layer.

    Connections are from branch 0 location 0 to a randomly chosen branch and loc.

    Parameters:

    Name Type Description Default pre_cell_view CellView

    View of the presynaptic cell.

    required post_cell_view CellView

    View of the postsynaptic cell.

    required synapse_type Synapse

    The synapse to append.

    required p float

    Probability of connection.

    required Source code in jaxley/connect.py
    def sparse_connect(\n    pre_cell_view: \"CellView\",\n    post_cell_view: \"CellView\",\n    synapse_type: \"Synapse\",\n    p: float,\n):\n    \"\"\"Appends multiple connections which build a sparse, randomly connected layer.\n\n    Connections are from branch 0 location 0 to a randomly chosen branch and loc.\n\n    Args:\n        pre_cell_view: View of the presynaptic cell.\n        post_cell_view: View of the postsynaptic cell.\n        synapse_type: The synapse to append.\n        p: Probability of connection.\n    \"\"\"\n    # Get pre- and postsynaptic cell indices.\n    pre_cell_inds, post_cell_inds = get_pre_post_inds(pre_cell_view, post_cell_view)\n    num_pre, num_post = len(pre_cell_inds), len(post_cell_inds)\n\n    num_connections = np.random.binomial(num_pre * num_post, p)\n    pre_syn_neurons = np.random.choice(pre_cell_inds, size=num_connections)\n    post_syn_neurons = np.random.choice(post_cell_inds, size=num_connections)\n\n    # Sort the synapses only for convenience of inspecting `.edges`.\n    sorting = np.argsort(pre_syn_neurons)\n    pre_syn_neurons = pre_syn_neurons[sorting]\n    post_syn_neurons = post_syn_neurons[sorting]\n\n    # Post-synapse is a randomly chosen branch and compartment.\n    global_post_indices = [\n        sample_comp(post_cell_view, cell_idx).index[0] for cell_idx in post_syn_neurons\n    ]\n    post_rows = post_cell_view.view.loc[global_post_indices]\n\n    # Pre-synapse is at the zero-eth branch and zero-eth compartment.\n    idcs_to_zero = np.zeros_like(num_pre)\n    get_global_idx = pre_cell_view.pointer._local_inds_to_global\n    global_pre_indices = get_global_idx(pre_syn_neurons, idcs_to_zero, idcs_to_zero)\n    pre_rows = pre_cell_view.view.loc[global_pre_indices]\n\n    pre_cell_view._append_multiple_synapses(pre_rows, post_rows, synapse_type)\n
    "},{"location":"reference/integration/","title":"Simulation","text":""},{"location":"reference/integration/#jaxley.integrate.integrate","title":"integrate(module, params=[], *, param_state=None, data_stimuli=None, t_max=None, delta_t=0.025, solver='bwd_euler', voltage_solver='jaxley.stone', checkpoint_lengths=None, all_states=None, return_states=False)","text":"

    Solves ODE and simulates neuron model.

    Parameters:

    Name Type Description Default params List[Dict[str, ndarray]]

    Trainable parameters returned by get_parameters().

    [] param_state Optional[List[Dict]]

    Parameters returned by data_set.

    None data_stimuli Optional[Tuple[ndarray, DataFrame]]

    Outputs of .data_stimulate(), only needed if stimuli change across function calls.

    None t_max Optional[float]

    Duration of the simulation in milliseconds. If t_max is greater than the length of the stimulus input, the stimulus will be padded at the end with zeros. If t_max is smaller, then the stimulus with be truncated.

    None delta_t float

    Time step of the solver in milliseconds.

    0.025 solver str

    Which ODE solver to use. Either of [\u201cfwd_euler\u201d, \u201cbwd_euler\u201d, \u201ccranck\u201d].

    'bwd_euler' tridiag_solver

    Algorithm to solve tridiagonal systems. The different options only affect bwd_euler and cranck solvers. Either of [\u201cstone\u201d, \u201cthomas\u201d], where stone is much faster on GPU for long branches with many compartments and thomas is slightly faster on CPU (thomas is used in NEURON).

    required checkpoint_lengths Optional[List[int]]

    Number of timesteps at every level of checkpointing. The prod(checkpoint_lengths) must be larger or equal to the desired number of simulated timesteps. Warning: the simulation is run for prod(checkpoint_lengths) timesteps, and the result is posthoc truncated to the desired simulation length. Therefore, a poor choice of checkpoint_lengths can lead to longer simulation time. If None, no checkpointing is applied.

    None all_states Optional[Dict]

    An optional initial state that was returned by a previous jx.integrate(..., return_states=True) run. Overrides potentially trainable initial states.

    None return_states bool

    If True, it returns all states such that the current state of the Module can be set with set_states.

    False Source code in jaxley/integrate.py
    def integrate(\n    module: Module,\n    params: List[Dict[str, jnp.ndarray]] = [],\n    *,\n    param_state: Optional[List[Dict]] = None,\n    data_stimuli: Optional[Tuple[jnp.ndarray, pd.DataFrame]] = None,\n    t_max: Optional[float] = None,\n    delta_t: float = 0.025,\n    solver: str = \"bwd_euler\",\n    voltage_solver: str = \"jaxley.stone\",\n    checkpoint_lengths: Optional[List[int]] = None,\n    all_states: Optional[Dict] = None,\n    return_states: bool = False,\n) -> jnp.ndarray:\n    \"\"\"\n    Solves ODE and simulates neuron model.\n\n    Args:\n        params: Trainable parameters returned by `get_parameters()`.\n        param_state: Parameters returned by `data_set`.\n        data_stimuli: Outputs of `.data_stimulate()`, only needed if stimuli change\n            across function calls.\n        t_max: Duration of the simulation in milliseconds. If `t_max` is greater than\n            the length of the stimulus input, the stimulus will be padded at the end\n            with zeros. If `t_max` is smaller, then the stimulus with be truncated.\n        delta_t: Time step of the solver in milliseconds.\n        solver: Which ODE solver to use. Either of [\"fwd_euler\", \"bwd_euler\", \"cranck\"].\n        tridiag_solver: Algorithm to solve tridiagonal systems. The  different options\n            only affect `bwd_euler` and `cranck` solvers. Either of [\"stone\",\n            \"thomas\"], where `stone` is much faster on GPU for long branches\n            with many compartments and `thomas` is slightly faster on CPU (`thomas` is\n            used in NEURON).\n        checkpoint_lengths: Number of timesteps at every level of checkpointing. The\n            `prod(checkpoint_lengths)` must be larger or equal to the desired number of\n            simulated timesteps. Warning: the simulation is run for\n            `prod(checkpoint_lengths)` timesteps, and the result is posthoc truncated\n            to the desired simulation length. Therefore, a poor choice of\n            `checkpoint_lengths` can lead to longer simulation time. If `None`, no\n            checkpointing is applied.\n        all_states: An optional initial state that was returned by a previous\n            `jx.integrate(..., return_states=True)` run. Overrides potentially\n            trainable initial states.\n        return_states: If True, it returns all states such that the current state of\n            the `Module` can be set with `set_states`.\n    \"\"\"\n\n    assert module.initialized, \"Module is not initialized, run `.initialize()`.\"\n    module.to_jax()  # Creates `.jaxnodes` from `.nodes` and `.jaxedges` from `.edges`.\n\n    # Initialize the external inputs and their indices.\n    externals = module.externals.copy()\n    external_inds = module.external_inds.copy()\n\n    # If stimulus is inserted, add it to the external inputs.\n    if \"i\" in module.externals.keys() or data_stimuli is not None:\n        if \"i\" in module.externals.keys():\n            if data_stimuli is not None:\n                externals[\"i\"] = jnp.concatenate([externals[\"i\"], data_stimuli[0]])\n                external_inds[\"i\"] = jnp.concatenate(\n                    [external_inds[\"i\"], data_stimuli[1].comp_index.to_numpy()]\n                )\n        else:\n            externals[\"i\"] = data_stimuli[0]\n            external_inds[\"i\"] = data_stimuli[1].comp_index.to_numpy()\n    else:\n        externals[\"i\"] = jnp.asarray([[]]).astype(\"float\")\n        external_inds[\"i\"] = jnp.asarray([]).astype(\"int32\")\n\n    if not externals.keys():\n        # No stimulus was inserted and no clamp was set.\n        assert (\n            t_max is not None\n        ), \"If no stimulus or clamp are inserted you have to specify the simulation duration at `jx.integrate(..., t_max=)`.\"\n\n    for key in externals.keys():\n        externals[key] = externals[key].T  # Shape `(time, num_stimuli)`.\n\n    rec_inds = module.recordings.rec_index.to_numpy()\n    rec_states = module.recordings.state.to_numpy()\n\n    # Shorten or pad stimulus depending on `t_max`.\n    if t_max is not None:\n        t_max_steps = int(t_max // delta_t + 1)\n\n        # Pad or truncate the stimulus.\n        if \"i\" in externals.keys() and t_max_steps > externals[\"i\"].shape[0]:\n            pad = jnp.zeros(\n                (t_max_steps - externals[\"i\"].shape[0], externals[\"i\"].shape[1])\n            )\n            externals[\"i\"] = jnp.concatenate((externals[\"i\"], pad))\n\n        for key in externals.keys():\n            if t_max_steps > externals[key].shape[0]:\n                raise NotImplementedError(\n                    \"clamp must be at least as long as simulation.\"\n                )\n            else:\n                externals[key] = externals[key][:t_max_steps, :]\n\n    # Make the `trainable_params` of the same shape as the `param_state`, such that they\n    # can be processed together by `get_all_parameters`.\n    pstate = params_to_pstate(params, module.indices_set_by_trainables)\n\n    # Gather parameters from `make_trainable` and `data_set` into a single list.\n    if param_state is not None:\n        pstate += param_state\n\n    all_params = module.get_all_parameters(pstate)\n    all_states = (\n        module.get_all_states(pstate, all_params, delta_t)\n        if all_states is None\n        else all_states\n    )\n\n    def _body_fun(state, externals):\n        state = module.step(\n            state,\n            delta_t,\n            external_inds,\n            externals,\n            params=all_params,\n            solver=solver,\n            voltage_solver=voltage_solver,\n        )\n        recs = jnp.asarray(\n            [\n                state[rec_state][rec_ind]\n                for rec_state, rec_ind in zip(rec_states, rec_inds)\n            ]\n        )\n        return state, recs\n\n    # If necessary, pad the stimulus with zeros in order to simulate sufficiently long.\n    # The total simulation length will be `prod(checkpoint_lengths)`. At the end, we\n    # return only the first `nsteps_to_return` elements (plus the initial state).\n    example_key = list(externals.keys())[0]\n    nsteps_to_return = len(externals[example_key])\n    if checkpoint_lengths is None:\n        checkpoint_lengths = [len(externals[example_key])]\n        length = len(externals[example_key])\n    else:\n        length = prod(checkpoint_lengths)\n        size_difference = length - len(externals[example_key])\n        dummy_external = jnp.zeros((size_difference, externals[example_key].shape[1]))\n        assert (\n            len(externals[example_key]) <= length\n        ), \"The desired simulation duration is longer than `prod(nested_length)`.\"\n        for key in externals.keys():\n            externals[key] = jnp.concatenate([externals[key], dummy_external])\n\n    # Record the initial state.\n    init_recs = jnp.asarray(\n        [\n            all_states[rec_state][rec_ind]\n            for rec_state, rec_ind in zip(rec_states, rec_inds)\n        ]\n    )\n    init_recording = jnp.expand_dims(init_recs, axis=0)\n\n    # Run simulation.\n    all_states, recordings = nested_checkpoint_scan(\n        _body_fun,\n        all_states,\n        externals,\n        length=length,\n        nested_lengths=checkpoint_lengths,\n    )\n    recs = jnp.concatenate([init_recording, recordings[:nsteps_to_return]], axis=0).T\n    return (recs, all_states) if return_states else recs\n
    "},{"location":"reference/integration/#jaxley.solver_gate.exponential_euler","title":"exponential_euler(x, dt, x_inf, x_tau)","text":"

    An exact solver for the linear dynamical system dx = -(x - x_inf) / x_tau.

    Source code in jaxley/solver_gate.py
    def exponential_euler(\n    x: jnp.ndarray,\n    dt: float,\n    x_inf: jnp.ndarray,\n    x_tau: jnp.ndarray,\n):\n    \"\"\"An exact solver for the linear dynamical system `dx = -(x - x_inf) / x_tau`.\"\"\"\n    exp_term = save_exp(-dt / x_tau)\n    return x * exp_term + x_inf * (1.0 - exp_term)\n
    "},{"location":"reference/integration/#jaxley.solver_gate.save_exp","title":"save_exp(x, max_value=20.0)","text":"

    Clip the input to a maximum value and return its exponential.

    Source code in jaxley/solver_gate.py
    def save_exp(x, max_value: float = 20.0):\n    \"\"\"Clip the input to a maximum value and return its exponential.\"\"\"\n    x = jnp.clip(x, a_max=max_value)\n    return jnp.exp(x)\n
    "},{"location":"reference/integration/#jaxley.solver_gate.solve_inf_gate_exponential","title":"solve_inf_gate_exponential(x, dt, s_inf, tau_s)","text":"

    solves dx/dt = (s_inf - x) / tau_s via exponential Euler

    Parameters:

    Name Type Description Default x ndarray

    gate variable

    required dt float

    time_delta

    required s_inf ndarray

    description

    required tau_s ndarray

    description

    required

    Returns:

    Name Type Description _type_

    updated gate

    Source code in jaxley/solver_gate.py
    def solve_inf_gate_exponential(\n    x: jnp.ndarray,\n    dt: float,\n    s_inf: jnp.ndarray,\n    tau_s: jnp.ndarray,\n):\n    \"\"\"solves dx/dt = (s_inf - x) / tau_s\n    via exponential Euler\n\n    Args:\n        x (jnp.ndarray): gate variable\n        dt (float): time_delta\n        s_inf (jnp.ndarray): _description_\n        tau_s (jnp.ndarray): _description_\n\n    Returns:\n        _type_: updated gate\n    \"\"\"\n    slope = -1.0 / tau_s\n    exp_term = save_exp(slope * dt)\n    return x * exp_term + s_inf * (1.0 - exp_term)\n
    "},{"location":"reference/integration/#jaxley.solver_voltage.step_voltage_explicit","title":"step_voltage_explicit(voltages, voltage_terms, constant_terms, coupling_conds_bwd, coupling_conds_fwd, branch_cond_fwd, branch_cond_bwd, nbranches, parents, delta_t)","text":"

    Solve one timestep of branched nerve equations with explicit (forward) Euler.

    Source code in jaxley/solver_voltage.py
    def step_voltage_explicit(\n    voltages: jnp.ndarray,\n    voltage_terms: jnp.ndarray,\n    constant_terms: jnp.ndarray,\n    coupling_conds_bwd: jnp.ndarray,\n    coupling_conds_fwd: jnp.ndarray,\n    branch_cond_fwd: jnp.ndarray,\n    branch_cond_bwd: jnp.ndarray,\n    nbranches: int,\n    parents: jnp.ndarray,\n    delta_t: float,\n) -> jnp.ndarray:\n    \"\"\"Solve one timestep of branched nerve equations with explicit (forward) Euler.\"\"\"\n    voltages = jnp.reshape(voltages, (nbranches, -1))\n    voltage_terms = jnp.reshape(voltage_terms, (nbranches, -1))\n    constant_terms = jnp.reshape(constant_terms, (nbranches, -1))\n\n    update = voltage_vectorfield(\n        parents,\n        voltages,\n        voltage_terms,\n        constant_terms,\n        coupling_conds_bwd,\n        coupling_conds_fwd,\n        branch_cond_fwd,\n        branch_cond_bwd,\n    )\n    new_voltates = voltages + delta_t * update\n    return new_voltates\n
    "},{"location":"reference/integration/#jaxley.solver_voltage.step_voltage_implicit","title":"step_voltage_implicit(voltages, voltage_terms, constant_terms, coupling_conds_upper, coupling_conds_lower, summed_coupling_conds, branchpoint_conds_children, branchpoint_conds_parents, branchpoint_weights_children, branchpoint_weights_parents, par_inds, child_inds, nbranches, solver, delta_t, children_in_level, parents_in_level, root_inds, branchpoint_group_inds, debug_states)","text":"

    Solve one timestep of branched nerve equations with implicit (backward) Euler.

    Source code in jaxley/solver_voltage.py
    def step_voltage_implicit(\n    voltages,\n    voltage_terms,\n    constant_terms,\n    coupling_conds_upper,\n    coupling_conds_lower,\n    summed_coupling_conds,\n    branchpoint_conds_children,\n    branchpoint_conds_parents,\n    branchpoint_weights_children,\n    branchpoint_weights_parents,\n    par_inds,\n    child_inds,\n    nbranches,\n    solver: str,\n    delta_t,\n    children_in_level,\n    parents_in_level,\n    root_inds,\n    branchpoint_group_inds,\n    debug_states,\n):\n    \"\"\"Solve one timestep of branched nerve equations with implicit (backward) Euler.\"\"\"\n    voltages = jnp.reshape(voltages, (nbranches, -1))\n    voltage_terms = jnp.reshape(voltage_terms, (nbranches, -1))\n    constant_terms = jnp.reshape(constant_terms, (nbranches, -1))\n    coupling_conds_upper = jnp.reshape(coupling_conds_upper, (nbranches, -1))\n    coupling_conds_lower = jnp.reshape(coupling_conds_lower, (nbranches, -1))\n    summed_coupling_conds = jnp.reshape(summed_coupling_conds, (nbranches, -1))\n\n    # Define quasi-tridiagonal system.\n    lowers, diags, uppers, solves = define_all_tridiags(\n        voltages,\n        voltage_terms,\n        constant_terms,\n        nbranches,\n        coupling_conds_upper,\n        coupling_conds_lower,\n        summed_coupling_conds,\n        delta_t,\n    )\n    all_branchpoint_vals = jnp.concatenate(\n        [branchpoint_weights_parents, branchpoint_weights_children]\n    )\n    # Find unique group identifiers\n    num_branchpoints = len(branchpoint_conds_parents)\n    branchpoint_diags = -group_and_sum(\n        all_branchpoint_vals, branchpoint_group_inds, num_branchpoints\n    )\n    branchpoint_solves = jnp.zeros((num_branchpoints,))\n\n    branchpoint_conds_children = -delta_t * branchpoint_conds_children\n    branchpoint_conds_parents = -delta_t * branchpoint_conds_parents\n\n    # Here, I move all child and parent indices towards a branchpoint into a larger\n    # vector. This is wasteful, but it makes indexing much easier. JIT compiling\n    # makes the speed difference negligible.\n    # Children.\n    bp_conds_children = jnp.zeros(nbranches)\n    bp_weights_children = jnp.zeros(nbranches)\n    # Parents.\n    bp_conds_parents = jnp.zeros(nbranches)\n    bp_weights_parents = jnp.zeros(nbranches)\n\n    # `.at[inds]` requires that `inds` is not empty, so we need an if-case here.\n    # `len(inds) == 0` is the case for branches and compartments.\n    if num_branchpoints > 0:\n        bp_conds_children = bp_conds_children.at[child_inds].set(\n            branchpoint_conds_children\n        )\n        bp_weights_children = bp_weights_children.at[child_inds].set(\n            branchpoint_weights_children\n        )\n        bp_conds_parents = bp_conds_parents.at[par_inds].set(branchpoint_conds_parents)\n        bp_weights_parents = bp_weights_parents.at[par_inds].set(\n            branchpoint_weights_parents\n        )\n\n    # Triangulate the linear system of equations.\n    (\n        diags,\n        lowers,\n        solves,\n        uppers,\n        branchpoint_diags,\n        branchpoint_solves,\n        bp_weights_children,\n        bp_conds_parents,\n    ) = _triang_branched(\n        lowers,\n        diags,\n        uppers,\n        solves,\n        bp_conds_children,\n        bp_conds_parents,\n        bp_weights_children,\n        bp_weights_parents,\n        branchpoint_diags,\n        branchpoint_solves,\n        solver,\n        children_in_level,\n        parents_in_level,\n        root_inds,\n        debug_states,\n    )\n\n    # Backsubstitute the linear system of equations.\n    (\n        solves,\n        lowers,\n        diags,\n        bp_weights_parents,\n        branchpoint_solves,\n        bp_conds_children,\n    ) = _backsub_branched(\n        lowers,\n        diags,\n        uppers,\n        solves,\n        bp_conds_children,\n        bp_conds_parents,\n        bp_weights_children,\n        bp_weights_parents,\n        branchpoint_diags,\n        branchpoint_solves,\n        solver,\n        children_in_level,\n        parents_in_level,\n        root_inds,\n        debug_states,\n    )\n\n    return solves\n
    "},{"location":"reference/integration/#jaxley.solver_voltage.voltage_vectorfield","title":"voltage_vectorfield(parents, voltages, voltage_terms, constant_terms, coupling_conds_bwd, coupling_conds_fwd, branch_cond_fwd, branch_cond_bwd)","text":"

    Evaluate the vectorfield of the nerve equation.

    Source code in jaxley/solver_voltage.py
    def voltage_vectorfield(\n    parents: jnp.ndarray,\n    voltages: jnp.ndarray,\n    voltage_terms: jnp.ndarray,\n    constant_terms: jnp.ndarray,\n    coupling_conds_bwd: jnp.ndarray,\n    coupling_conds_fwd: jnp.ndarray,\n    branch_cond_fwd: jnp.ndarray,\n    branch_cond_bwd: jnp.ndarray,\n) -> jnp.ndarray:\n    \"\"\"Evaluate the vectorfield of the nerve equation.\"\"\"\n    # Membrane current update.\n    vecfield = -voltage_terms * voltages + constant_terms\n\n    # Current through segments within the same branch.\n    vecfield = vecfield.at[:, :-1].add(\n        (voltages[:, 1:] - voltages[:, :-1]) * coupling_conds_bwd\n    )\n    vecfield = vecfield.at[:, 1:].add(\n        (voltages[:, :-1] - voltages[:, 1:]) * coupling_conds_fwd\n    )\n\n    # Current through branch points.\n    if len(branch_cond_bwd) > 0:\n        vecfield = vecfield.at[:, -1].add(\n            (voltages[parents, 0] - voltages[:, -1]) * branch_cond_bwd\n        )\n\n        # Several branches might have the same parent, so we have to either update these\n        # entries sequentially or we have to build a matrix with width being the maximum\n        # number of children and then sum.\n        term_to_add = (voltages[:, -1] - voltages[parents, 0]) * branch_cond_fwd\n        inds = jnp.stack([parents, jnp.zeros_like(parents)]).T\n        dnums = ScatterDimensionNumbers(\n            update_window_dims=(),\n            inserted_window_dims=(0, 1),\n            scatter_dims_to_operand_dims=(0, 1),\n        )\n        vecfield = scatter_add(vecfield, inds, term_to_add, dnums)\n\n    return vecfield\n
    "},{"location":"reference/mechanisms/","title":"Channels","text":""},{"location":"reference/mechanisms/#channel","title":"Channel","text":"

    Channel base class. All channels inherit from this class.

    As in NEURON, a Channel is considered a distributed process, which means that its conductances are to be specified in S/cm2 and its currents are to be specified in uA/cm2.

    Source code in jaxley/channels/channel.py
    class Channel:\n    \"\"\"Channel base class. All channels inherit from this class.\n\n    As in NEURON, a `Channel` is considered a distributed process, which means that its\n    conductances are to be specified in `S/cm2` and its currents are to be specified in\n    `uA/cm2`.\"\"\"\n\n    _name = None\n    channel_params = None\n    channel_states = None\n    current_name = None\n\n    def __init__(self, name: Optional[str] = None):\n        self._name = name if name else self.__class__.__name__\n\n    @property\n    def name(self) -> Optional[str]:\n        \"\"\"The name of the channel (by default, this is the class name).\"\"\"\n        return self._name\n\n    def change_name(self, new_name: str):\n        \"\"\"Change the channel name.\n\n        Args:\n            new_name: The new name of the channel.\n\n        Returns:\n            Renamed channel, such that this function is chainable.\n        \"\"\"\n        old_prefix = self._name + \"_\"\n        new_prefix = new_name + \"_\"\n\n        self._name = new_name\n        self.channel_params = {\n            (\n                new_prefix + key[len(old_prefix) :]\n                if key.startswith(old_prefix)\n                else key\n            ): value\n            for key, value in self.channel_params.items()\n        }\n\n        self.channel_states = {\n            (\n                new_prefix + key[len(old_prefix) :]\n                if key.startswith(old_prefix)\n                else key\n            ): value\n            for key, value in self.channel_states.items()\n        }\n        return self\n\n    def update_states(\n        self, states, dt, v, params\n    ) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:\n        \"\"\"Return the updated states.\"\"\"\n        raise NotImplementedError\n\n    def compute_current(\n        self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n    ):\n        \"\"\"Given channel states and voltage, return the current through the channel.\n\n        Args:\n            states: All states of the compartment.\n            v: Voltage of the compartment in mV.\n            params: Parameters of the channel (conductances in `S/cm2`).\n\n        Returns:\n            Current in `uA/cm2`.\n        \"\"\"\n        raise NotImplementedError\n
    "},{"location":"reference/mechanisms/#jaxley.channels.channel.Channel.name","title":"name: Optional[str] property","text":"

    The name of the channel (by default, this is the class name).

    "},{"location":"reference/mechanisms/#jaxley.channels.channel.Channel.change_name","title":"change_name(new_name)","text":"

    Change the channel name.

    Parameters:

    Name Type Description Default new_name str

    The new name of the channel.

    required

    Returns:

    Type Description

    Renamed channel, such that this function is chainable.

    Source code in jaxley/channels/channel.py
    def change_name(self, new_name: str):\n    \"\"\"Change the channel name.\n\n    Args:\n        new_name: The new name of the channel.\n\n    Returns:\n        Renamed channel, such that this function is chainable.\n    \"\"\"\n    old_prefix = self._name + \"_\"\n    new_prefix = new_name + \"_\"\n\n    self._name = new_name\n    self.channel_params = {\n        (\n            new_prefix + key[len(old_prefix) :]\n            if key.startswith(old_prefix)\n            else key\n        ): value\n        for key, value in self.channel_params.items()\n    }\n\n    self.channel_states = {\n        (\n            new_prefix + key[len(old_prefix) :]\n            if key.startswith(old_prefix)\n            else key\n        ): value\n        for key, value in self.channel_states.items()\n    }\n    return self\n
    "},{"location":"reference/mechanisms/#jaxley.channels.channel.Channel.compute_current","title":"compute_current(states, v, params)","text":"

    Given channel states and voltage, return the current through the channel.

    Parameters:

    Name Type Description Default states Dict[str, ndarray]

    All states of the compartment.

    required v

    Voltage of the compartment in mV.

    required params Dict[str, ndarray]

    Parameters of the channel (conductances in S/cm2).

    required

    Returns:

    Type Description

    Current in uA/cm2.

    Source code in jaxley/channels/channel.py
    def compute_current(\n    self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n):\n    \"\"\"Given channel states and voltage, return the current through the channel.\n\n    Args:\n        states: All states of the compartment.\n        v: Voltage of the compartment in mV.\n        params: Parameters of the channel (conductances in `S/cm2`).\n\n    Returns:\n        Current in `uA/cm2`.\n    \"\"\"\n    raise NotImplementedError\n
    "},{"location":"reference/mechanisms/#jaxley.channels.channel.Channel.update_states","title":"update_states(states, dt, v, params)","text":"

    Return the updated states.

    Source code in jaxley/channels/channel.py
    def update_states(\n    self, states, dt, v, params\n) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:\n    \"\"\"Return the updated states.\"\"\"\n    raise NotImplementedError\n
    "},{"location":"reference/mechanisms/#hh","title":"HH","text":"

    Bases: Channel

    Hodgkin-Huxley channel.

    Source code in jaxley/channels/hh.py
    class HH(Channel):\n    \"\"\"Hodgkin-Huxley channel.\"\"\"\n\n    def __init__(self, name: Optional[str] = None):\n        super().__init__(name)\n        prefix = self._name\n        self.channel_params = {\n            f\"{prefix}_gNa\": 0.12,\n            f\"{prefix}_gK\": 0.036,\n            f\"{prefix}_gLeak\": 0.0003,\n            f\"{prefix}_eNa\": 50.0,\n            f\"{prefix}_eK\": -77.0,\n            f\"{prefix}_eLeak\": -54.3,\n        }\n        self.channel_states = {\n            f\"{prefix}_m\": 0.2,\n            f\"{prefix}_h\": 0.2,\n            f\"{prefix}_n\": 0.2,\n        }\n        self.current_name = f\"i_HH\"\n\n    def update_states(\n        self,\n        states: Dict[str, jnp.ndarray],\n        dt,\n        v,\n        params: Dict[str, jnp.ndarray],\n    ):\n        \"\"\"Return updated HH channel state.\"\"\"\n        prefix = self._name\n        m, h, n = states[f\"{prefix}_m\"], states[f\"{prefix}_h\"], states[f\"{prefix}_n\"]\n        new_m = solve_gate_exponential(m, dt, *self.m_gate(v))\n        new_h = solve_gate_exponential(h, dt, *self.h_gate(v))\n        new_n = solve_gate_exponential(n, dt, *self.n_gate(v))\n        return {f\"{prefix}_m\": new_m, f\"{prefix}_h\": new_h, f\"{prefix}_n\": new_n}\n\n    def compute_current(\n        self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n    ):\n        \"\"\"Return current through HH channels.\"\"\"\n        prefix = self._name\n        m, h, n = states[f\"{prefix}_m\"], states[f\"{prefix}_h\"], states[f\"{prefix}_n\"]\n\n        # Multiply with 1000 to convert Siemens to milli Siemens.\n        gNa = params[f\"{prefix}_gNa\"] * (m**3) * h * 1000  # mS/cm^2\n        gK = params[f\"{prefix}_gK\"] * n**4 * 1000  # mS/cm^2\n        gLeak = params[f\"{prefix}_gLeak\"] * 1000  # mS/cm^2\n\n        return (\n            gNa * (v - params[f\"{prefix}_eNa\"])\n            + gK * (v - params[f\"{prefix}_eK\"])\n            + gLeak * (v - params[f\"{prefix}_eLeak\"])\n        )\n\n    def init_state(self, v, params):\n        \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n        prefix = self._name\n        alpha_m, beta_m = self.m_gate(v)\n        alpha_h, beta_h = self.h_gate(v)\n        alpha_n, beta_n = self.n_gate(v)\n        return {\n            f\"{prefix}_m\": alpha_m / (alpha_m + beta_m),\n            f\"{prefix}_h\": alpha_h / (alpha_h + beta_h),\n            f\"{prefix}_n\": alpha_n / (alpha_n + beta_n),\n        }\n\n    @staticmethod\n    def m_gate(v):\n        alpha = 0.1 * _vtrap(-(v + 40), 10)\n        beta = 4.0 * save_exp(-(v + 65) / 18)\n        return alpha, beta\n\n    @staticmethod\n    def h_gate(v):\n        alpha = 0.07 * save_exp(-(v + 65) / 20)\n        beta = 1.0 / (save_exp(-(v + 35) / 10) + 1)\n        return alpha, beta\n\n    @staticmethod\n    def n_gate(v):\n        alpha = 0.01 * _vtrap(-(v + 55), 10)\n        beta = 0.125 * save_exp(-(v + 65) / 80)\n        return alpha, beta\n
    "},{"location":"reference/mechanisms/#jaxley.channels.hh.HH.compute_current","title":"compute_current(states, v, params)","text":"

    Return current through HH channels.

    Source code in jaxley/channels/hh.py
    def compute_current(\n    self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n):\n    \"\"\"Return current through HH channels.\"\"\"\n    prefix = self._name\n    m, h, n = states[f\"{prefix}_m\"], states[f\"{prefix}_h\"], states[f\"{prefix}_n\"]\n\n    # Multiply with 1000 to convert Siemens to milli Siemens.\n    gNa = params[f\"{prefix}_gNa\"] * (m**3) * h * 1000  # mS/cm^2\n    gK = params[f\"{prefix}_gK\"] * n**4 * 1000  # mS/cm^2\n    gLeak = params[f\"{prefix}_gLeak\"] * 1000  # mS/cm^2\n\n    return (\n        gNa * (v - params[f\"{prefix}_eNa\"])\n        + gK * (v - params[f\"{prefix}_eK\"])\n        + gLeak * (v - params[f\"{prefix}_eLeak\"])\n    )\n
    "},{"location":"reference/mechanisms/#jaxley.channels.hh.HH.init_state","title":"init_state(v, params)","text":"

    Initialize the state such at fixed point of gate dynamics.

    Source code in jaxley/channels/hh.py
    def init_state(self, v, params):\n    \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n    prefix = self._name\n    alpha_m, beta_m = self.m_gate(v)\n    alpha_h, beta_h = self.h_gate(v)\n    alpha_n, beta_n = self.n_gate(v)\n    return {\n        f\"{prefix}_m\": alpha_m / (alpha_m + beta_m),\n        f\"{prefix}_h\": alpha_h / (alpha_h + beta_h),\n        f\"{prefix}_n\": alpha_n / (alpha_n + beta_n),\n    }\n
    "},{"location":"reference/mechanisms/#jaxley.channels.hh.HH.update_states","title":"update_states(states, dt, v, params)","text":"

    Return updated HH channel state.

    Source code in jaxley/channels/hh.py
    def update_states(\n    self,\n    states: Dict[str, jnp.ndarray],\n    dt,\n    v,\n    params: Dict[str, jnp.ndarray],\n):\n    \"\"\"Return updated HH channel state.\"\"\"\n    prefix = self._name\n    m, h, n = states[f\"{prefix}_m\"], states[f\"{prefix}_h\"], states[f\"{prefix}_n\"]\n    new_m = solve_gate_exponential(m, dt, *self.m_gate(v))\n    new_h = solve_gate_exponential(h, dt, *self.h_gate(v))\n    new_n = solve_gate_exponential(n, dt, *self.n_gate(v))\n    return {f\"{prefix}_m\": new_m, f\"{prefix}_h\": new_h, f\"{prefix}_n\": new_n}\n
    "},{"location":"reference/mechanisms/#pospischil","title":"Pospischil","text":"

    Bases: Channel

    Leak current

    Source code in jaxley/channels/pospischil.py
    class Leak(Channel):\n    \"\"\"Leak current\"\"\"\n\n    def __init__(self, name: Optional[str] = None):\n        super().__init__(name)\n        prefix = self._name\n        self.channel_params = {\n            f\"{prefix}_gLeak\": 1e-4,\n            f\"{prefix}_eLeak\": -70.0,\n        }\n        self.channel_states = {}\n        self.current_name = f\"i_{prefix}\"\n\n    def update_states(\n        self,\n        states: Dict[str, jnp.ndarray],\n        dt,\n        v,\n        params: Dict[str, jnp.ndarray],\n    ):\n        \"\"\"No state to update.\"\"\"\n        return {}\n\n    def compute_current(\n        self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n    ):\n        \"\"\"Return current.\"\"\"\n        prefix = self._name\n        # Multiply with 1000 to convert Siemens to milli Siemens.\n        gLeak = params[f\"{prefix}_gLeak\"] * 1000  # mS/cm^2\n        return gLeak * (v - params[f\"{prefix}_eLeak\"])\n\n    def init_state(self, v, params):\n        return {}\n

    Bases: Channel

    Sodium channel

    Source code in jaxley/channels/pospischil.py
    class Na(Channel):\n    \"\"\"Sodium channel\"\"\"\n\n    def __init__(self, name: Optional[str] = None):\n        super().__init__(name)\n        prefix = self._name\n        self.channel_params = {\n            f\"{prefix}_gNa\": 50e-3,\n            \"eNa\": 50.0,\n            \"vt\": -60.0,  # Global parameter, not prefixed with `Na`.\n        }\n        self.channel_states = {f\"{prefix}_m\": 0.2, f\"{prefix}_h\": 0.2}\n        self.current_name = f\"i_Na\"\n\n    def update_states(\n        self,\n        states: Dict[str, jnp.ndarray],\n        dt,\n        v,\n        params: Dict[str, jnp.ndarray],\n    ):\n        \"\"\"Update state.\"\"\"\n        prefix = self._name\n        m, h = states[f\"{prefix}_m\"], states[f\"{prefix}_h\"]\n        new_m = solve_gate_exponential(m, dt, *self.m_gate(v, params[\"vt\"]))\n        new_h = solve_gate_exponential(h, dt, *self.h_gate(v, params[\"vt\"]))\n        return {f\"{prefix}_m\": new_m, f\"{prefix}_h\": new_h}\n\n    def compute_current(\n        self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n    ):\n        \"\"\"Return current.\"\"\"\n        prefix = self._name\n        m, h = states[f\"{prefix}_m\"], states[f\"{prefix}_h\"]\n\n        # Multiply with 1000 to convert Siemens to milli Siemens.\n        gNa = params[f\"{prefix}_gNa\"] * (m**3) * h * 1000  # mS/cm^2\n\n        current = gNa * (v - params[\"eNa\"])\n        return current\n\n    def init_state(self, v, params):\n        \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n        prefix = self._name\n        alpha_m, beta_m = self.m_gate(v, params[\"vt\"])\n        alpha_h, beta_h = self.h_gate(v, params[\"vt\"])\n        return {\n            f\"{prefix}_m\": alpha_m / (alpha_m + beta_m),\n            f\"{prefix}_h\": alpha_h / (alpha_h + beta_h),\n        }\n\n    @staticmethod\n    def m_gate(v, vt):\n        v_alpha = v - vt - 13.0\n        alpha = 0.32 * efun(-0.25 * v_alpha) / 0.25\n\n        v_beta = v - vt - 40.0\n        beta = 0.28 * efun(0.2 * v_beta) / 0.2\n        return alpha, beta\n\n    @staticmethod\n    def h_gate(v, vt):\n        v_alpha = v - vt - 17.0\n        alpha = 0.128 * save_exp(-v_alpha / 18.0)\n\n        v_beta = v - vt - 40.0\n        beta = 4.0 / (save_exp(-v_beta / 5.0) + 1.0)\n        return alpha, beta\n

    Bases: Channel

    Potassium channel

    Source code in jaxley/channels/pospischil.py
    class K(Channel):\n    \"\"\"Potassium channel\"\"\"\n\n    def __init__(self, name: Optional[str] = None):\n        super().__init__(name)\n        prefix = self._name\n        self.channel_params = {\n            f\"{prefix}_gK\": 5e-3,\n            \"eK\": -90.0,\n            \"vt\": -60.0,  # Global parameter, not prefixed with `Na`.\n        }\n        self.channel_states = {f\"{prefix}_n\": 0.2}\n        self.current_name = f\"i_K\"\n\n    def update_states(\n        self,\n        states: Dict[str, jnp.ndarray],\n        dt,\n        v,\n        params: Dict[str, jnp.ndarray],\n    ):\n        \"\"\"Update state.\"\"\"\n        prefix = self._name\n        n = states[f\"{prefix}_n\"]\n        new_n = solve_gate_exponential(n, dt, *self.n_gate(v, params[\"vt\"]))\n        return {f\"{prefix}_n\": new_n}\n\n    def compute_current(\n        self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n    ):\n        \"\"\"Return current.\"\"\"\n        prefix = self._name\n        n = states[f\"{prefix}_n\"]\n\n        # Multiply with 1000 to convert Siemens to milli Siemens.\n        gK = params[f\"{prefix}_gK\"] * (n**4) * 1000  # mS/cm^2\n\n        return gK * (v - params[\"eK\"])\n\n    def init_state(self, v, params):\n        \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n        prefix = self._name\n        alpha_n, beta_n = self.n_gate(v, params[\"vt\"])\n        return {f\"{prefix}_n\": alpha_n / (alpha_n + beta_n)}\n\n    @staticmethod\n    def n_gate(v, vt):\n        v_alpha = v - vt - 15.0\n        alpha = 0.032 * efun(-0.2 * v_alpha) / 0.2\n\n        v_beta = v - vt - 10.0\n        beta = 0.5 * save_exp(-v_beta / 40.0)\n        return alpha, beta\n

    Bases: Channel

    Slow M Potassium channel

    Source code in jaxley/channels/pospischil.py
    class Km(Channel):\n    \"\"\"Slow M Potassium channel\"\"\"\n\n    def __init__(self, name: Optional[str] = None):\n        super().__init__(name)\n        prefix = self._name\n        self.channel_params = {\n            f\"{prefix}_gKm\": 0.004e-3,\n            f\"{prefix}_taumax\": 4000.0,\n            f\"eK\": -90.0,\n        }\n        self.channel_states = {f\"{prefix}_p\": 0.2}\n        self.current_name = f\"i_K\"\n\n    def update_states(\n        self,\n        states: Dict[str, jnp.ndarray],\n        dt,\n        v,\n        params: Dict[str, jnp.ndarray],\n    ):\n        \"\"\"Update state.\"\"\"\n        prefix = self._name\n        p = states[f\"{prefix}_p\"]\n        new_p = solve_inf_gate_exponential(\n            p, dt, *self.p_gate(v, params[f\"{prefix}_taumax\"])\n        )\n        return {f\"{prefix}_p\": new_p}\n\n    def compute_current(\n        self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n    ):\n        \"\"\"Return current.\"\"\"\n        prefix = self._name\n        p = states[f\"{prefix}_p\"]\n\n        # Multiply with 1000 to convert Siemens to milli Siemens.\n        gKm = params[f\"{prefix}_gKm\"] * p * 1000  # mS/cm^2\n        return gKm * (v - params[\"eK\"])\n\n    def init_state(self, v, params):\n        \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n        prefix = self._name\n        alpha_p, beta_p = self.p_gate(v, params[f\"{prefix}_taumax\"])\n        return {f\"{prefix}_p\": alpha_p / (alpha_p + beta_p)}\n\n    @staticmethod\n    def p_gate(v, taumax):\n        v_p = v + 35.0\n        p_inf = 1.0 / (1.0 + save_exp(-0.1 * v_p))\n\n        tau_p = taumax / (3.3 * save_exp(0.05 * v_p) + save_exp(-0.05 * v_p))\n\n        return p_inf, tau_p\n

    Bases: Channel

    L-type Calcium channel

    Source code in jaxley/channels/pospischil.py
    class CaL(Channel):\n    \"\"\"L-type Calcium channel\"\"\"\n\n    def __init__(self, name: Optional[str] = None):\n        super().__init__(name)\n        prefix = self._name\n        self.channel_params = {\n            f\"{prefix}_gCaL\": 0.1e-3,\n            \"eCa\": 120.0,\n        }\n        self.channel_states = {f\"{prefix}_q\": 0.2, f\"{prefix}_r\": 0.2}\n        self.current_name = f\"i_Ca\"\n\n    def update_states(\n        self,\n        states: Dict[str, jnp.ndarray],\n        dt,\n        v,\n        params: Dict[str, jnp.ndarray],\n    ):\n        \"\"\"Update state.\"\"\"\n        prefix = self._name\n        q, r = states[f\"{prefix}_q\"], states[f\"{prefix}_r\"]\n        new_q = solve_gate_exponential(q, dt, *self.q_gate(v))\n        new_r = solve_gate_exponential(r, dt, *self.r_gate(v))\n        return {f\"{prefix}_q\": new_q, f\"{prefix}_r\": new_r}\n\n    def compute_current(\n        self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n    ):\n        \"\"\"Return current.\"\"\"\n        prefix = self._name\n        q, r = states[f\"{prefix}_q\"], states[f\"{prefix}_r\"]\n\n        # Multiply with 1000 to convert Siemens to milli Siemens.\n        gCaL = params[f\"{prefix}_gCaL\"] * (q**2) * r * 1000  # mS/cm^2\n\n        return gCaL * (v - params[\"eCa\"])\n\n    def init_state(self, v, params):\n        \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n        prefix = self._name\n        alpha_q, beta_q = self.q_gate(v)\n        alpha_r, beta_r = self.r_gate(v)\n        return {\n            f\"{prefix}_q\": alpha_q / (alpha_q + beta_q),\n            f\"{prefix}_r\": alpha_r / (alpha_r + beta_r),\n        }\n\n    @staticmethod\n    def q_gate(v):\n        v_alpha = -v - 27.0\n        alpha = 0.055 * efun(v_alpha / 3.8) * 3.8\n\n        v_beta = -v - 75.0\n        beta = 0.94 * save_exp(v_beta / 17.0)\n        return alpha, beta\n\n    @staticmethod\n    def r_gate(v):\n        v_alpha = -v - 13.0\n        alpha = 0.000457 * save_exp(v_alpha / 50)\n\n        v_beta = -v - 15.0\n        beta = 0.0065 / (save_exp(v_beta / 28.0) + 1)\n        return alpha, beta\n

    Bases: Channel

    T-type Calcium channel

    Source code in jaxley/channels/pospischil.py
    class CaT(Channel):\n    \"\"\"T-type Calcium channel\"\"\"\n\n    def __init__(self, name: Optional[str] = None):\n        super().__init__(name)\n        prefix = self._name\n        self.channel_params = {\n            f\"{prefix}_gCaT\": 0.4e-4,\n            f\"{prefix}_vx\": 2.0,\n            \"eCa\": 120.0,  # Global parameter, not prefixed with `CaT`.\n        }\n        self.channel_states = {f\"{prefix}_u\": 0.2}\n        self.current_name = f\"i_Ca\"\n\n    def update_states(\n        self,\n        states: Dict[str, jnp.ndarray],\n        dt,\n        v,\n        params: Dict[str, jnp.ndarray],\n    ):\n        \"\"\"Update state.\"\"\"\n        prefix = self._name\n        u = states[f\"{prefix}_u\"]\n        new_u = solve_inf_gate_exponential(\n            u, dt, *self.u_gate(v, params[f\"{prefix}_vx\"])\n        )\n        return {f\"{prefix}_u\": new_u}\n\n    def compute_current(\n        self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n    ):\n        \"\"\"Return current.\"\"\"\n        prefix = self._name\n        u = states[f\"{prefix}_u\"]\n        s_inf = 1.0 / (1.0 + save_exp(-(v + params[f\"{prefix}_vx\"] + 57.0) / 6.2))\n\n        # Multiply with 1000 to convert Siemens to milli Siemens.\n        gCaT = params[f\"{prefix}_gCaT\"] * (s_inf**2) * u * 1000  # mS/cm^2\n\n        return gCaT * (v - params[\"eCa\"])\n\n    def init_state(self, v, params):\n        \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n        prefix = self._name\n        alpha_u, beta_u = self.u_gate(v, params[f\"{prefix}_vx\"])\n        return {f\"{prefix}_u\": alpha_u / (alpha_u + beta_u)}\n\n    @staticmethod\n    def u_gate(v, vx):\n        v_u1 = v + vx + 81.0\n        u_inf = 1.0 / (1.0 + save_exp(v_u1 / 4))\n\n        tau_u = (30.8 + (211.4 + save_exp((v + vx + 113.2) / 5.0))) / (\n            3.7 * (1 + save_exp((v + vx + 84.0) / 3.2))\n        )\n\n        return u_inf, tau_u\n
    "},{"location":"reference/mechanisms/#jaxley.channels.pospischil.Leak.compute_current","title":"compute_current(states, v, params)","text":"

    Return current.

    Source code in jaxley/channels/pospischil.py
    def compute_current(\n    self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n):\n    \"\"\"Return current.\"\"\"\n    prefix = self._name\n    # Multiply with 1000 to convert Siemens to milli Siemens.\n    gLeak = params[f\"{prefix}_gLeak\"] * 1000  # mS/cm^2\n    return gLeak * (v - params[f\"{prefix}_eLeak\"])\n
    "},{"location":"reference/mechanisms/#jaxley.channels.pospischil.Leak.update_states","title":"update_states(states, dt, v, params)","text":"

    No state to update.

    Source code in jaxley/channels/pospischil.py
    def update_states(\n    self,\n    states: Dict[str, jnp.ndarray],\n    dt,\n    v,\n    params: Dict[str, jnp.ndarray],\n):\n    \"\"\"No state to update.\"\"\"\n    return {}\n
    "},{"location":"reference/mechanisms/#jaxley.channels.pospischil.Na.compute_current","title":"compute_current(states, v, params)","text":"

    Return current.

    Source code in jaxley/channels/pospischil.py
    def compute_current(\n    self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n):\n    \"\"\"Return current.\"\"\"\n    prefix = self._name\n    m, h = states[f\"{prefix}_m\"], states[f\"{prefix}_h\"]\n\n    # Multiply with 1000 to convert Siemens to milli Siemens.\n    gNa = params[f\"{prefix}_gNa\"] * (m**3) * h * 1000  # mS/cm^2\n\n    current = gNa * (v - params[\"eNa\"])\n    return current\n
    "},{"location":"reference/mechanisms/#jaxley.channels.pospischil.Na.init_state","title":"init_state(v, params)","text":"

    Initialize the state such at fixed point of gate dynamics.

    Source code in jaxley/channels/pospischil.py
    def init_state(self, v, params):\n    \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n    prefix = self._name\n    alpha_m, beta_m = self.m_gate(v, params[\"vt\"])\n    alpha_h, beta_h = self.h_gate(v, params[\"vt\"])\n    return {\n        f\"{prefix}_m\": alpha_m / (alpha_m + beta_m),\n        f\"{prefix}_h\": alpha_h / (alpha_h + beta_h),\n    }\n
    "},{"location":"reference/mechanisms/#jaxley.channels.pospischil.Na.update_states","title":"update_states(states, dt, v, params)","text":"

    Update state.

    Source code in jaxley/channels/pospischil.py
    def update_states(\n    self,\n    states: Dict[str, jnp.ndarray],\n    dt,\n    v,\n    params: Dict[str, jnp.ndarray],\n):\n    \"\"\"Update state.\"\"\"\n    prefix = self._name\n    m, h = states[f\"{prefix}_m\"], states[f\"{prefix}_h\"]\n    new_m = solve_gate_exponential(m, dt, *self.m_gate(v, params[\"vt\"]))\n    new_h = solve_gate_exponential(h, dt, *self.h_gate(v, params[\"vt\"]))\n    return {f\"{prefix}_m\": new_m, f\"{prefix}_h\": new_h}\n
    "},{"location":"reference/mechanisms/#jaxley.channels.pospischil.K.compute_current","title":"compute_current(states, v, params)","text":"

    Return current.

    Source code in jaxley/channels/pospischil.py
    def compute_current(\n    self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n):\n    \"\"\"Return current.\"\"\"\n    prefix = self._name\n    n = states[f\"{prefix}_n\"]\n\n    # Multiply with 1000 to convert Siemens to milli Siemens.\n    gK = params[f\"{prefix}_gK\"] * (n**4) * 1000  # mS/cm^2\n\n    return gK * (v - params[\"eK\"])\n
    "},{"location":"reference/mechanisms/#jaxley.channels.pospischil.K.init_state","title":"init_state(v, params)","text":"

    Initialize the state such at fixed point of gate dynamics.

    Source code in jaxley/channels/pospischil.py
    def init_state(self, v, params):\n    \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n    prefix = self._name\n    alpha_n, beta_n = self.n_gate(v, params[\"vt\"])\n    return {f\"{prefix}_n\": alpha_n / (alpha_n + beta_n)}\n
    "},{"location":"reference/mechanisms/#jaxley.channels.pospischil.K.update_states","title":"update_states(states, dt, v, params)","text":"

    Update state.

    Source code in jaxley/channels/pospischil.py
    def update_states(\n    self,\n    states: Dict[str, jnp.ndarray],\n    dt,\n    v,\n    params: Dict[str, jnp.ndarray],\n):\n    \"\"\"Update state.\"\"\"\n    prefix = self._name\n    n = states[f\"{prefix}_n\"]\n    new_n = solve_gate_exponential(n, dt, *self.n_gate(v, params[\"vt\"]))\n    return {f\"{prefix}_n\": new_n}\n
    "},{"location":"reference/mechanisms/#jaxley.channels.pospischil.Km.compute_current","title":"compute_current(states, v, params)","text":"

    Return current.

    Source code in jaxley/channels/pospischil.py
    def compute_current(\n    self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n):\n    \"\"\"Return current.\"\"\"\n    prefix = self._name\n    p = states[f\"{prefix}_p\"]\n\n    # Multiply with 1000 to convert Siemens to milli Siemens.\n    gKm = params[f\"{prefix}_gKm\"] * p * 1000  # mS/cm^2\n    return gKm * (v - params[\"eK\"])\n
    "},{"location":"reference/mechanisms/#jaxley.channels.pospischil.Km.init_state","title":"init_state(v, params)","text":"

    Initialize the state such at fixed point of gate dynamics.

    Source code in jaxley/channels/pospischil.py
    def init_state(self, v, params):\n    \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n    prefix = self._name\n    alpha_p, beta_p = self.p_gate(v, params[f\"{prefix}_taumax\"])\n    return {f\"{prefix}_p\": alpha_p / (alpha_p + beta_p)}\n
    "},{"location":"reference/mechanisms/#jaxley.channels.pospischil.Km.update_states","title":"update_states(states, dt, v, params)","text":"

    Update state.

    Source code in jaxley/channels/pospischil.py
    def update_states(\n    self,\n    states: Dict[str, jnp.ndarray],\n    dt,\n    v,\n    params: Dict[str, jnp.ndarray],\n):\n    \"\"\"Update state.\"\"\"\n    prefix = self._name\n    p = states[f\"{prefix}_p\"]\n    new_p = solve_inf_gate_exponential(\n        p, dt, *self.p_gate(v, params[f\"{prefix}_taumax\"])\n    )\n    return {f\"{prefix}_p\": new_p}\n
    "},{"location":"reference/mechanisms/#jaxley.channels.pospischil.CaL.compute_current","title":"compute_current(states, v, params)","text":"

    Return current.

    Source code in jaxley/channels/pospischil.py
    def compute_current(\n    self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n):\n    \"\"\"Return current.\"\"\"\n    prefix = self._name\n    q, r = states[f\"{prefix}_q\"], states[f\"{prefix}_r\"]\n\n    # Multiply with 1000 to convert Siemens to milli Siemens.\n    gCaL = params[f\"{prefix}_gCaL\"] * (q**2) * r * 1000  # mS/cm^2\n\n    return gCaL * (v - params[\"eCa\"])\n
    "},{"location":"reference/mechanisms/#jaxley.channels.pospischil.CaL.init_state","title":"init_state(v, params)","text":"

    Initialize the state such at fixed point of gate dynamics.

    Source code in jaxley/channels/pospischil.py
    def init_state(self, v, params):\n    \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n    prefix = self._name\n    alpha_q, beta_q = self.q_gate(v)\n    alpha_r, beta_r = self.r_gate(v)\n    return {\n        f\"{prefix}_q\": alpha_q / (alpha_q + beta_q),\n        f\"{prefix}_r\": alpha_r / (alpha_r + beta_r),\n    }\n
    "},{"location":"reference/mechanisms/#jaxley.channels.pospischil.CaL.update_states","title":"update_states(states, dt, v, params)","text":"

    Update state.

    Source code in jaxley/channels/pospischil.py
    def update_states(\n    self,\n    states: Dict[str, jnp.ndarray],\n    dt,\n    v,\n    params: Dict[str, jnp.ndarray],\n):\n    \"\"\"Update state.\"\"\"\n    prefix = self._name\n    q, r = states[f\"{prefix}_q\"], states[f\"{prefix}_r\"]\n    new_q = solve_gate_exponential(q, dt, *self.q_gate(v))\n    new_r = solve_gate_exponential(r, dt, *self.r_gate(v))\n    return {f\"{prefix}_q\": new_q, f\"{prefix}_r\": new_r}\n
    "},{"location":"reference/mechanisms/#jaxley.channels.pospischil.CaT.compute_current","title":"compute_current(states, v, params)","text":"

    Return current.

    Source code in jaxley/channels/pospischil.py
    def compute_current(\n    self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n):\n    \"\"\"Return current.\"\"\"\n    prefix = self._name\n    u = states[f\"{prefix}_u\"]\n    s_inf = 1.0 / (1.0 + save_exp(-(v + params[f\"{prefix}_vx\"] + 57.0) / 6.2))\n\n    # Multiply with 1000 to convert Siemens to milli Siemens.\n    gCaT = params[f\"{prefix}_gCaT\"] * (s_inf**2) * u * 1000  # mS/cm^2\n\n    return gCaT * (v - params[\"eCa\"])\n
    "},{"location":"reference/mechanisms/#jaxley.channels.pospischil.CaT.init_state","title":"init_state(v, params)","text":"

    Initialize the state such at fixed point of gate dynamics.

    Source code in jaxley/channels/pospischil.py
    def init_state(self, v, params):\n    \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n    prefix = self._name\n    alpha_u, beta_u = self.u_gate(v, params[f\"{prefix}_vx\"])\n    return {f\"{prefix}_u\": alpha_u / (alpha_u + beta_u)}\n
    "},{"location":"reference/mechanisms/#jaxley.channels.pospischil.CaT.update_states","title":"update_states(states, dt, v, params)","text":"

    Update state.

    Source code in jaxley/channels/pospischil.py
    def update_states(\n    self,\n    states: Dict[str, jnp.ndarray],\n    dt,\n    v,\n    params: Dict[str, jnp.ndarray],\n):\n    \"\"\"Update state.\"\"\"\n    prefix = self._name\n    u = states[f\"{prefix}_u\"]\n    new_u = solve_inf_gate_exponential(\n        u, dt, *self.u_gate(v, params[f\"{prefix}_vx\"])\n    )\n    return {f\"{prefix}_u\": new_u}\n
    "},{"location":"reference/mechanisms/#synapses","title":"Synapses","text":""},{"location":"reference/mechanisms/#synapse","title":"Synapse","text":"

    Base class for a synapse.

    As in NEURON, a Synapse is considered a point process, which means that its conductances are to be specified in uS and its currents are to be specified in nA.

    Source code in jaxley/synapses/synapse.py
    class Synapse:\n    \"\"\"Base class for a synapse.\n\n    As in NEURON, a `Synapse` is considered a point process, which means that its\n    conductances are to be specified in `uS` and its currents are to be specified in\n    `nA`.\n    \"\"\"\n\n    _name = None\n    synapse_params = None\n    synapse_states = None\n\n    def __init__(self, name: Optional[str] = None):\n        self._name = name if name else self.__class__.__name__\n\n    @property\n    def name(self) -> Optional[str]:\n        return self._name\n\n    def change_name(self, new_name: str):\n        \"\"\"Change the synapse name.\n\n        Args:\n            new_name: The new name of the channel.\n\n        Returns:\n            Renamed channel, such that this function is chainable.\n        \"\"\"\n        old_prefix = self._name + \"_\"\n        new_prefix = new_name + \"_\"\n\n        self._name = new_name\n        self.synapse_params = {\n            (\n                new_prefix + key[len(old_prefix) :]\n                if key.startswith(old_prefix)\n                else key\n            ): value\n            for key, value in self.synapse_params.items()\n        }\n\n        self.synapse_states = {\n            (\n                new_prefix + key[len(old_prefix) :]\n                if key.startswith(old_prefix)\n                else key\n            ): value\n            for key, value in self.synapse_states.items()\n        }\n        return self\n\n    def update_states(\n        states: Dict[str, jnp.ndarray],\n        delta_t: float,\n        pre_voltage: jnp.ndarray,\n        post_voltage: jnp.ndarray,\n        params: Dict[str, jnp.ndarray],\n    ) -> Dict[str, jnp.ndarray]:\n        \"\"\"ODE update step.\n\n        Args:\n            states: States of the synapse.\n            delta_t: Time step in `ms`.\n            pre_voltage: Voltage of the presynaptic compartment, shape `()`.\n            post_voltage: Voltage of the postsynaptic compartment, shape `()`.\n            params: Parameters of the synapse. Conductances in `uS`.\n\n        Returns:\n            Updated states.\"\"\"\n        raise NotImplementedError\n\n    def compute_current(\n        states: Dict[str, jnp.ndarray],\n        pre_voltage: jnp.ndarray,\n        post_voltage: jnp.ndarray,\n        params: Dict[str, jnp.ndarray],\n    ) -> jnp.ndarray:\n        \"\"\"Return current through one synapse in `nA`.\n\n        Internally, we use `jax.vmap` to vectorize this function across many synapses.\n\n        Args:\n            states: States of the synapse.\n            pre_voltage: Voltage of the presynaptic compartment, shape `()`.\n            post_voltage: Voltage of the postsynaptic compartment, shape `()`.\n            params: Parameters of the synapse. Conductances in `uS`.\n\n        Returns:\n            Current through the synapse in `nA`, shape `()`.\n        \"\"\"\n        raise NotImplementedError\n
    "},{"location":"reference/mechanisms/#jaxley.synapses.synapse.Synapse.change_name","title":"change_name(new_name)","text":"

    Change the synapse name.

    Parameters:

    Name Type Description Default new_name str

    The new name of the channel.

    required

    Returns:

    Type Description

    Renamed channel, such that this function is chainable.

    Source code in jaxley/synapses/synapse.py
    def change_name(self, new_name: str):\n    \"\"\"Change the synapse name.\n\n    Args:\n        new_name: The new name of the channel.\n\n    Returns:\n        Renamed channel, such that this function is chainable.\n    \"\"\"\n    old_prefix = self._name + \"_\"\n    new_prefix = new_name + \"_\"\n\n    self._name = new_name\n    self.synapse_params = {\n        (\n            new_prefix + key[len(old_prefix) :]\n            if key.startswith(old_prefix)\n            else key\n        ): value\n        for key, value in self.synapse_params.items()\n    }\n\n    self.synapse_states = {\n        (\n            new_prefix + key[len(old_prefix) :]\n            if key.startswith(old_prefix)\n            else key\n        ): value\n        for key, value in self.synapse_states.items()\n    }\n    return self\n
    "},{"location":"reference/mechanisms/#jaxley.synapses.synapse.Synapse.compute_current","title":"compute_current(states, pre_voltage, post_voltage, params)","text":"

    Return current through one synapse in nA.

    Internally, we use jax.vmap to vectorize this function across many synapses.

    Parameters:

    Name Type Description Default states Dict[str, ndarray]

    States of the synapse.

    required pre_voltage ndarray

    Voltage of the presynaptic compartment, shape ().

    required post_voltage ndarray

    Voltage of the postsynaptic compartment, shape ().

    required params Dict[str, ndarray]

    Parameters of the synapse. Conductances in uS.

    required

    Returns:

    Type Description ndarray

    Current through the synapse in nA, shape ().

    Source code in jaxley/synapses/synapse.py
    def compute_current(\n    states: Dict[str, jnp.ndarray],\n    pre_voltage: jnp.ndarray,\n    post_voltage: jnp.ndarray,\n    params: Dict[str, jnp.ndarray],\n) -> jnp.ndarray:\n    \"\"\"Return current through one synapse in `nA`.\n\n    Internally, we use `jax.vmap` to vectorize this function across many synapses.\n\n    Args:\n        states: States of the synapse.\n        pre_voltage: Voltage of the presynaptic compartment, shape `()`.\n        post_voltage: Voltage of the postsynaptic compartment, shape `()`.\n        params: Parameters of the synapse. Conductances in `uS`.\n\n    Returns:\n        Current through the synapse in `nA`, shape `()`.\n    \"\"\"\n    raise NotImplementedError\n
    "},{"location":"reference/mechanisms/#jaxley.synapses.synapse.Synapse.update_states","title":"update_states(states, delta_t, pre_voltage, post_voltage, params)","text":"

    ODE update step.

    Parameters:

    Name Type Description Default states Dict[str, ndarray]

    States of the synapse.

    required delta_t float

    Time step in ms.

    required pre_voltage ndarray

    Voltage of the presynaptic compartment, shape ().

    required post_voltage ndarray

    Voltage of the postsynaptic compartment, shape ().

    required params Dict[str, ndarray]

    Parameters of the synapse. Conductances in uS.

    required

    Returns:

    Type Description Dict[str, ndarray]

    Updated states.

    Source code in jaxley/synapses/synapse.py
    def update_states(\n    states: Dict[str, jnp.ndarray],\n    delta_t: float,\n    pre_voltage: jnp.ndarray,\n    post_voltage: jnp.ndarray,\n    params: Dict[str, jnp.ndarray],\n) -> Dict[str, jnp.ndarray]:\n    \"\"\"ODE update step.\n\n    Args:\n        states: States of the synapse.\n        delta_t: Time step in `ms`.\n        pre_voltage: Voltage of the presynaptic compartment, shape `()`.\n        post_voltage: Voltage of the postsynaptic compartment, shape `()`.\n        params: Parameters of the synapse. Conductances in `uS`.\n\n    Returns:\n        Updated states.\"\"\"\n    raise NotImplementedError\n
    "},{"location":"reference/mechanisms/#ionotropic-synapse","title":"Ionotropic Synapse","text":"

    Bases: Synapse

    Compute synaptic current and update synapse state for a generic ionotropic synapse.

    The synapse state \u201cs\u201d is the probability that a postsynaptic receptor channel is open, and this depends on the amount of neurotransmitter released, which is in turn dependent on the presynaptic voltage.

    The synaptic parameters are
    • gS: the maximal conductance across the postsynaptic membrane (uS)
    • e_syn: the reversal potential across the postsynaptic membrane (mV)
    • k_minus: the rate constant of neurotransmitter unbinding from the postsynaptic receptor (s^-1)
    Details of this implementation can be found in the following book chapter

    L. F. Abbott and E. Marder, \u201cModeling Small Networks,\u201d in Methods in Neuronal Modeling, C. Koch and I. Sergev, Eds. Cambridge: MIT Press, 1998.

    Source code in jaxley/synapses/ionotropic.py
    class IonotropicSynapse(Synapse):\n    \"\"\"\n    Compute synaptic current and update synapse state for a generic ionotropic synapse.\n\n    The synapse state \"s\" is the probability that a postsynaptic receptor channel is\n    open, and this depends on the amount of neurotransmitter released, which is in turn\n    dependent on the presynaptic voltage.\n\n    The synaptic parameters are:\n        - gS: the maximal conductance across the postsynaptic membrane (uS)\n        - e_syn: the reversal potential across the postsynaptic membrane (mV)\n        - k_minus: the rate constant of neurotransmitter unbinding from the postsynaptic\n            receptor (s^-1)\n\n    Details of this implementation can be found in the following book chapter:\n        L. F. Abbott and E. Marder, \"Modeling Small Networks,\" in Methods in Neuronal\n        Modeling, C. Koch and I. Sergev, Eds. Cambridge: MIT Press, 1998.\n\n    \"\"\"\n\n    def __init__(self, name: Optional[str] = None):\n        super().__init__(name)\n        prefix = self._name\n        self.synapse_params = {\n            f\"{prefix}_gS\": 1e-4,\n            f\"{prefix}_e_syn\": 0.0,\n            f\"{prefix}_k_minus\": 0.025,\n        }\n        self.synapse_states = {f\"{prefix}_s\": 0.2}\n\n    def update_states(\n        self,\n        states: Dict,\n        delta_t: float,\n        pre_voltage: float,\n        post_voltage: float,\n        params: Dict,\n    ) -> Dict:\n        \"\"\"Return updated synapse state and current.\"\"\"\n        prefix = self._name\n        v_th = -35.0  # mV\n        delta = 10.0  # mV\n\n        s_inf = 1.0 / (1.0 + save_exp((v_th - pre_voltage) / delta))\n        tau_s = (1.0 - s_inf) / params[f\"{prefix}_k_minus\"]\n\n        slope = -1.0 / tau_s\n        exp_term = save_exp(slope * delta_t)\n        new_s = states[f\"{prefix}_s\"] * exp_term + s_inf * (1.0 - exp_term)\n        return {f\"{prefix}_s\": new_s}\n\n    def compute_current(\n        self, states: Dict, pre_voltage: float, post_voltage: float, params: Dict\n    ) -> float:\n        prefix = self._name\n        g_syn = params[f\"{prefix}_gS\"] * states[f\"{prefix}_s\"]\n        return g_syn * (post_voltage - params[f\"{prefix}_e_syn\"])\n
    "},{"location":"reference/mechanisms/#jaxley.synapses.ionotropic.IonotropicSynapse.update_states","title":"update_states(states, delta_t, pre_voltage, post_voltage, params)","text":"

    Return updated synapse state and current.

    Source code in jaxley/synapses/ionotropic.py
    def update_states(\n    self,\n    states: Dict,\n    delta_t: float,\n    pre_voltage: float,\n    post_voltage: float,\n    params: Dict,\n) -> Dict:\n    \"\"\"Return updated synapse state and current.\"\"\"\n    prefix = self._name\n    v_th = -35.0  # mV\n    delta = 10.0  # mV\n\n    s_inf = 1.0 / (1.0 + save_exp((v_th - pre_voltage) / delta))\n    tau_s = (1.0 - s_inf) / params[f\"{prefix}_k_minus\"]\n\n    slope = -1.0 / tau_s\n    exp_term = save_exp(slope * delta_t)\n    new_s = states[f\"{prefix}_s\"] * exp_term + s_inf * (1.0 - exp_term)\n    return {f\"{prefix}_s\": new_s}\n
    "},{"location":"reference/mechanisms/#tanh-rate-synapse","title":"TanH Rate Synapse","text":"

    Bases: Synapse

    Compute synaptic current for tanh synapse (no state).

    Source code in jaxley/synapses/tanh_rate.py
    class TanhRateSynapse(Synapse):\n    \"\"\"\n    Compute synaptic current for tanh synapse (no state).\n    \"\"\"\n\n    def __init__(self, name: Optional[str] = None):\n        super().__init__(name)\n        prefix = self._name\n        self.synapse_params = {\n            f\"{prefix}_gS\": 1e-4,\n            f\"{prefix}_x_offset\": -70.0,\n            f\"{prefix}_slope\": 1.0,\n        }\n        self.synapse_states = {}\n\n    def update_states(\n        self,\n        states: Dict,\n        delta_t: float,\n        pre_voltage: float,\n        post_voltage: float,\n        params: Dict,\n    ) -> Dict:\n        \"\"\"Return updated synapse state and current.\"\"\"\n        return {}\n\n    def compute_current(\n        self, states: Dict, pre_voltage: float, post_voltage: float, params: Dict\n    ) -> float:\n        \"\"\"Return updated synapse state and current.\"\"\"\n        prefix = self._name\n        current = (\n            -1\n            * params[f\"{prefix}_gS\"]\n            * jnp.tanh(\n                (pre_voltage - params[f\"{prefix}_x_offset\"]) * params[f\"{prefix}_slope\"]\n            )\n        )\n        return current\n
    "},{"location":"reference/mechanisms/#jaxley.synapses.tanh_rate.TanhRateSynapse.compute_current","title":"compute_current(states, pre_voltage, post_voltage, params)","text":"

    Return updated synapse state and current.

    Source code in jaxley/synapses/tanh_rate.py
    def compute_current(\n    self, states: Dict, pre_voltage: float, post_voltage: float, params: Dict\n) -> float:\n    \"\"\"Return updated synapse state and current.\"\"\"\n    prefix = self._name\n    current = (\n        -1\n        * params[f\"{prefix}_gS\"]\n        * jnp.tanh(\n            (pre_voltage - params[f\"{prefix}_x_offset\"]) * params[f\"{prefix}_slope\"]\n        )\n    )\n    return current\n
    "},{"location":"reference/mechanisms/#jaxley.synapses.tanh_rate.TanhRateSynapse.update_states","title":"update_states(states, delta_t, pre_voltage, post_voltage, params)","text":"

    Return updated synapse state and current.

    Source code in jaxley/synapses/tanh_rate.py
    def update_states(\n    self,\n    states: Dict,\n    delta_t: float,\n    pre_voltage: float,\n    post_voltage: float,\n    params: Dict,\n) -> Dict:\n    \"\"\"Return updated synapse state and current.\"\"\"\n    return {}\n
    "},{"location":"reference/modules/","title":"Modules","text":""},{"location":"reference/modules/#module","title":"Module","text":"

    Bases: ABC

    Module base class.

    Modules are everything that can be passed to jx.integrate, i.e. compartments, branches, cells, and networks.

    This base class defines the scaffold for all jaxley modules (compartments, branches, cells, networks).

    Source code in jaxley/modules/base.py
    class Module(ABC):\n    \"\"\"Module base class.\n\n    Modules are everything that can be passed to `jx.integrate`, i.e. compartments,\n    branches, cells, and networks.\n\n    This base class defines the scaffold for all jaxley modules (compartments,\n    branches, cells, networks).\n    \"\"\"\n\n    def __init__(self):\n        self.nseg: int = None\n        self.total_nbranches: int = 0\n        self.nbranches_per_cell: List[int] = None\n\n        self.group_nodes = {}\n\n        self.nodes: Optional[pd.DataFrame] = None\n\n        self.edges = pd.DataFrame(\n            columns=[\n                \"pre_locs\",\n                \"pre_branch_index\",\n                \"pre_cell_index\",\n                \"post_locs\",\n                \"post_branch_index\",\n                \"post_cell_index\",\n                \"type\",\n                \"type_ind\",\n                \"global_pre_comp_index\",\n                \"global_post_comp_index\",\n                \"global_pre_branch_index\",\n                \"global_post_branch_index\",\n            ]\n        )\n\n        self.cumsum_nbranches: Optional[jnp.ndarray] = None\n\n        self.comb_parents: jnp.ndarray = jnp.asarray([-1])\n        self.comb_branches_in_each_level: List[jnp.ndarray] = [jnp.asarray([0])]\n\n        self.initialized_morph: bool = False\n        self.initialized_syns: bool = False\n\n        # List of all types of `jx.Synapse`s.\n        self.synapses: List = []\n        self.synapse_param_names = []\n        self.synapse_state_names = []\n        self.synapse_names = []\n\n        # List of types of all `jx.Channel`s.\n        self.channels: List[Channel] = []\n        self.membrane_current_names: List[str] = []\n\n        # For trainable parameters.\n        self.indices_set_by_trainables: List[jnp.ndarray] = []\n        self.trainable_params: List[Dict[str, jnp.ndarray]] = []\n        self.allow_make_trainable: bool = True\n        self.num_trainable_params: int = 0\n\n        # For recordings.\n        self.recordings: pd.DataFrame = pd.DataFrame().from_dict({})\n\n        # For stimuli or clamps.\n        # E.g. `self.externals = {\"v\": zeros(1000,2), \"i\": ones(1000, 2)}`\n        # for 1000 timesteps and two compartments.\n        self.externals: Dict[str, jnp.ndarray] = {}\n        # E.g. `self.external)inds = {\"v\": jnp.asarray([0,1]), \"i\": jnp.asarray([2,3])}`\n        self.external_inds: Dict[str, jnp.ndarray] = {}\n\n        # x, y, z coordinates and radius.\n        self.xyzr: List[np.ndarray] = []\n\n        # For debugging the solver. Will be empty by default and only filled if\n        # `self._init_morph_for_debugging` is run.\n        self.debug_states = {}\n\n    def _update_nodes_with_xyz(self):\n        \"\"\"Add xyz coordinates to nodes.\"\"\"\n        num_branches = len(self.xyzr)\n        x = np.linspace(\n            0.5 / self.nseg,\n            (num_branches * 1 - 0.5 / self.nseg),\n            num_branches * self.nseg,\n        )\n        x += np.arange(num_branches).repeat(\n            self.nseg\n        )  # add offset to prevent branch loc overlap\n        xp = np.hstack(\n            [np.linspace(0, 1, x.shape[0]) + 2 * i for i, x in enumerate(self.xyzr)]\n        )\n        xyz = v_interp(x, xp, np.vstack(self.xyzr)[:, :3])\n        idcs = self.nodes[\"comp_index\"]\n        self.nodes.loc[idcs, [\"x\", \"y\", \"z\"]] = xyz.T\n        return xyz.T\n\n    def __repr__(self):\n        return f\"{type(self).__name__} with {len(self.channels)} different channels. Use `.show()` for details.\"\n\n    def __str__(self):\n        return f\"jx.{type(self).__name__}\"\n\n    def __dir__(self):\n        base_dir = object.__dir__(self)\n        return sorted(base_dir + self.synapse_names + list(self.group_nodes.keys()))\n\n    def _append_params_and_states(self, param_dict: Dict, state_dict: Dict):\n        \"\"\"Insert the default params of the module (e.g. radius, length).\n\n        This is run at `__init__()`. It does not deal with channels.\n        \"\"\"\n        for param_name, param_value in param_dict.items():\n            self.nodes[param_name] = param_value\n        for state_name, state_value in state_dict.items():\n            self.nodes[state_name] = state_value\n\n    def _gather_channels_from_constituents(self, constituents: List):\n        \"\"\"Modify `self.channels` and `self.nodes` with channel info from constituents.\n\n        This is run at `__init__()`. It takes all branches of constituents (e.g.\n        of all branches when the are assembled into a cell) and adds columns to\n        `.nodes` for the relevant channels.\n        \"\"\"\n        for module in constituents:\n            for channel in module.channels:\n                if channel._name not in [c._name for c in self.channels]:\n                    self.channels.append(channel)\n                if channel.current_name not in self.membrane_current_names:\n                    self.membrane_current_names.append(channel.current_name)\n        # Setting columns of channel names to `False` instead of `NaN`.\n        for channel in self.channels:\n            name = channel._name\n            self.nodes.loc[self.nodes[name].isna(), name] = False\n\n    def to_jax(self):\n        \"\"\"Move `.nodes` to `.jaxnodes`.\n\n        Before the actual simulation is run (via `jx.integrate`), all parameters of\n        the `jx.Module` are stored in `.nodes` (a `pd.DataFrame`). However, for\n        simulation, these parameters have to be moved to be `jnp.ndarrays` such that\n        they can be processed on GPU/TPU and such that the simulation can be\n        differentiated. `.to_jax()` copies the `.nodes` to `.jaxnodes`.\n        \"\"\"\n        self.jaxnodes = {}\n        for key, value in self.nodes.to_dict(orient=\"list\").items():\n            inds = jnp.arange(len(value))\n            self.jaxnodes[key] = jnp.asarray(value)[inds]\n\n        # `jaxedges` contains only parameters (no indices).\n        # `jaxedges` contains only non-Nan elements. This is unlike the channels where\n        # we allow parameter sharing.\n        self.jaxedges = {}\n        edges = self.edges.to_dict(orient=\"list\")\n        for i, synapse in enumerate(self.synapses):\n            for key in synapse.synapse_params:\n                condition = np.asarray(edges[\"type_ind\"]) == i\n                self.jaxedges[key] = jnp.asarray(np.asarray(edges[key])[condition])\n            for key in synapse.synapse_states:\n                self.jaxedges[key] = jnp.asarray(np.asarray(edges[key])[condition])\n\n    def show(\n        self,\n        param_names: Optional[Union[str, List[str]]] = None,  # TODO.\n        *,\n        indices: bool = True,\n        params: bool = True,\n        states: bool = True,\n        channel_names: Optional[List[str]] = None,\n    ) -> pd.DataFrame:\n        \"\"\"Print detailed information about the Module or a view of it.\n\n        Args:\n            param_names: The names of the parameters to show. If `None`, all parameters\n                are shown. NOT YET IMPLEMENTED.\n            indices: Whether to show the indices of the compartments.\n            params: Whether to show the parameters of the compartments.\n            states: Whether to show the states of the compartments.\n            channel_names: The names of the channels to show. If `None`, all channels are\n                shown.\n\n        Returns:\n            A `pd.DataFrame` with the requested information.\n        \"\"\"\n        return self._show(\n            self.nodes, param_names, indices, params, states, channel_names\n        )\n\n    def _show(\n        self,\n        view: pd.DataFrame,\n        param_names: Optional[Union[str, List[str]]] = None,\n        indices: bool = True,\n        params: bool = True,\n        states: bool = True,\n        channel_names: Optional[List[str]] = None,\n    ):\n        \"\"\"Print detailed information about the entire Module.\"\"\"\n        printable_nodes = deepcopy(view)\n\n        for channel in self.channels:\n            name = channel._name\n            param_names = list(channel.channel_params.keys())\n            state_names = list(channel.channel_states.keys())\n            if channel_names is not None and name not in channel_names:\n                printable_nodes = printable_nodes.drop(name, axis=1)\n                printable_nodes = printable_nodes.drop(param_names, axis=1)\n                printable_nodes = printable_nodes.drop(state_names, axis=1)\n            else:\n                if not params:\n                    printable_nodes = printable_nodes.drop(param_names, axis=1)\n                if not states:\n                    printable_nodes = printable_nodes.drop(state_names, axis=1)\n\n        if not indices:\n            for name in [\"comp_index\", \"branch_index\", \"cell_index\"]:\n                printable_nodes = printable_nodes.drop(name, axis=1)\n\n        return printable_nodes\n\n    @abstractmethod\n    def init_conds(self, params: Dict):\n        \"\"\"Initialize coupling conductances.\n\n        Args:\n            params: Conductances and morphology parameters, not yet including\n                coupling conductances.\n        \"\"\"\n        raise NotImplementedError\n\n    def _append_channel_to_nodes(self, view: pd.DataFrame, channel: \"jx.Channel\"):\n        \"\"\"Adds channel nodes from constituents to `self.channel_nodes`.\"\"\"\n        name = channel._name\n\n        # Channel does not yet exist in the `jx.Module` at all.\n        if name not in [c._name for c in self.channels]:\n            self.channels.append(channel)\n            self.nodes[name] = False  # Previous columns do not have the new channel.\n\n        if channel.current_name not in self.membrane_current_names:\n            self.membrane_current_names.append(channel.current_name)\n\n        # Add a binary column that indicates if a channel is present.\n        self.nodes.loc[view.index.values, name] = True\n\n        # Loop over all new parameters, e.g. gNa, eNa.\n        for key in channel.channel_params:\n            self.nodes.loc[view.index.values, key] = channel.channel_params[key]\n\n        # Loop over all new parameters, e.g. gNa, eNa.\n        for key in channel.channel_states:\n            self.nodes.loc[view.index.values, key] = channel.channel_states[key]\n\n    def set(self, key: str, val: Union[float, jnp.ndarray]):\n        \"\"\"Set parameter of module (or its view) to a new value.\n\n        Note that this function can not be called within `jax.jit` or `jax.grad`.\n        Instead, it should be used set the parameters of the module **before** the\n        simulation. Use `.data_set()` to set parameters during `jax.jit` or\n        `jax.grad`.\n\n        Args:\n            key: The name of the parameter to set.\n            val: The value to set the parameter to. If it is `jnp.ndarray` then it\n                must be of shape `(len(num_compartments))`.\n        \"\"\"\n        # TODO(@michaeldeistler) should we allow `.set()` for synaptic parameters\n        # without using the `SynapseView`, purely for consistency with `make_trainable`?\n        view = (\n            self.edges\n            if key in self.synapse_param_names or key in self.synapse_state_names\n            else self.nodes\n        )\n        self._set(key, val, view, view)\n\n    def _set(\n        self,\n        key: str,\n        val: Union[float, jnp.ndarray],\n        view: pd.DataFrame,\n        table_to_update: pd.DataFrame,\n    ):\n        if key in view.columns:\n            view = view[~np.isnan(view[key])]\n            table_to_update.loc[view.index.values, key] = val\n        else:\n            raise KeyError(\"Key not recognized.\")\n\n    def data_set(\n        self,\n        key: str,\n        val: Union[float, jnp.ndarray],\n        param_state: Optional[List[Dict]],\n    ):\n        \"\"\"Set parameter of module (or its view) to a new value within `jit`.\n\n        Args:\n            key: The name of the parameter to set.\n            val: The value to set the parameter to. If it is `jnp.ndarray` then it\n                must be of shape `(len(num_compartments))`.\n            param_state: State of the setted parameters, internally used such that this\n                function does not modify global state.\n        \"\"\"\n        view = (\n            self.edges\n            if key in self.synapse_param_names or key in self.synapse_state_names\n            else self.nodes\n        )\n        return self._data_set(key, val, view, param_state=param_state)\n\n    def _data_set(\n        self,\n        key: str,\n        val: Tuple[float, jnp.ndarray],\n        view: pd.DataFrame,\n        param_state: Optional[List[Dict]] = None,\n    ):\n        # Note: `data_set` does not support arrays for `val`.\n        if key in view.columns:\n            view = view[~np.isnan(view[key])]\n            added_param_state = [\n                {\n                    \"indices\": np.atleast_2d(view.index.values),\n                    \"key\": key,\n                    \"val\": jnp.atleast_1d(jnp.asarray(val)),\n                }\n            ]\n            if param_state is not None:\n                param_state += added_param_state\n            else:\n                param_state = added_param_state\n        else:\n            raise KeyError(\"Key not recognized.\")\n        return param_state\n\n    def make_trainable(\n        self,\n        key: str,\n        init_val: Optional[Union[float, list]] = None,\n        verbose: bool = True,\n    ):\n        \"\"\"Make a parameter trainable.\n\n        If a parameter is made trainable, it will be returned by `get_parameters()`\n        and should then be passed to `jx.integrate(..., params=params)`.\n\n        Args:\n            key: Name of the parameter to make trainable.\n            init_val: Initial value of the parameter. If `float`, the same value is\n                used for every created parameter. If `list`, the length of the list has\n                to match the number of created parameters. If `None`, the current\n                parameter value is used and if parameter sharing is performed that the\n                current parameter value is averaged over all shared parameters.\n            verbose: Whether to print the number of parameters that are added and the\n                total number of parameters.\n        \"\"\"\n        assert (\n            key not in self.synapse_param_names and key not in self.synapse_state_names\n        ), \"Parameters of synapses can only be made trainable via the `SynapseView`.\"\n        view = self.nodes\n        view = deepcopy(view.assign(controlled_by_param=0))\n        self._make_trainable(view, key, init_val, verbose=verbose)\n\n    def _make_trainable(\n        self,\n        view: pd.DataFrame,\n        key: str,\n        init_val: Optional[Union[float, list]] = None,\n        verbose: bool = True,\n    ):\n        assert (\n            self.allow_make_trainable\n        ), \"network.cell('all').make_trainable() is not supported. Use a for-loop over cells.\"\n\n        if key in view.columns:\n            view = view[~np.isnan(view[key])]\n            grouped_view = view.groupby(\"controlled_by_param\")\n            # Because of this `x.index.values` we cannot support `make_trainable()` on\n            # the module level for synapse parameters (but only for `SynapseView`).\n            inds_of_comps = list(grouped_view.apply(lambda x: x.index.values))\n\n            # Sorted inds are only used to infer the correct starting values.\n            param_vals = jnp.asarray(\n                [view.loc[inds, key].to_numpy() for inds in inds_of_comps]\n            )\n        else:\n            raise KeyError(f\"Parameter {key} not recognized.\")\n\n        indices_per_param = jnp.stack(inds_of_comps)\n        self.indices_set_by_trainables.append(indices_per_param)\n\n        # Set the value which the trainable parameter should take.\n        num_created_parameters = len(indices_per_param)\n        if init_val is not None:\n            if isinstance(init_val, float):\n                new_params = jnp.asarray([init_val] * num_created_parameters)\n            elif isinstance(init_val, list):\n                assert (\n                    len(init_val) == num_created_parameters\n                ), f\"len(init_val)={len(init_val)}, but trying to create {num_created_parameters} parameters.\"\n                new_params = jnp.asarray(init_val)\n            else:\n                raise ValueError(\n                    f\"init_val must a float, list, or None, but it is a {type(init_val).__name__}.\"\n                )\n        else:\n            new_params = jnp.mean(param_vals, axis=1)\n\n        self.trainable_params.append({key: new_params})\n        self.num_trainable_params += num_created_parameters\n        if verbose:\n            print(\n                f\"Number of newly added trainable parameters: {num_created_parameters}. Total number of trainable parameters: {self.num_trainable_params}\"\n            )\n\n    def delete_trainables(self):\n        \"\"\"Removes all trainable parameters from the module.\"\"\"\n        self.indices_set_by_trainables: List[jnp.ndarray] = []\n        self.trainable_params: List[Dict[str, jnp.ndarray]] = []\n        self.num_trainable_params: int = 0\n\n    def add_to_group(self, group_name: str):\n        \"\"\"Add a view of the module to a group.\n\n        Groups can then be indexed. For example:\n        ```python\n        net.cell(0).add_to_group(\"excitatory\")\n        net.excitatory.set(\"radius\", 0.1)\n        ```\n\n        Args:\n            group_name: The name of the group.\n        \"\"\"\n        raise ValueError(\"`add_to_group()` makes no sense for an entire module.\")\n\n    def _add_to_group(self, group_name: str, view: pd.DataFrame):\n        if group_name in self.group_nodes:\n            view = pd.concat([self.group_nodes[group_name], view])\n        self.group_nodes[group_name] = view\n\n    def get_parameters(self) -> List[Dict[str, jnp.ndarray]]:\n        \"\"\"Get all trainable parameters.\n\n        The returned parameters should be passed to `jx.integrate(..., params=params).\n\n        Returns:\n            A list of all trainable parameters in the form of\n                [{\"gNa\": jnp.array([0.1, 0.2, 0.3])}, ...].\n        \"\"\"\n        return self.trainable_params\n\n    def get_all_parameters(self, pstate: List[Dict]) -> Dict[str, jnp.ndarray]:\n        \"\"\"Return all parameters (and coupling conductances) needed to simulate.\n\n        Runs `init_conds()` and return every parameter that is needed to solve the ODE.\n        This includes conductances, radiuses, lengths, axial_resistivities, but also\n        coupling conductances.\n\n        This is done by first obtaining the current value of every parameter (not only\n        the trainable ones) and then replacing the trainable ones with the value\n        in `trainable_params()`. This function is run within `jx.integrate()`.\n\n        pstate can be obtained by calling `params_to_pstate()`.\n        ```\n        params = module.get_parameters() # i.e. [0, 1, 2]\n        pstate = params_to_pstate(params, module.indices_set_by_trainables)\n        module.to_jax() # needed for call to module.jaxnodes\n        ```\n\n        Args:\n            pstate: The state of the trainable parameters. pstate takes the form\n                [{\"key\": \"gNa\", \"indices\": jnp.array([0, 1, 2]), \"val\": jnp.array([0.1, 0.2, 0.3])}, ...].\n\n        Returns:\n            A dictionary of all module parameters.\n        \"\"\"\n        params = {}\n        for key in [\"radius\", \"length\", \"axial_resistivity\", \"capacitance\"]:\n            params[key] = self.jaxnodes[key]\n\n        for channel in self.channels:\n            for channel_params in channel.channel_params:\n                params[channel_params] = self.jaxnodes[channel_params]\n\n        for synapse_params in self.synapse_param_names:\n            params[synapse_params] = self.jaxedges[synapse_params]\n\n        # Override with those parameters set by `.make_trainable()`.\n        for parameter in pstate:\n            key = parameter[\"key\"]\n            inds = parameter[\"indices\"]\n            set_param = parameter[\"val\"]\n            if key in params:  # Only parameters, not initial states.\n                # `inds` is of shape `(num_params, num_comps_per_param)`.\n                # `set_param` is of shape `(num_params,)`\n                # We need to unsqueeze `set_param` to make it `(num_params, 1)` for the\n                # `.set()` to work. This is done with `[:, None]`.\n                params[key] = params[key].at[inds].set(set_param[:, None])\n\n        # Compute conductance params and append them.\n        cond_params = self.init_conds(params)\n        for key in cond_params:\n            params[key] = cond_params[key]\n\n        return params\n\n    def get_all_states(\n        self, pstate: List[Dict], all_params, delta_t: float\n    ) -> Dict[str, jnp.ndarray]:\n        \"\"\"Get the full initial state of the module from jaxnodes and trainables.\n\n        Args:\n            pstate: The state of the trainable parameters.\n            all_params: All parameters of the module.\n            delta_t: The time step.\n\n        Returns:\n            A dictionary of all states of the module.\n        \"\"\"\n        # Join node and edge states into a single state dictionary.\n        states = {\"v\": self.jaxnodes[\"v\"]}\n        for channel in self.channels:\n            for channel_states in channel.channel_states:\n                states[channel_states] = self.jaxnodes[channel_states]\n        for synapse_states in self.synapse_state_names:\n            states[synapse_states] = self.jaxedges[synapse_states]\n\n        # Override with the initial states set by `.make_trainable()`.\n        for parameter in pstate:\n            key = parameter[\"key\"]\n            inds = parameter[\"indices\"]\n            set_param = parameter[\"val\"]\n            if key in states:  # Only initial states, not parameters.\n                # `inds` is of shape `(num_params, num_comps_per_param)`.\n                # `set_param` is of shape `(num_params,)`\n                # We need to unsqueeze `set_param` to make it `(num_params, 1)` for the\n                # `.set()` to work. This is done with `[:, None]`.\n                states[key] = states[key].at[inds].set(set_param[:, None])\n\n        # Add to the states the initial current through every channel.\n        states, _ = self._channel_currents(\n            states, delta_t, self.channels, self.nodes, all_params\n        )\n\n        # Add to the states the initial current through every synapse.\n        states, _ = self._synapse_currents(\n            states, self.synapses, all_params, delta_t, self.edges\n        )\n        return states\n\n    @property\n    def initialized(self):\n        \"\"\"Whether the `Module` is ready to be solved or not.\"\"\"\n        return self.initialized_morph and self.initialized_syns\n\n    def initialize(self):\n        \"\"\"Initialize the module.\"\"\"\n        self.init_morph()\n        return self\n\n    def init_states(self):\n        \"\"\"Initialize all mechanisms in their steady state.\n\n        This considers the voltages and parameters of each compartment.\"\"\"\n        # Update states of the channels.\n        channel_nodes = self.nodes\n\n        for channel in self.channels:\n            name = channel._name\n            indices = channel_nodes.loc[channel_nodes[name]][\"comp_index\"].to_numpy()\n            voltages = channel_nodes.loc[indices, \"v\"].to_numpy()\n\n            channel_param_names = list(channel.channel_params.keys())\n            channel_params = {}\n            for p in channel_param_names:\n                channel_params[p] = channel_nodes[p][indices].to_numpy()\n\n            init_state = channel.init_state(voltages, channel_params)\n\n            # `init_state` might not return all channel states. Only the ones that are\n            # returned are updated here.\n            for key, val in init_state.items():\n                self.nodes.loc[indices, key] = val\n\n    def _init_morph_for_debugging(self):\n        \"\"\"Instandiates row and column inds which can be used to solve the voltage eqs.\n\n        This is important only for expert users who try to modify the solver for the\n        voltage equations. By default, this function is never run.\n\n        This is useful for debugging the solver because one can use\n        `scipy.linalg.sparse.spsolve` after every step of the solve.\n\n        Here is the code snippet that can be used for debugging then (to be inserted in\n        `solver_voltage`):\n        ```python\n        from scipy.sparse import csc_matrix\n        from scipy.sparse.linalg import spsolve\n        from jaxley.utils.debug_solver import build_voltage_matrix_elements\n\n        elements, solve, num_entries, start_ind_for_branchpoints = (\n            build_voltage_matrix_elements(\n                uppers,\n                lowers,\n                diags,\n                solves,\n                branchpoint_conds_children[debug_states[\"child_inds\"]],\n                branchpoint_conds_parents[debug_states[\"par_inds\"]],\n                branchpoint_weights_children[debug_states[\"child_inds\"]],\n                branchpoint_weights_parents[debug_states[\"par_inds\"]],\n                branchpoint_diags,\n                branchpoint_solves,\n                debug_states[\"nseg\"],\n                nbranches,\n            )\n        )\n        sparse_matrix = csc_matrix(\n            (elements, (debug_states[\"row_inds\"], debug_states[\"col_inds\"])),\n            shape=(num_entries, num_entries),\n        )\n        solution = spsolve(sparse_matrix, solve)\n        solution = solution[:start_ind_for_branchpoints]  # Delete branchpoint voltages.\n        solves = jnp.reshape(solution, (debug_states[\"nseg\"], nbranches))\n        return solves\n        ```\n        \"\"\"\n        # For scipy and jax.scipy.\n        row_and_col_inds = compute_morphology_indices(\n            len(self.par_inds),\n            self.child_belongs_to_branchpoint,\n            self.par_inds,\n            self.child_inds,\n            self.nseg,\n            self.total_nbranches,\n        )\n\n        num_elements = len(row_and_col_inds[\"row_inds\"])\n        data_inds, indices, indptr = convert_to_csc(\n            num_elements=num_elements,\n            row_ind=row_and_col_inds[\"row_inds\"],\n            col_ind=row_and_col_inds[\"col_inds\"],\n        )\n        self.debug_states[\"row_inds\"] = row_and_col_inds[\"row_inds\"]\n        self.debug_states[\"col_inds\"] = row_and_col_inds[\"col_inds\"]\n        self.debug_states[\"data_inds\"] = data_inds\n        self.debug_states[\"indices\"] = indices\n        self.debug_states[\"indptr\"] = indptr\n\n        self.debug_states[\"nseg\"] = self.nseg\n        self.debug_states[\"child_inds\"] = self.child_inds\n        self.debug_states[\"par_inds\"] = self.par_inds\n\n    def record(self, state: str = \"v\", verbose: bool = True):\n        \"\"\"Insert a recording into the compartment.\n\n        Args:\n            state: The name of the state to record.\n            verbose: Whether to print number of inserted recordings.\"\"\"\n        view = deepcopy(self.nodes)\n        view[\"state\"] = state\n        recording_view = view[[\"comp_index\", \"state\"]]\n        recording_view = recording_view.rename(columns={\"comp_index\": \"rec_index\"})\n        self._record(recording_view, verbose=verbose)\n\n    def _record(self, view: pd.DataFrame, verbose: bool = True):\n        self.recordings = pd.concat([self.recordings, view], ignore_index=True)\n        if verbose:\n            print(f\"Added {len(view)} recordings. See `.recordings` for details.\")\n\n    def delete_recordings(self):\n        \"\"\"Removes all recordings from the module.\"\"\"\n        self.recordings = pd.DataFrame().from_dict({})\n\n    def stimulate(self, current: Optional[jnp.ndarray] = None, verbose: bool = True):\n        \"\"\"Insert a stimulus into the compartment.\n\n        current must be a 1d array or have batch dimension of size `(num_compartments, )`\n        or `(1, )`. If 1d, the same stimulus is added to all compartments.\n\n        This function cannot be run during `jax.jit` and `jax.grad`. Because of this,\n        it should only be used for static stimuli (i.e., stimuli that do not depend\n        on the data and that should not be learned). For stimuli that depend on data\n        (or that should be learned), please use `data_stimulate()`.\n\n        Args:\n            current: Current in `nA`.\n        \"\"\"\n        self._external_input(\"i\", current, self.nodes, verbose=verbose)\n\n    def clamp(self, state_name: str, state_array: jnp.ndarray, verbose: bool = True):\n        \"\"\"Clamp a state to a given value across specified compartments.\n\n        Args:\n            state_name: The name of the state to clamp.\n            state_array (jnp.nd: Array of values to clamp the state to.\n            verbose : If True, prints details about the clamping.\n\n        This function sets external states for the compartments.\n        \"\"\"\n        if state_name not in self.nodes.columns:\n            raise KeyError(f\"{state_name} is not a recognized state in this module.\")\n        self._external_input(state_name, state_array, self.nodes, verbose=verbose)\n\n    def _external_input(\n        self,\n        key: str,\n        values: Optional[jnp.ndarray],\n        view: pd.DataFrame,\n        verbose: bool = True,\n    ):\n        values = values if values.ndim == 2 else jnp.expand_dims(values, axis=0)\n        batch_size = values.shape[0]\n        is_multiple = len(view) == batch_size\n        values = values if is_multiple else jnp.repeat(values, len(view), axis=0)\n        assert batch_size in [1, len(view)], \"Number of comps and stimuli do not match.\"\n\n        if key in self.externals.keys():\n            self.externals[key] = jnp.concatenate([self.externals[key], values])\n            self.external_inds[key] = jnp.concatenate(\n                [self.external_inds[key], view.comp_index.to_numpy()]\n            )\n        else:\n            self.externals[key] = values\n            self.external_inds[key] = view.comp_index.to_numpy()\n\n        if verbose:\n            print(f\"Added {len(view)} external_states. See `.externals` for details.\")\n\n    def data_stimulate(\n        self,\n        current: jnp.ndarray,\n        data_stimuli: Optional[Tuple[jnp.ndarray, pd.DataFrame]] = None,\n        verbose: bool = False,\n    ) -> Tuple[jnp.ndarray, pd.DataFrame]:\n        \"\"\"Insert a stimulus into the module within jit (or grad).\n\n        Args:\n            current: Current in `nA`.\n            verbose: Whether or not to print the number of inserted stimuli. `False`\n                by default because this method is meant to be jitted.\n        \"\"\"\n        return self._data_stimulate(current, data_stimuli, self.nodes, verbose=verbose)\n\n    def _data_stimulate(\n        self,\n        current: jnp.ndarray,\n        data_stimuli: Optional[Tuple[jnp.ndarray, pd.DataFrame]],\n        view: pd.DataFrame,\n        verbose: bool = False,\n    ) -> Tuple[jnp.ndarray, pd.DataFrame]:\n        current = current if current.ndim == 2 else jnp.expand_dims(current, axis=0)\n        batch_size = current.shape[0]\n        is_multiple = len(view) == batch_size\n        current = current if is_multiple else jnp.repeat(current, len(view), axis=0)\n        assert batch_size in [1, len(view)], \"Number of comps and stimuli do not match.\"\n\n        if data_stimuli is not None:\n            currents = data_stimuli[0]\n            inds = data_stimuli[1]\n        else:\n            currents = None\n            inds = pd.DataFrame().from_dict({})\n\n        # Same as in `.stimulate()`.\n        if currents is not None:\n            currents = jnp.concatenate([currents, current])\n        else:\n            currents = current\n        inds = pd.concat([inds, view])\n\n        if verbose:\n            print(f\"Added {len(view)} stimuli.\")\n\n        return (currents, inds)\n\n    def delete_stimuli(self):\n        \"\"\"Removes all stimuli from the module.\"\"\"\n        self.externals.pop(\"i\", None)\n        self.external_inds.pop(\"i\", None)\n\n    def insert(self, channel: Channel):\n        \"\"\"Insert a channel into the module.\n\n        Args:\n            channel: The channel to insert.\"\"\"\n        self._insert(channel, self.nodes)\n\n    def _insert(self, channel, view):\n        self._append_channel_to_nodes(view, channel)\n\n    def init_syns(self):\n        self.initialized_syns = True\n\n    def init_morph(self):\n        self.initialized_morph = True\n\n    def step(\n        self,\n        u: Dict[str, jnp.ndarray],\n        delta_t: float,\n        external_inds: Dict[str, jnp.ndarray],\n        externals: Dict[str, jnp.ndarray],\n        params: Dict[str, jnp.ndarray],\n        solver: str = \"bwd_euler\",\n        voltage_solver: str = \"jaxley.stone\",\n    ) -> Dict[str, jnp.ndarray]:\n        \"\"\"One step of solving the Ordinary Differential Equation.\n\n        This function is called inside of `integrate` and increments the state of the\n        module by one time step. Calls `_step_channels` and `_step_synapse` to update\n        the states of the channels and synapses using fwd_euler.\n\n        Args:\n            u: The state of the module. voltages = u[\"v\"]\n            delta_t: The time step.\n            external_inds: The indices of the external inputs.\n            externals: The external inputs.\n            params: The parameters of the module.\n            solver: The solver to use for the voltages. Either \"bwd_euler\" or \"fwd_euler\".\n            voltage_solver: The tridiagonal solver to used to diagonalize the\n                coefficient matrix of the ODE system. Either \"jaxley.thomas\",\n                \"jaxley.stone\", or \"jax.scipy.sparse\".\n\n        Returns:\n            The updated state of the module.\n        \"\"\"\n\n        # Extract the voltages\n        voltages = u[\"v\"]\n\n        # Extract the external inputs\n        has_current = \"i\" in externals.keys()\n        i_current = externals[\"i\"] if has_current else jnp.asarray([]).astype(\"float\")\n        i_inds = external_inds[\"i\"] if has_current else jnp.asarray([]).astype(\"int32\")\n        i_ext = self._get_external_input(\n            voltages, i_inds, i_current, params[\"radius\"], params[\"length\"]\n        )\n\n        # Step of the channels.\n        u, (v_terms, const_terms) = self._step_channels(\n            u, delta_t, self.channels, self.nodes, params\n        )\n\n        # Step of the synapse.\n        u, (syn_v_terms, syn_const_terms) = self._step_synapse(\n            u,\n            self.synapses,\n            params,\n            delta_t,\n            self.edges,\n        )\n\n        # Clamp for channels and synapses.\n        for key in externals.keys():\n            if key not in [\"i\", \"v\"]:\n                u[key] = u[key].at[external_inds[key]].set(externals[key])\n\n        # Voltage steps.\n        cm = params[\"capacitance\"]  # Abbreviation.\n        if solver == \"bwd_euler\":\n            new_voltages = step_voltage_implicit(\n                voltages=voltages,\n                voltage_terms=(v_terms + syn_v_terms) / cm,\n                constant_terms=(const_terms + i_ext + syn_const_terms) / cm,\n                coupling_conds_upper=params[\"branch_uppers\"],\n                coupling_conds_lower=params[\"branch_lowers\"],\n                summed_coupling_conds=params[\"branch_diags\"],\n                branchpoint_conds_children=params[\"branchpoint_conds_children\"],\n                branchpoint_conds_parents=params[\"branchpoint_conds_parents\"],\n                branchpoint_weights_children=params[\"branchpoint_weights_children\"],\n                branchpoint_weights_parents=params[\"branchpoint_weights_parents\"],\n                par_inds=self.par_inds,\n                child_inds=self.child_inds,\n                nbranches=self.total_nbranches,\n                solver=voltage_solver,\n                delta_t=delta_t,\n                children_in_level=self.children_in_level,\n                parents_in_level=self.parents_in_level,\n                root_inds=self.root_inds,\n                branchpoint_group_inds=self.branchpoint_group_inds,\n                debug_states=self.debug_states,\n            )\n        else:\n            new_voltages = step_voltage_explicit(\n                voltages,\n                (v_terms + syn_v_terms) / cm,\n                (const_terms + i_ext + syn_const_terms) / cm,\n                coupling_conds_bwd=params[\"coupling_conds_bwd\"],\n                coupling_conds_fwd=params[\"coupling_conds_fwd\"],\n                branch_cond_fwd=params[\"branch_conds_fwd\"],\n                branch_cond_bwd=params[\"branch_conds_bwd\"],\n                nbranches=self.total_nbranches,\n                parents=self.comb_parents,\n                delta_t=delta_t,\n            )\n\n        u[\"v\"] = new_voltages.ravel(order=\"C\")\n\n        # Clamp for voltages.\n        if \"v\" in externals.keys():\n            u[\"v\"] = u[\"v\"].at[external_inds[\"v\"]].set(externals[\"v\"])\n\n        return u\n\n    def _step_channels(\n        self,\n        states: Dict[str, jnp.ndarray],\n        delta_t: float,\n        channels: List[Channel],\n        channel_nodes: pd.DataFrame,\n        params: Dict[str, jnp.ndarray],\n    ) -> Tuple[Dict[str, jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray]]:\n        \"\"\"One step of integration of the channels and of computing their current.\"\"\"\n        states = self._step_channels_state(\n            states, delta_t, channels, channel_nodes, params\n        )\n        states, current_terms = self._channel_currents(\n            states, delta_t, channels, channel_nodes, params\n        )\n        return states, current_terms\n\n    def _step_channels_state(\n        self,\n        states,\n        delta_t,\n        channels: List[Channel],\n        channel_nodes: pd.DataFrame,\n        params: Dict[str, jnp.ndarray],\n    ) -> Dict[str, jnp.ndarray]:\n        \"\"\"One integration step of the channels.\"\"\"\n        voltages = states[\"v\"]\n\n        query = lambda d, keys, idcs: dict(\n            zip(keys, (v[idcs] for v in map(d.get, keys)))\n        )  # get dict with subset of keys and values from d\n        # only loops over necessary keys, as opposed to looping over d.items()\n\n        # Update states of the channels.\n        indices = channel_nodes[\"comp_index\"].to_numpy()\n        for channel in channels:\n            channel_param_names = list(channel.channel_params)\n            channel_param_names += [\"radius\", \"length\", \"axial_resistivity\"]\n            channel_state_names = list(channel.channel_states)\n            channel_state_names += self.membrane_current_names\n            channel_indices = indices[channel_nodes[channel._name].astype(bool)]\n\n            channel_params = query(params, channel_param_names, channel_indices)\n            channel_states = query(states, channel_state_names, channel_indices)\n\n            states_updated = channel.update_states(\n                channel_states, delta_t, voltages[channel_indices], channel_params\n            )\n            # Rebuild state. This has to be done within the loop over channels to allow\n            # multiple channels which modify the same state.\n            for key, val in states_updated.items():\n                states[key] = states[key].at[channel_indices].set(val)\n\n        return states\n\n    def _channel_currents(\n        self,\n        states: Dict[str, jnp.ndarray],\n        delta_t: float,\n        channels: List[Channel],\n        channel_nodes: pd.DataFrame,\n        params: Dict[str, jnp.ndarray],\n    ) -> Tuple[Dict[str, jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray]]:\n        \"\"\"Return the current through each channel.\n\n        This is also updates `state` because the `state` also contains the current.\n        \"\"\"\n        voltages = states[\"v\"]\n\n        # Compute current through channels.\n        voltage_terms = jnp.zeros_like(voltages)\n        constant_terms = jnp.zeros_like(voltages)\n        # Run with two different voltages that are `diff` apart to infer the slope and\n        # offset.\n        diff = 1e-3\n\n        current_states = {}\n        for name in self.membrane_current_names:\n            current_states[name] = jnp.zeros_like(voltages)\n\n        for channel in channels:\n            name = channel._name\n            channel_param_names = list(channel.channel_params.keys())\n            channel_state_names = list(channel.channel_states.keys())\n            indices = channel_nodes.loc[channel_nodes[name]][\"comp_index\"].to_numpy()\n\n            channel_params = {}\n            for p in channel_param_names:\n                channel_params[p] = params[p][indices]\n            channel_params[\"radius\"] = params[\"radius\"][indices]\n            channel_params[\"length\"] = params[\"length\"][indices]\n            channel_params[\"axial_resistivity\"] = params[\"axial_resistivity\"][indices]\n\n            channel_states = {}\n            for s in channel_state_names:\n                channel_states[s] = states[s][indices]\n\n            v_and_perturbed = jnp.stack([voltages[indices], voltages[indices] + diff])\n            membrane_currents = vmap(channel.compute_current, in_axes=(None, 0, None))(\n                channel_states, v_and_perturbed, channel_params\n            )\n            voltage_term = (membrane_currents[1] - membrane_currents[0]) / diff\n            constant_term = membrane_currents[0] - voltage_term * voltages[indices]\n            voltage_terms = voltage_terms.at[indices].add(voltage_term)\n            constant_terms = constant_terms.at[indices].add(-constant_term)\n\n            # Save the current (for the unperturbed voltage) as a state that will\n            # also be passed to the state update.\n            current_states[channel.current_name] = (\n                current_states[channel.current_name]\n                .at[indices]\n                .add(membrane_currents[0])\n            )\n\n        # Copy the currents into the `state` dictionary such that they can be\n        # recorded and used by `Channel.update_states()`.\n        for name in self.membrane_current_names:\n            states[name] = current_states[name]\n\n        return states, (voltage_terms, constant_terms)\n\n    def _step_synapse(\n        self,\n        u: Dict[str, jnp.ndarray],\n        syn_channels: List[Channel],\n        params: Dict[str, jnp.ndarray],\n        delta_t: float,\n        edges: pd.DataFrame,\n    ) -> Tuple[Dict[str, jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray]]:\n        \"\"\"One step of integration of the channels.\n\n        `Network` overrides this method (because it actually has synapses), whereas\n        `Compartment`, `Branch`, and `Cell` do not override this.\n        \"\"\"\n        voltages = u[\"v\"]\n        return u, (jnp.zeros_like(voltages), jnp.zeros_like(voltages))\n\n    def _synapse_currents(\n        self, states, syn_channels, params, delta_t, edges: pd.DataFrame\n    ) -> Tuple[Dict[str, jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray]]:\n        return states, (None, None)\n\n    @staticmethod\n    def _get_external_input(\n        voltages: jnp.ndarray,\n        i_inds: jnp.ndarray,\n        i_stim: jnp.ndarray,\n        radius: float,\n        length_single_compartment: float,\n    ) -> jnp.ndarray:\n        \"\"\"\n        Return external input to each compartment in uA / cm^2.\n\n        Args:\n            voltages: mV.\n            i_stim: nA.\n            radius: um.\n            length_single_compartment: um.\n        \"\"\"\n        zero_vec = jnp.zeros_like(voltages)\n        current = convert_point_process_to_distributed(\n            i_stim, radius[i_inds], length_single_compartment[i_inds]\n        )\n\n        dnums = ScatterDimensionNumbers(\n            update_window_dims=(),\n            inserted_window_dims=(0,),\n            scatter_dims_to_operand_dims=(0,),\n        )\n        stim_at_timestep = scatter_add(zero_vec, i_inds[:, None], current, dnums)\n        return stim_at_timestep\n\n    def vis(\n        self,\n        ax: Optional[Axes] = None,\n        col: str = \"k\",\n        dims: Tuple[int] = (0, 1),\n        type: str = \"line\",\n        morph_plot_kwargs: Dict = {},\n    ) -> Axes:\n        \"\"\"Visualize the module.\n\n        Args:\n            ax: An axis into which to plot.\n            col: The color for all branches.\n            dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of\n                two of them.\n            morph_plot_kwargs: Keyword arguments passed to the plotting function.\n        \"\"\"\n        return self._vis(\n            dims=dims,\n            col=col,\n            ax=ax,\n            view=self.nodes,\n            type=type,\n            morph_plot_kwargs=morph_plot_kwargs,\n        )\n\n    def _vis(\n        self,\n        ax: Axes,\n        col: str,\n        dims: Tuple[int],\n        view: pd.DataFrame,\n        type: str,\n        morph_plot_kwargs: Dict,\n    ) -> Axes:\n        branches_inds = view[\"branch_index\"].to_numpy()\n        coords = []\n        for branch_ind in branches_inds:\n            assert not np.any(\n                np.isnan(self.xyzr[branch_ind][:, dims])\n            ), \"No coordinates available. Use `vis(detail='point')` or run `.compute_xyz()` before running `.vis()`.\"\n            coords.append(self.xyzr[branch_ind])\n\n        ax = plot_morph(\n            coords,\n            dims=dims,\n            col=col,\n            ax=ax,\n            type=type,\n            morph_plot_kwargs=morph_plot_kwargs,\n        )\n\n        return ax\n\n    def _scatter(self, ax, col, dims, view, morph_plot_kwargs):\n        \"\"\"Scatter visualization (used only for compartments).\"\"\"\n        assert len(view) == 1, \"Scatter only deals with compartments.\"\n        branch_ind = view[\"branch_index\"].to_numpy().item()\n        comp_ind = view[\"comp_index\"].to_numpy().item()\n        assert not np.any(\n            np.isnan(self.xyzr[branch_ind][:, dims])\n        ), \"No coordinates available. Use `vis(detail='point')` or run `.compute_xyz()` before running `.vis()`.\"\n\n        comp_fraction = loc_of_index(comp_ind, self.nseg)\n        coords = self.xyzr[branch_ind]\n        interpolated_xyz = interpolate_xyz(comp_fraction, coords)\n\n        ax = plot_morph(\n            np.asarray([[interpolated_xyz]]),\n            dims=dims,\n            col=col,\n            ax=ax,\n            type=\"scatter\",\n            morph_plot_kwargs=morph_plot_kwargs,\n        )\n\n        return ax\n\n    def compute_xyz(self):\n        \"\"\"Return xyz coordinates of every branch, based on the branch length.\n\n        This function should not be called if the morphology was read from an `.swc`\n        file. However, for morphologies that were constructed from scratch, this\n        function **must** be called before `.vis()`. The computed `xyz` coordinates\n        are only used for plotting.\n        \"\"\"\n        max_y_multiplier = 5.0\n        min_y_multiplier = 0.5\n\n        parents = self.comb_parents\n        num_children = _compute_num_children(parents)\n        index_of_child = _compute_index_of_child(parents)\n        levels = compute_levels(parents)\n\n        # Extract branch.\n        inds_branch = self.nodes.groupby(\"branch_index\")[\"comp_index\"].apply(list)\n        branch_lens = [np.sum(self.nodes[\"length\"][np.asarray(i)]) for i in inds_branch]\n        endpoints = []\n\n        # Different levels will get a different \"angle\" at which the children emerge from\n        # the parents. This angle is defined by the `y_offset_multiplier`. This value\n        # defines the range between y-location of the first and of the last child of a\n        # parent.\n        y_offset_multiplier = np.linspace(\n            max_y_multiplier, min_y_multiplier, np.max(levels) + 1\n        )\n\n        for b in range(self.total_nbranches):\n            # For networks with mixed SWC and from-scatch neurons, only update those\n            # branches that do not have coordingates yet.\n            if np.any(np.isnan(self.xyzr[b])):\n                if parents[b] > -1:\n                    start_point = endpoints[parents[b]]\n                    num_children_of_parent = num_children[parents[b]]\n                    if num_children_of_parent > 1:\n                        y_offset = (\n                            ((index_of_child[b] / (num_children_of_parent - 1))) - 0.5\n                        ) * y_offset_multiplier[levels[b]]\n                    else:\n                        y_offset = 0.0\n                else:\n                    start_point = [0, 0, 0]\n                    y_offset = 0.0\n\n                len_of_path = np.sqrt(y_offset**2 + 1.0)\n\n                end_point = [\n                    start_point[0] + branch_lens[b] / len_of_path * 1.0,\n                    start_point[1] + branch_lens[b] / len_of_path * y_offset,\n                    start_point[2],\n                ]\n                endpoints.append(end_point)\n\n                self.xyzr[b][:, :3] = np.asarray([start_point, end_point])\n            else:\n                # Dummy to keey the index `endpoints[parent[b]]` above working.\n                endpoints.append(np.zeros((2,)))\n\n    def move(\n        self, x: float = 0.0, y: float = 0.0, z: float = 0.0, update_nodes: bool = True\n    ):\n        \"\"\"Move cells or networks by adding to their (x, y, z) coordinates.\n\n        This function is used only for visualization. It does not affect the simulation.\n\n        Args:\n            x: The amount to move in the x direction in um.\n            y: The amount to move in the y direction in um.\n            z: The amount to move in the z direction in um.\n            update_nodes: Whether `.nodes` should be updated or not. Setting this to\n                `False` largely speeds up moving, especially for big networks, but\n                `.nodes` or `.show` will not show the new xyz coordinates.\n        \"\"\"\n        self._move(x, y, z, self.nodes, update_nodes)\n\n    def _move(self, x: float, y: float, z: float, view, update_nodes: bool):\n        # Need to cast to set because this will return one columnn per compartment,\n        # not one column per branch.\n        indizes = set(view[\"branch_index\"].to_numpy().tolist())\n        for i in indizes:\n            self.xyzr[i][:, 0] += x\n            self.xyzr[i][:, 1] += y\n            self.xyzr[i][:, 2] += z\n        if update_nodes:\n            self._update_nodes_with_xyz()\n\n    def move_to(\n        self,\n        x: Union[float, np.ndarray] = 0.0,\n        y: Union[float, np.ndarray] = 0.0,\n        z: Union[float, np.ndarray] = 0.0,\n        update_nodes: bool = True,\n    ):\n        \"\"\"Move cells or networks to a location (x, y, z).\n\n        If x, y, and z are floats, then the first compartment of the first branch\n        of the first cell is moved to that float coordinate, and everything else is\n        shifted by the difference between that compartment's previous coordinate and\n        the new float location.\n\n        If x, y, and z are arrays, then they must each have a length equal to the number\n        of cells being moved. Then the first compartment of the first branch of each\n        cell is moved to the specified location.\n\n        Args:\n            update_nodes: Whether `.nodes` should be updated or not. Setting this to\n                `False` largely speeds up moving, especially for big networks, but\n                `.nodes` or `.show` will not show the new xyz coordinates.\n        \"\"\"\n        self._move_to(x, y, z, self.nodes, update_nodes)\n\n    def _move_to(\n        self,\n        x: Union[float, np.ndarray],\n        y: Union[float, np.ndarray],\n        z: Union[float, np.ndarray],\n        view: pd.DataFrame,\n        update_nodes: bool,\n    ):\n        # Test if any coordinate values are NaN which would greatly affect moving\n        if np.any(np.concatenate(self.xyzr, axis=0)[:, :3] == np.nan):\n            raise ValueError(\n                \"NaN coordinate values detected. Shift amounts cannot be computed. Please run compute_xyzr() or assign initial coordinate values.\"\n            )\n\n        # Get the indices of the cells and branches to move\n        cell_inds = list(view.cell_index.unique())\n        branch_inds = view.branch_index.unique()\n\n        if (\n            isinstance(x, np.ndarray)\n            and isinstance(y, np.ndarray)\n            and isinstance(z, np.ndarray)\n        ):\n            assert (\n                x.shape == y.shape == z.shape == (len(cell_inds),)\n            ), \"x, y, and z array shapes are not all equal to the number of cells to be moved.\"\n\n            # Split the branches by cell id\n            tup_indices = np.array([view.cell_index, view.branch_index])\n            view_cell_branch_inds = np.unique(tup_indices, axis=1)[0]\n            _, branch_split_inds = np.unique(view_cell_branch_inds, return_index=True)\n            branches_by_cell = np.split(\n                view.branch_index.unique(), branch_split_inds[1:]\n            )\n\n            # Calculate the amount to shift all of the branches of each cell\n            shift_amounts = (\n                np.array([x, y, z]).T - np.stack(self[cell_inds, 0].xyzr)[:, 0, :3]\n            )\n\n        else:\n            # Treat as if all branches belong to the same cell to be moved\n            branches_by_cell = [branch_inds]\n            # Calculate the amount to shift all branches by the 1st branch of 1st cell\n            shift_amounts = [np.array([x, y, z]) - self[cell_inds].xyzr[0][0, :3]]\n\n        # Move all of the branches\n        for i, branches in enumerate(branches_by_cell):\n            for b in branches:\n                self.xyzr[b][:, :3] += shift_amounts[i]\n\n        if update_nodes:\n            self._update_nodes_with_xyz()\n\n    def rotate(self, degrees: float, rotation_axis: str = \"xy\"):\n        \"\"\"Rotate jaxley modules clockwise. Used only for visualization.\n\n        This function is used only for visualization. It does not affect the simulation.\n\n        Args:\n            degrees: How many degrees to rotate the module by.\n            rotation_axis: Either of {`xy` | `xz` | `yz`}.\n        \"\"\"\n        self._rotate(degrees=degrees, rotation_axis=rotation_axis, view=self.nodes)\n\n    def _rotate(self, degrees: float, rotation_axis: str, view: pd.DataFrame):\n        degrees = degrees / 180 * np.pi\n        if rotation_axis == \"xy\":\n            dims = [0, 1]\n        elif rotation_axis == \"xz\":\n            dims = [0, 2]\n        elif rotation_axis == \"yz\":\n            dims = [1, 2]\n        else:\n            raise ValueError\n\n        rotation_matrix = np.asarray(\n            [[np.cos(degrees), np.sin(degrees)], [-np.sin(degrees), np.cos(degrees)]]\n        )\n        indizes = set(view[\"branch_index\"].to_numpy().tolist())\n        for i in indizes:\n            rot = np.dot(rotation_matrix, self.xyzr[i][:, dims].T).T\n            self.xyzr[i][:, dims] = rot\n\n    @property\n    def shape(self) -> Tuple[int]:\n        \"\"\"Returns the number of submodules contained in a module.\n\n        ```\n        network.shape = (num_cells, num_branches, num_compartments)\n        cell.shape = (num_branches, num_compartments)\n        branch.shape = (num_compartments,)\n        ```\"\"\"\n        mod_name = self.__class__.__name__.lower()\n        if \"comp\" in mod_name:\n            return (1,)\n        elif \"branch\" in mod_name:\n            return self[:].shape[1:]\n        return self[:].shape\n\n    def __getitem__(self, index):\n        return self._getitem(self, index)\n\n    def _getitem(\n        self,\n        module: Union[\"Module\", \"View\"],\n        index: Union[Tuple, int],\n        child_name: Optional[str] = None,\n    ) -> \"View\":\n        \"\"\"Return View which is created from indexing the module.\n\n        Args:\n            module: The module to be indexed. Will be a `Module` if `._getitem` is\n                called from `__getitem__` in a `Module` and will be a `View` if it was\n                called from `__getitem__` in a `View`.\n            index: The index (or indices) to index the module.\n            child_name: If passed, this will be the key that is used to index the\n                `module`, e.g. if it is the string `branch` then we will try to call\n                `module.xyz(index)`. If `None` then we try to infer automatically what\n                the childview should be, given the name of the `module`.\n\n        Returns:\n            An indexed `View`.\n        \"\"\"\n        if isinstance(index, tuple):\n            if len(index) > 1:\n                return childview(module, index[0], child_name)[index[1:]]\n            return childview(module, index[0], child_name)\n        return childview(module, index, child_name)\n\n    def __iter__(self):\n        for i in range(self.shape[0]):\n            yield self[i]\n\n    def _local_inds_to_global(\n        self, cell_inds: np.ndarray, branch_inds: np.ndarray, comp_inds: np.ndarray\n    ):\n        \"\"\"Given local inds of cell, branch, and comp, return the global comp index.\"\"\"\n        global_ind = (\n            self.cumsum_nbranches[cell_inds] + branch_inds\n        ) * self.nseg + comp_inds\n        return global_ind.astype(int)\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.initialized","title":"initialized property","text":"

    Whether the Module is ready to be solved or not.

    "},{"location":"reference/modules/#jaxley.modules.base.Module.shape","title":"shape: Tuple[int] property","text":"

    Returns the number of submodules contained in a module.

    network.shape = (num_cells, num_branches, num_compartments)\ncell.shape = (num_branches, num_compartments)\nbranch.shape = (num_compartments,)\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.add_to_group","title":"add_to_group(group_name)","text":"

    Add a view of the module to a group.

    Groups can then be indexed. For example:

    net.cell(0).add_to_group(\"excitatory\")\nnet.excitatory.set(\"radius\", 0.1)\n

    Parameters:

    Name Type Description Default group_name str

    The name of the group.

    required Source code in jaxley/modules/base.py
    def add_to_group(self, group_name: str):\n    \"\"\"Add a view of the module to a group.\n\n    Groups can then be indexed. For example:\n    ```python\n    net.cell(0).add_to_group(\"excitatory\")\n    net.excitatory.set(\"radius\", 0.1)\n    ```\n\n    Args:\n        group_name: The name of the group.\n    \"\"\"\n    raise ValueError(\"`add_to_group()` makes no sense for an entire module.\")\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.clamp","title":"clamp(state_name, state_array, verbose=True)","text":"

    Clamp a state to a given value across specified compartments.

    Parameters:

    Name Type Description Default state_name str

    The name of the state to clamp.

    required state_array nd

    Array of values to clamp the state to.

    required verbose

    If True, prints details about the clamping.

    True

    This function sets external states for the compartments.

    Source code in jaxley/modules/base.py
    def clamp(self, state_name: str, state_array: jnp.ndarray, verbose: bool = True):\n    \"\"\"Clamp a state to a given value across specified compartments.\n\n    Args:\n        state_name: The name of the state to clamp.\n        state_array (jnp.nd: Array of values to clamp the state to.\n        verbose : If True, prints details about the clamping.\n\n    This function sets external states for the compartments.\n    \"\"\"\n    if state_name not in self.nodes.columns:\n        raise KeyError(f\"{state_name} is not a recognized state in this module.\")\n    self._external_input(state_name, state_array, self.nodes, verbose=verbose)\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.compute_xyz","title":"compute_xyz()","text":"

    Return xyz coordinates of every branch, based on the branch length.

    This function should not be called if the morphology was read from an .swc file. However, for morphologies that were constructed from scratch, this function must be called before .vis(). The computed xyz coordinates are only used for plotting.

    Source code in jaxley/modules/base.py
    def compute_xyz(self):\n    \"\"\"Return xyz coordinates of every branch, based on the branch length.\n\n    This function should not be called if the morphology was read from an `.swc`\n    file. However, for morphologies that were constructed from scratch, this\n    function **must** be called before `.vis()`. The computed `xyz` coordinates\n    are only used for plotting.\n    \"\"\"\n    max_y_multiplier = 5.0\n    min_y_multiplier = 0.5\n\n    parents = self.comb_parents\n    num_children = _compute_num_children(parents)\n    index_of_child = _compute_index_of_child(parents)\n    levels = compute_levels(parents)\n\n    # Extract branch.\n    inds_branch = self.nodes.groupby(\"branch_index\")[\"comp_index\"].apply(list)\n    branch_lens = [np.sum(self.nodes[\"length\"][np.asarray(i)]) for i in inds_branch]\n    endpoints = []\n\n    # Different levels will get a different \"angle\" at which the children emerge from\n    # the parents. This angle is defined by the `y_offset_multiplier`. This value\n    # defines the range between y-location of the first and of the last child of a\n    # parent.\n    y_offset_multiplier = np.linspace(\n        max_y_multiplier, min_y_multiplier, np.max(levels) + 1\n    )\n\n    for b in range(self.total_nbranches):\n        # For networks with mixed SWC and from-scatch neurons, only update those\n        # branches that do not have coordingates yet.\n        if np.any(np.isnan(self.xyzr[b])):\n            if parents[b] > -1:\n                start_point = endpoints[parents[b]]\n                num_children_of_parent = num_children[parents[b]]\n                if num_children_of_parent > 1:\n                    y_offset = (\n                        ((index_of_child[b] / (num_children_of_parent - 1))) - 0.5\n                    ) * y_offset_multiplier[levels[b]]\n                else:\n                    y_offset = 0.0\n            else:\n                start_point = [0, 0, 0]\n                y_offset = 0.0\n\n            len_of_path = np.sqrt(y_offset**2 + 1.0)\n\n            end_point = [\n                start_point[0] + branch_lens[b] / len_of_path * 1.0,\n                start_point[1] + branch_lens[b] / len_of_path * y_offset,\n                start_point[2],\n            ]\n            endpoints.append(end_point)\n\n            self.xyzr[b][:, :3] = np.asarray([start_point, end_point])\n        else:\n            # Dummy to keey the index `endpoints[parent[b]]` above working.\n            endpoints.append(np.zeros((2,)))\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.data_set","title":"data_set(key, val, param_state)","text":"

    Set parameter of module (or its view) to a new value within jit.

    Parameters:

    Name Type Description Default key str

    The name of the parameter to set.

    required val Union[float, ndarray]

    The value to set the parameter to. If it is jnp.ndarray then it must be of shape (len(num_compartments)).

    required param_state Optional[List[Dict]]

    State of the setted parameters, internally used such that this function does not modify global state.

    required Source code in jaxley/modules/base.py
    def data_set(\n    self,\n    key: str,\n    val: Union[float, jnp.ndarray],\n    param_state: Optional[List[Dict]],\n):\n    \"\"\"Set parameter of module (or its view) to a new value within `jit`.\n\n    Args:\n        key: The name of the parameter to set.\n        val: The value to set the parameter to. If it is `jnp.ndarray` then it\n            must be of shape `(len(num_compartments))`.\n        param_state: State of the setted parameters, internally used such that this\n            function does not modify global state.\n    \"\"\"\n    view = (\n        self.edges\n        if key in self.synapse_param_names or key in self.synapse_state_names\n        else self.nodes\n    )\n    return self._data_set(key, val, view, param_state=param_state)\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.data_stimulate","title":"data_stimulate(current, data_stimuli=None, verbose=False)","text":"

    Insert a stimulus into the module within jit (or grad).

    Parameters:

    Name Type Description Default current ndarray

    Current in nA.

    required verbose bool

    Whether or not to print the number of inserted stimuli. False by default because this method is meant to be jitted.

    False Source code in jaxley/modules/base.py
    def data_stimulate(\n    self,\n    current: jnp.ndarray,\n    data_stimuli: Optional[Tuple[jnp.ndarray, pd.DataFrame]] = None,\n    verbose: bool = False,\n) -> Tuple[jnp.ndarray, pd.DataFrame]:\n    \"\"\"Insert a stimulus into the module within jit (or grad).\n\n    Args:\n        current: Current in `nA`.\n        verbose: Whether or not to print the number of inserted stimuli. `False`\n            by default because this method is meant to be jitted.\n    \"\"\"\n    return self._data_stimulate(current, data_stimuli, self.nodes, verbose=verbose)\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.delete_recordings","title":"delete_recordings()","text":"

    Removes all recordings from the module.

    Source code in jaxley/modules/base.py
    def delete_recordings(self):\n    \"\"\"Removes all recordings from the module.\"\"\"\n    self.recordings = pd.DataFrame().from_dict({})\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.delete_stimuli","title":"delete_stimuli()","text":"

    Removes all stimuli from the module.

    Source code in jaxley/modules/base.py
    def delete_stimuli(self):\n    \"\"\"Removes all stimuli from the module.\"\"\"\n    self.externals.pop(\"i\", None)\n    self.external_inds.pop(\"i\", None)\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.delete_trainables","title":"delete_trainables()","text":"

    Removes all trainable parameters from the module.

    Source code in jaxley/modules/base.py
    def delete_trainables(self):\n    \"\"\"Removes all trainable parameters from the module.\"\"\"\n    self.indices_set_by_trainables: List[jnp.ndarray] = []\n    self.trainable_params: List[Dict[str, jnp.ndarray]] = []\n    self.num_trainable_params: int = 0\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.get_all_parameters","title":"get_all_parameters(pstate)","text":"

    Return all parameters (and coupling conductances) needed to simulate.

    Runs init_conds() and return every parameter that is needed to solve the ODE. This includes conductances, radiuses, lengths, axial_resistivities, but also coupling conductances.

    This is done by first obtaining the current value of every parameter (not only the trainable ones) and then replacing the trainable ones with the value in trainable_params(). This function is run within jx.integrate().

    pstate can be obtained by calling params_to_pstate().

    params = module.get_parameters() # i.e. [0, 1, 2]\npstate = params_to_pstate(params, module.indices_set_by_trainables)\nmodule.to_jax() # needed for call to module.jaxnodes\n

    Parameters:

    Name Type Description Default pstate List[Dict]

    The state of the trainable parameters. pstate takes the form [{\u201ckey\u201d: \u201cgNa\u201d, \u201cindices\u201d: jnp.array([0, 1, 2]), \u201cval\u201d: jnp.array([0.1, 0.2, 0.3])}, \u2026].

    required

    Returns:

    Type Description Dict[str, ndarray]

    A dictionary of all module parameters.

    Source code in jaxley/modules/base.py
    def get_all_parameters(self, pstate: List[Dict]) -> Dict[str, jnp.ndarray]:\n    \"\"\"Return all parameters (and coupling conductances) needed to simulate.\n\n    Runs `init_conds()` and return every parameter that is needed to solve the ODE.\n    This includes conductances, radiuses, lengths, axial_resistivities, but also\n    coupling conductances.\n\n    This is done by first obtaining the current value of every parameter (not only\n    the trainable ones) and then replacing the trainable ones with the value\n    in `trainable_params()`. This function is run within `jx.integrate()`.\n\n    pstate can be obtained by calling `params_to_pstate()`.\n    ```\n    params = module.get_parameters() # i.e. [0, 1, 2]\n    pstate = params_to_pstate(params, module.indices_set_by_trainables)\n    module.to_jax() # needed for call to module.jaxnodes\n    ```\n\n    Args:\n        pstate: The state of the trainable parameters. pstate takes the form\n            [{\"key\": \"gNa\", \"indices\": jnp.array([0, 1, 2]), \"val\": jnp.array([0.1, 0.2, 0.3])}, ...].\n\n    Returns:\n        A dictionary of all module parameters.\n    \"\"\"\n    params = {}\n    for key in [\"radius\", \"length\", \"axial_resistivity\", \"capacitance\"]:\n        params[key] = self.jaxnodes[key]\n\n    for channel in self.channels:\n        for channel_params in channel.channel_params:\n            params[channel_params] = self.jaxnodes[channel_params]\n\n    for synapse_params in self.synapse_param_names:\n        params[synapse_params] = self.jaxedges[synapse_params]\n\n    # Override with those parameters set by `.make_trainable()`.\n    for parameter in pstate:\n        key = parameter[\"key\"]\n        inds = parameter[\"indices\"]\n        set_param = parameter[\"val\"]\n        if key in params:  # Only parameters, not initial states.\n            # `inds` is of shape `(num_params, num_comps_per_param)`.\n            # `set_param` is of shape `(num_params,)`\n            # We need to unsqueeze `set_param` to make it `(num_params, 1)` for the\n            # `.set()` to work. This is done with `[:, None]`.\n            params[key] = params[key].at[inds].set(set_param[:, None])\n\n    # Compute conductance params and append them.\n    cond_params = self.init_conds(params)\n    for key in cond_params:\n        params[key] = cond_params[key]\n\n    return params\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.get_all_states","title":"get_all_states(pstate, all_params, delta_t)","text":"

    Get the full initial state of the module from jaxnodes and trainables.

    Parameters:

    Name Type Description Default pstate List[Dict]

    The state of the trainable parameters.

    required all_params

    All parameters of the module.

    required delta_t float

    The time step.

    required

    Returns:

    Type Description Dict[str, ndarray]

    A dictionary of all states of the module.

    Source code in jaxley/modules/base.py
    def get_all_states(\n    self, pstate: List[Dict], all_params, delta_t: float\n) -> Dict[str, jnp.ndarray]:\n    \"\"\"Get the full initial state of the module from jaxnodes and trainables.\n\n    Args:\n        pstate: The state of the trainable parameters.\n        all_params: All parameters of the module.\n        delta_t: The time step.\n\n    Returns:\n        A dictionary of all states of the module.\n    \"\"\"\n    # Join node and edge states into a single state dictionary.\n    states = {\"v\": self.jaxnodes[\"v\"]}\n    for channel in self.channels:\n        for channel_states in channel.channel_states:\n            states[channel_states] = self.jaxnodes[channel_states]\n    for synapse_states in self.synapse_state_names:\n        states[synapse_states] = self.jaxedges[synapse_states]\n\n    # Override with the initial states set by `.make_trainable()`.\n    for parameter in pstate:\n        key = parameter[\"key\"]\n        inds = parameter[\"indices\"]\n        set_param = parameter[\"val\"]\n        if key in states:  # Only initial states, not parameters.\n            # `inds` is of shape `(num_params, num_comps_per_param)`.\n            # `set_param` is of shape `(num_params,)`\n            # We need to unsqueeze `set_param` to make it `(num_params, 1)` for the\n            # `.set()` to work. This is done with `[:, None]`.\n            states[key] = states[key].at[inds].set(set_param[:, None])\n\n    # Add to the states the initial current through every channel.\n    states, _ = self._channel_currents(\n        states, delta_t, self.channels, self.nodes, all_params\n    )\n\n    # Add to the states the initial current through every synapse.\n    states, _ = self._synapse_currents(\n        states, self.synapses, all_params, delta_t, self.edges\n    )\n    return states\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.get_parameters","title":"get_parameters()","text":"

    Get all trainable parameters.

    The returned parameters should be passed to `jx.integrate(\u2026, params=params).

    Returns:

    Type Description List[Dict[str, ndarray]]

    A list of all trainable parameters in the form of [{\u201cgNa\u201d: jnp.array([0.1, 0.2, 0.3])}, \u2026].

    Source code in jaxley/modules/base.py
    def get_parameters(self) -> List[Dict[str, jnp.ndarray]]:\n    \"\"\"Get all trainable parameters.\n\n    The returned parameters should be passed to `jx.integrate(..., params=params).\n\n    Returns:\n        A list of all trainable parameters in the form of\n            [{\"gNa\": jnp.array([0.1, 0.2, 0.3])}, ...].\n    \"\"\"\n    return self.trainable_params\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.init_conds","title":"init_conds(params) abstractmethod","text":"

    Initialize coupling conductances.

    Parameters:

    Name Type Description Default params Dict

    Conductances and morphology parameters, not yet including coupling conductances.

    required Source code in jaxley/modules/base.py
    @abstractmethod\ndef init_conds(self, params: Dict):\n    \"\"\"Initialize coupling conductances.\n\n    Args:\n        params: Conductances and morphology parameters, not yet including\n            coupling conductances.\n    \"\"\"\n    raise NotImplementedError\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.init_states","title":"init_states()","text":"

    Initialize all mechanisms in their steady state.

    This considers the voltages and parameters of each compartment.

    Source code in jaxley/modules/base.py
    def init_states(self):\n    \"\"\"Initialize all mechanisms in their steady state.\n\n    This considers the voltages and parameters of each compartment.\"\"\"\n    # Update states of the channels.\n    channel_nodes = self.nodes\n\n    for channel in self.channels:\n        name = channel._name\n        indices = channel_nodes.loc[channel_nodes[name]][\"comp_index\"].to_numpy()\n        voltages = channel_nodes.loc[indices, \"v\"].to_numpy()\n\n        channel_param_names = list(channel.channel_params.keys())\n        channel_params = {}\n        for p in channel_param_names:\n            channel_params[p] = channel_nodes[p][indices].to_numpy()\n\n        init_state = channel.init_state(voltages, channel_params)\n\n        # `init_state` might not return all channel states. Only the ones that are\n        # returned are updated here.\n        for key, val in init_state.items():\n            self.nodes.loc[indices, key] = val\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.initialize","title":"initialize()","text":"

    Initialize the module.

    Source code in jaxley/modules/base.py
    def initialize(self):\n    \"\"\"Initialize the module.\"\"\"\n    self.init_morph()\n    return self\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.insert","title":"insert(channel)","text":"

    Insert a channel into the module.

    Parameters:

    Name Type Description Default channel Channel

    The channel to insert.

    required Source code in jaxley/modules/base.py
    def insert(self, channel: Channel):\n    \"\"\"Insert a channel into the module.\n\n    Args:\n        channel: The channel to insert.\"\"\"\n    self._insert(channel, self.nodes)\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.make_trainable","title":"make_trainable(key, init_val=None, verbose=True)","text":"

    Make a parameter trainable.

    If a parameter is made trainable, it will be returned by get_parameters() and should then be passed to jx.integrate(..., params=params).

    Parameters:

    Name Type Description Default key str

    Name of the parameter to make trainable.

    required init_val Optional[Union[float, list]]

    Initial value of the parameter. If float, the same value is used for every created parameter. If list, the length of the list has to match the number of created parameters. If None, the current parameter value is used and if parameter sharing is performed that the current parameter value is averaged over all shared parameters.

    None verbose bool

    Whether to print the number of parameters that are added and the total number of parameters.

    True Source code in jaxley/modules/base.py
    def make_trainable(\n    self,\n    key: str,\n    init_val: Optional[Union[float, list]] = None,\n    verbose: bool = True,\n):\n    \"\"\"Make a parameter trainable.\n\n    If a parameter is made trainable, it will be returned by `get_parameters()`\n    and should then be passed to `jx.integrate(..., params=params)`.\n\n    Args:\n        key: Name of the parameter to make trainable.\n        init_val: Initial value of the parameter. If `float`, the same value is\n            used for every created parameter. If `list`, the length of the list has\n            to match the number of created parameters. If `None`, the current\n            parameter value is used and if parameter sharing is performed that the\n            current parameter value is averaged over all shared parameters.\n        verbose: Whether to print the number of parameters that are added and the\n            total number of parameters.\n    \"\"\"\n    assert (\n        key not in self.synapse_param_names and key not in self.synapse_state_names\n    ), \"Parameters of synapses can only be made trainable via the `SynapseView`.\"\n    view = self.nodes\n    view = deepcopy(view.assign(controlled_by_param=0))\n    self._make_trainable(view, key, init_val, verbose=verbose)\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.move","title":"move(x=0.0, y=0.0, z=0.0, update_nodes=True)","text":"

    Move cells or networks by adding to their (x, y, z) coordinates.

    This function is used only for visualization. It does not affect the simulation.

    Parameters:

    Name Type Description Default x float

    The amount to move in the x direction in um.

    0.0 y float

    The amount to move in the y direction in um.

    0.0 z float

    The amount to move in the z direction in um.

    0.0 update_nodes bool

    Whether .nodes should be updated or not. Setting this to False largely speeds up moving, especially for big networks, but .nodes or .show will not show the new xyz coordinates.

    True Source code in jaxley/modules/base.py
    def move(\n    self, x: float = 0.0, y: float = 0.0, z: float = 0.0, update_nodes: bool = True\n):\n    \"\"\"Move cells or networks by adding to their (x, y, z) coordinates.\n\n    This function is used only for visualization. It does not affect the simulation.\n\n    Args:\n        x: The amount to move in the x direction in um.\n        y: The amount to move in the y direction in um.\n        z: The amount to move in the z direction in um.\n        update_nodes: Whether `.nodes` should be updated or not. Setting this to\n            `False` largely speeds up moving, especially for big networks, but\n            `.nodes` or `.show` will not show the new xyz coordinates.\n    \"\"\"\n    self._move(x, y, z, self.nodes, update_nodes)\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.move_to","title":"move_to(x=0.0, y=0.0, z=0.0, update_nodes=True)","text":"

    Move cells or networks to a location (x, y, z).

    If x, y, and z are floats, then the first compartment of the first branch of the first cell is moved to that float coordinate, and everything else is shifted by the difference between that compartment\u2019s previous coordinate and the new float location.

    If x, y, and z are arrays, then they must each have a length equal to the number of cells being moved. Then the first compartment of the first branch of each cell is moved to the specified location.

    Parameters:

    Name Type Description Default update_nodes bool

    Whether .nodes should be updated or not. Setting this to False largely speeds up moving, especially for big networks, but .nodes or .show will not show the new xyz coordinates.

    True Source code in jaxley/modules/base.py
    def move_to(\n    self,\n    x: Union[float, np.ndarray] = 0.0,\n    y: Union[float, np.ndarray] = 0.0,\n    z: Union[float, np.ndarray] = 0.0,\n    update_nodes: bool = True,\n):\n    \"\"\"Move cells or networks to a location (x, y, z).\n\n    If x, y, and z are floats, then the first compartment of the first branch\n    of the first cell is moved to that float coordinate, and everything else is\n    shifted by the difference between that compartment's previous coordinate and\n    the new float location.\n\n    If x, y, and z are arrays, then they must each have a length equal to the number\n    of cells being moved. Then the first compartment of the first branch of each\n    cell is moved to the specified location.\n\n    Args:\n        update_nodes: Whether `.nodes` should be updated or not. Setting this to\n            `False` largely speeds up moving, especially for big networks, but\n            `.nodes` or `.show` will not show the new xyz coordinates.\n    \"\"\"\n    self._move_to(x, y, z, self.nodes, update_nodes)\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.record","title":"record(state='v', verbose=True)","text":"

    Insert a recording into the compartment.

    Parameters:

    Name Type Description Default state str

    The name of the state to record.

    'v' verbose bool

    Whether to print number of inserted recordings.

    True Source code in jaxley/modules/base.py
    def record(self, state: str = \"v\", verbose: bool = True):\n    \"\"\"Insert a recording into the compartment.\n\n    Args:\n        state: The name of the state to record.\n        verbose: Whether to print number of inserted recordings.\"\"\"\n    view = deepcopy(self.nodes)\n    view[\"state\"] = state\n    recording_view = view[[\"comp_index\", \"state\"]]\n    recording_view = recording_view.rename(columns={\"comp_index\": \"rec_index\"})\n    self._record(recording_view, verbose=verbose)\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.rotate","title":"rotate(degrees, rotation_axis='xy')","text":"

    Rotate jaxley modules clockwise. Used only for visualization.

    This function is used only for visualization. It does not affect the simulation.

    Parameters:

    Name Type Description Default degrees float

    How many degrees to rotate the module by.

    required rotation_axis str

    Either of {xy | xz | yz}.

    'xy' Source code in jaxley/modules/base.py
    def rotate(self, degrees: float, rotation_axis: str = \"xy\"):\n    \"\"\"Rotate jaxley modules clockwise. Used only for visualization.\n\n    This function is used only for visualization. It does not affect the simulation.\n\n    Args:\n        degrees: How many degrees to rotate the module by.\n        rotation_axis: Either of {`xy` | `xz` | `yz`}.\n    \"\"\"\n    self._rotate(degrees=degrees, rotation_axis=rotation_axis, view=self.nodes)\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.set","title":"set(key, val)","text":"

    Set parameter of module (or its view) to a new value.

    Note that this function can not be called within jax.jit or jax.grad. Instead, it should be used set the parameters of the module before the simulation. Use .data_set() to set parameters during jax.jit or jax.grad.

    Parameters:

    Name Type Description Default key str

    The name of the parameter to set.

    required val Union[float, ndarray]

    The value to set the parameter to. If it is jnp.ndarray then it must be of shape (len(num_compartments)).

    required Source code in jaxley/modules/base.py
    def set(self, key: str, val: Union[float, jnp.ndarray]):\n    \"\"\"Set parameter of module (or its view) to a new value.\n\n    Note that this function can not be called within `jax.jit` or `jax.grad`.\n    Instead, it should be used set the parameters of the module **before** the\n    simulation. Use `.data_set()` to set parameters during `jax.jit` or\n    `jax.grad`.\n\n    Args:\n        key: The name of the parameter to set.\n        val: The value to set the parameter to. If it is `jnp.ndarray` then it\n            must be of shape `(len(num_compartments))`.\n    \"\"\"\n    # TODO(@michaeldeistler) should we allow `.set()` for synaptic parameters\n    # without using the `SynapseView`, purely for consistency with `make_trainable`?\n    view = (\n        self.edges\n        if key in self.synapse_param_names or key in self.synapse_state_names\n        else self.nodes\n    )\n    self._set(key, val, view, view)\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.show","title":"show(param_names=None, *, indices=True, params=True, states=True, channel_names=None)","text":"

    Print detailed information about the Module or a view of it.

    Parameters:

    Name Type Description Default param_names Optional[Union[str, List[str]]]

    The names of the parameters to show. If None, all parameters are shown. NOT YET IMPLEMENTED.

    None indices bool

    Whether to show the indices of the compartments.

    True params bool

    Whether to show the parameters of the compartments.

    True states bool

    Whether to show the states of the compartments.

    True channel_names Optional[List[str]]

    The names of the channels to show. If None, all channels are shown.

    None

    Returns:

    Type Description DataFrame

    A pd.DataFrame with the requested information.

    Source code in jaxley/modules/base.py
    def show(\n    self,\n    param_names: Optional[Union[str, List[str]]] = None,  # TODO.\n    *,\n    indices: bool = True,\n    params: bool = True,\n    states: bool = True,\n    channel_names: Optional[List[str]] = None,\n) -> pd.DataFrame:\n    \"\"\"Print detailed information about the Module or a view of it.\n\n    Args:\n        param_names: The names of the parameters to show. If `None`, all parameters\n            are shown. NOT YET IMPLEMENTED.\n        indices: Whether to show the indices of the compartments.\n        params: Whether to show the parameters of the compartments.\n        states: Whether to show the states of the compartments.\n        channel_names: The names of the channels to show. If `None`, all channels are\n            shown.\n\n    Returns:\n        A `pd.DataFrame` with the requested information.\n    \"\"\"\n    return self._show(\n        self.nodes, param_names, indices, params, states, channel_names\n    )\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.step","title":"step(u, delta_t, external_inds, externals, params, solver='bwd_euler', voltage_solver='jaxley.stone')","text":"

    One step of solving the Ordinary Differential Equation.

    This function is called inside of integrate and increments the state of the module by one time step. Calls _step_channels and _step_synapse to update the states of the channels and synapses using fwd_euler.

    Parameters:

    Name Type Description Default u Dict[str, ndarray]

    The state of the module. voltages = u[\u201cv\u201d]

    required delta_t float

    The time step.

    required external_inds Dict[str, ndarray]

    The indices of the external inputs.

    required externals Dict[str, ndarray]

    The external inputs.

    required params Dict[str, ndarray]

    The parameters of the module.

    required solver str

    The solver to use for the voltages. Either \u201cbwd_euler\u201d or \u201cfwd_euler\u201d.

    'bwd_euler' voltage_solver str

    The tridiagonal solver to used to diagonalize the coefficient matrix of the ODE system. Either \u201cjaxley.thomas\u201d, \u201cjaxley.stone\u201d, or \u201cjax.scipy.sparse\u201d.

    'jaxley.stone'

    Returns:

    Type Description Dict[str, ndarray]

    The updated state of the module.

    Source code in jaxley/modules/base.py
    def step(\n    self,\n    u: Dict[str, jnp.ndarray],\n    delta_t: float,\n    external_inds: Dict[str, jnp.ndarray],\n    externals: Dict[str, jnp.ndarray],\n    params: Dict[str, jnp.ndarray],\n    solver: str = \"bwd_euler\",\n    voltage_solver: str = \"jaxley.stone\",\n) -> Dict[str, jnp.ndarray]:\n    \"\"\"One step of solving the Ordinary Differential Equation.\n\n    This function is called inside of `integrate` and increments the state of the\n    module by one time step. Calls `_step_channels` and `_step_synapse` to update\n    the states of the channels and synapses using fwd_euler.\n\n    Args:\n        u: The state of the module. voltages = u[\"v\"]\n        delta_t: The time step.\n        external_inds: The indices of the external inputs.\n        externals: The external inputs.\n        params: The parameters of the module.\n        solver: The solver to use for the voltages. Either \"bwd_euler\" or \"fwd_euler\".\n        voltage_solver: The tridiagonal solver to used to diagonalize the\n            coefficient matrix of the ODE system. Either \"jaxley.thomas\",\n            \"jaxley.stone\", or \"jax.scipy.sparse\".\n\n    Returns:\n        The updated state of the module.\n    \"\"\"\n\n    # Extract the voltages\n    voltages = u[\"v\"]\n\n    # Extract the external inputs\n    has_current = \"i\" in externals.keys()\n    i_current = externals[\"i\"] if has_current else jnp.asarray([]).astype(\"float\")\n    i_inds = external_inds[\"i\"] if has_current else jnp.asarray([]).astype(\"int32\")\n    i_ext = self._get_external_input(\n        voltages, i_inds, i_current, params[\"radius\"], params[\"length\"]\n    )\n\n    # Step of the channels.\n    u, (v_terms, const_terms) = self._step_channels(\n        u, delta_t, self.channels, self.nodes, params\n    )\n\n    # Step of the synapse.\n    u, (syn_v_terms, syn_const_terms) = self._step_synapse(\n        u,\n        self.synapses,\n        params,\n        delta_t,\n        self.edges,\n    )\n\n    # Clamp for channels and synapses.\n    for key in externals.keys():\n        if key not in [\"i\", \"v\"]:\n            u[key] = u[key].at[external_inds[key]].set(externals[key])\n\n    # Voltage steps.\n    cm = params[\"capacitance\"]  # Abbreviation.\n    if solver == \"bwd_euler\":\n        new_voltages = step_voltage_implicit(\n            voltages=voltages,\n            voltage_terms=(v_terms + syn_v_terms) / cm,\n            constant_terms=(const_terms + i_ext + syn_const_terms) / cm,\n            coupling_conds_upper=params[\"branch_uppers\"],\n            coupling_conds_lower=params[\"branch_lowers\"],\n            summed_coupling_conds=params[\"branch_diags\"],\n            branchpoint_conds_children=params[\"branchpoint_conds_children\"],\n            branchpoint_conds_parents=params[\"branchpoint_conds_parents\"],\n            branchpoint_weights_children=params[\"branchpoint_weights_children\"],\n            branchpoint_weights_parents=params[\"branchpoint_weights_parents\"],\n            par_inds=self.par_inds,\n            child_inds=self.child_inds,\n            nbranches=self.total_nbranches,\n            solver=voltage_solver,\n            delta_t=delta_t,\n            children_in_level=self.children_in_level,\n            parents_in_level=self.parents_in_level,\n            root_inds=self.root_inds,\n            branchpoint_group_inds=self.branchpoint_group_inds,\n            debug_states=self.debug_states,\n        )\n    else:\n        new_voltages = step_voltage_explicit(\n            voltages,\n            (v_terms + syn_v_terms) / cm,\n            (const_terms + i_ext + syn_const_terms) / cm,\n            coupling_conds_bwd=params[\"coupling_conds_bwd\"],\n            coupling_conds_fwd=params[\"coupling_conds_fwd\"],\n            branch_cond_fwd=params[\"branch_conds_fwd\"],\n            branch_cond_bwd=params[\"branch_conds_bwd\"],\n            nbranches=self.total_nbranches,\n            parents=self.comb_parents,\n            delta_t=delta_t,\n        )\n\n    u[\"v\"] = new_voltages.ravel(order=\"C\")\n\n    # Clamp for voltages.\n    if \"v\" in externals.keys():\n        u[\"v\"] = u[\"v\"].at[external_inds[\"v\"]].set(externals[\"v\"])\n\n    return u\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.stimulate","title":"stimulate(current=None, verbose=True)","text":"

    Insert a stimulus into the compartment.

    current must be a 1d array or have batch dimension of size (num_compartments, ) or (1, ). If 1d, the same stimulus is added to all compartments.

    This function cannot be run during jax.jit and jax.grad. Because of this, it should only be used for static stimuli (i.e., stimuli that do not depend on the data and that should not be learned). For stimuli that depend on data (or that should be learned), please use data_stimulate().

    Parameters:

    Name Type Description Default current Optional[ndarray]

    Current in nA.

    None Source code in jaxley/modules/base.py
    def stimulate(self, current: Optional[jnp.ndarray] = None, verbose: bool = True):\n    \"\"\"Insert a stimulus into the compartment.\n\n    current must be a 1d array or have batch dimension of size `(num_compartments, )`\n    or `(1, )`. If 1d, the same stimulus is added to all compartments.\n\n    This function cannot be run during `jax.jit` and `jax.grad`. Because of this,\n    it should only be used for static stimuli (i.e., stimuli that do not depend\n    on the data and that should not be learned). For stimuli that depend on data\n    (or that should be learned), please use `data_stimulate()`.\n\n    Args:\n        current: Current in `nA`.\n    \"\"\"\n    self._external_input(\"i\", current, self.nodes, verbose=verbose)\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.to_jax","title":"to_jax()","text":"

    Move .nodes to .jaxnodes.

    Before the actual simulation is run (via jx.integrate), all parameters of the jx.Module are stored in .nodes (a pd.DataFrame). However, for simulation, these parameters have to be moved to be jnp.ndarrays such that they can be processed on GPU/TPU and such that the simulation can be differentiated. .to_jax() copies the .nodes to .jaxnodes.

    Source code in jaxley/modules/base.py
    def to_jax(self):\n    \"\"\"Move `.nodes` to `.jaxnodes`.\n\n    Before the actual simulation is run (via `jx.integrate`), all parameters of\n    the `jx.Module` are stored in `.nodes` (a `pd.DataFrame`). However, for\n    simulation, these parameters have to be moved to be `jnp.ndarrays` such that\n    they can be processed on GPU/TPU and such that the simulation can be\n    differentiated. `.to_jax()` copies the `.nodes` to `.jaxnodes`.\n    \"\"\"\n    self.jaxnodes = {}\n    for key, value in self.nodes.to_dict(orient=\"list\").items():\n        inds = jnp.arange(len(value))\n        self.jaxnodes[key] = jnp.asarray(value)[inds]\n\n    # `jaxedges` contains only parameters (no indices).\n    # `jaxedges` contains only non-Nan elements. This is unlike the channels where\n    # we allow parameter sharing.\n    self.jaxedges = {}\n    edges = self.edges.to_dict(orient=\"list\")\n    for i, synapse in enumerate(self.synapses):\n        for key in synapse.synapse_params:\n            condition = np.asarray(edges[\"type_ind\"]) == i\n            self.jaxedges[key] = jnp.asarray(np.asarray(edges[key])[condition])\n        for key in synapse.synapse_states:\n            self.jaxedges[key] = jnp.asarray(np.asarray(edges[key])[condition])\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.vis","title":"vis(ax=None, col='k', dims=(0, 1), type='line', morph_plot_kwargs={})","text":"

    Visualize the module.

    Parameters:

    Name Type Description Default ax Optional[Axes]

    An axis into which to plot.

    None col str

    The color for all branches.

    'k' dims Tuple[int]

    Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of two of them.

    (0, 1) morph_plot_kwargs Dict

    Keyword arguments passed to the plotting function.

    {} Source code in jaxley/modules/base.py
    def vis(\n    self,\n    ax: Optional[Axes] = None,\n    col: str = \"k\",\n    dims: Tuple[int] = (0, 1),\n    type: str = \"line\",\n    morph_plot_kwargs: Dict = {},\n) -> Axes:\n    \"\"\"Visualize the module.\n\n    Args:\n        ax: An axis into which to plot.\n        col: The color for all branches.\n        dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of\n            two of them.\n        morph_plot_kwargs: Keyword arguments passed to the plotting function.\n    \"\"\"\n    return self._vis(\n        dims=dims,\n        col=col,\n        ax=ax,\n        view=self.nodes,\n        type=type,\n        morph_plot_kwargs=morph_plot_kwargs,\n    )\n
    "},{"location":"reference/modules/#compartment","title":"Compartment","text":"

    Bases: Module

    Compartment class.

    This class defines a single compartment that can be simulated by itself or connected up into branches. It is the basic building block of a neuron model.

    Source code in jaxley/modules/compartment.py
    class Compartment(Module):\n    \"\"\"Compartment class.\n\n    This class defines a single compartment that can be simulated by itself or\n    connected up into branches. It is the basic building block of a neuron model.\n    \"\"\"\n\n    compartment_params: Dict = {\n        \"length\": 10.0,  # um\n        \"radius\": 1.0,  # um\n        \"axial_resistivity\": 5_000.0,  # ohm cm\n        \"capacitance\": 1.0,  # uF/cm^2\n    }\n    compartment_states: Dict = {\"v\": -70.0}\n\n    def __init__(self):\n        super().__init__()\n\n        self.nseg = 1\n        self.total_nbranches = 1\n        self.nbranches_per_cell = [1]\n        self.cumsum_nbranches = jnp.asarray([0, 1])\n\n        # Setting up the `nodes` for indexing.\n        self.nodes = pd.DataFrame(\n            dict(comp_index=[0], branch_index=[0], cell_index=[0])\n        )\n        self._append_params_and_states(self.compartment_params, self.compartment_states)\n\n        # Synapses.\n        self.branch_edges = pd.DataFrame(\n            dict(parent_branch_index=[], child_branch_index=[])\n        )\n\n        # For morphology indexing.\n        self.child_inds = np.asarray([]).astype(int)\n        self.child_belongs_to_branchpoint = np.asarray([]).astype(int)\n        self.par_inds = np.asarray([]).astype(int)\n        self.total_nbranchpoints = 0\n        self.branchpoint_group_inds = np.asarray([]).astype(int)\n\n        self.children_in_level = []\n        self.parents_in_level = []\n        self.root_inds = jnp.asarray([0])\n\n        # Initialize the module.\n        self.initialize()\n        self.init_syns()\n        self.initialized_conds = True\n\n        # Coordinates.\n        self.xyzr = [float(\"NaN\") * np.zeros((2, 4))]\n\n    def init_conds(self, params):\n        return {\n            \"branchpoint_conds_children\": jnp.asarray([]),\n            \"branchpoint_conds_parents\": jnp.asarray([]),\n            \"branchpoint_weights_children\": jnp.asarray([]),\n            \"branchpoint_weights_parents\": jnp.asarray([]),\n            \"branch_uppers\": jnp.asarray([]),\n            \"branch_lowers\": jnp.asarray([]),\n            \"branch_diags\": jnp.asarray([0.0]),\n        }\n
    "},{"location":"reference/modules/#branch","title":"Branch","text":"

    Bases: Module

    Branch class.

    This class defines a single branch that can be simulated by itself or connected to build a cell. A branch is linear segment of several compartments and can be connected to no, one or more other branches at each end to build more intricate cell morphologies.

    Source code in jaxley/modules/branch.py
    class Branch(Module):\n    \"\"\"Branch class.\n\n    This class defines a single branch that can be simulated by itself or\n    connected to build a cell. A branch is linear segment of several compartments\n    and can be connected to no, one or more other branches at each end to build more\n    intricate cell morphologies.\n    \"\"\"\n\n    branch_params: Dict = {}\n    branch_states: Dict = {}\n\n    def __init__(\n        self,\n        compartments: Optional[Union[Compartment, List[Compartment]]] = None,\n        nseg: Optional[int] = None,\n    ):\n        \"\"\"\n        Args:\n            compartments: A single compartment or a list of compartments that make up the\n                branch.\n            nseg: Number of segments to divide the branch into. If `compartments` is an\n                a single compartment, than the compartment is repeated `nseg` times to\n                create the branch.\n        \"\"\"\n        super().__init__()\n        assert (\n            isinstance(compartments, (Compartment, List)) or compartments is None\n        ), \"Only Compartment or List[Compartment] is allowed.\"\n        if isinstance(compartments, Compartment):\n            assert (\n                nseg is not None\n            ), \"If `compartments` is not a list then you have to set `nseg`.\"\n        compartments = Compartment() if compartments is None else compartments\n        nseg = 1 if nseg is None else nseg\n\n        if isinstance(compartments, Compartment):\n            compartment_list = [compartments] * nseg\n        else:\n            compartment_list = compartments\n\n        self.nseg = len(compartment_list)\n        self.total_nbranches = 1\n        self.nbranches_per_cell = [1]\n        self.cumsum_nbranches = jnp.asarray([0, 1])\n\n        # Indexing.\n        self.nodes = pd.concat([c.nodes for c in compartment_list], ignore_index=True)\n        self._append_params_and_states(self.branch_params, self.branch_states)\n        self.nodes[\"comp_index\"] = np.arange(self.nseg).tolist()\n        self.nodes[\"branch_index\"] = [0] * self.nseg\n        self.nodes[\"cell_index\"] = [0] * self.nseg\n\n        # Channels.\n        self._gather_channels_from_constituents(compartment_list)\n\n        # Synapse indexing.\n        self.syn_edges = pd.DataFrame(\n            dict(global_pre_comp_index=[], global_post_comp_index=[], type=\"\")\n        )\n        self.branch_edges = pd.DataFrame(\n            dict(parent_branch_index=[], child_branch_index=[])\n        )\n\n        # For morphology indexing.\n        self.child_inds = np.asarray([]).astype(int)\n        self.child_belongs_to_branchpoint = np.asarray([]).astype(int)\n        self.par_inds = np.asarray([]).astype(int)\n        self.total_nbranchpoints = 0\n        self.branchpoint_group_inds = np.asarray([]).astype(int)\n\n        self.children_in_level = []\n        self.parents_in_level = []\n        self.root_inds = jnp.asarray([0])\n\n        self.initialize()\n        self.init_syns()\n        self.initialized_conds = False\n\n        # Coordinates.\n        self.xyzr = [float(\"NaN\") * np.zeros((2, 4))]\n\n    def __getattr__(self, key: str):\n        # Ensure that hidden methods such as `__deepcopy__` still work.\n        if key.startswith(\"__\"):\n            return super().__getattribute__(key)\n\n        if key in [\"comp\", \"loc\"]:\n            view = deepcopy(self.nodes)\n            view[\"global_comp_index\"] = view[\"comp_index\"]\n            view[\"global_branch_index\"] = view[\"branch_index\"]\n            view[\"global_cell_index\"] = view[\"cell_index\"]\n            compview = CompartmentView(self, view)\n            return compview if key == \"comp\" else compview.loc\n        elif key in self.group_nodes:\n            inds = self.group_nodes[key].index.values\n            view = self.nodes.loc[inds]\n            view[\"global_comp_index\"] = view[\"comp_index\"]\n            view[\"global_branch_index\"] = view[\"branch_index\"]\n            view[\"global_cell_index\"] = view[\"cell_index\"]\n            return GroupView(self, view, CompartmentView, [\"comp\", \"loc\"])\n        else:\n            raise KeyError(f\"Key {key} not recognized.\")\n\n    def init_conds(self, params: Dict) -> Dict[str, jnp.ndarray]:\n        conds = self.init_branch_conds(\n            params[\"axial_resistivity\"], params[\"radius\"], params[\"length\"], self.nseg\n        )\n        cond_params = {\n            \"branchpoint_conds_children\": jnp.asarray([]),\n            \"branchpoint_conds_parents\": jnp.asarray([]),\n            \"branchpoint_weights_children\": jnp.asarray([]),\n            \"branchpoint_weights_parents\": jnp.asarray([]),\n        }\n        cond_params[\"branch_lowers\"] = conds[0]\n        cond_params[\"branch_uppers\"] = conds[1]\n        cond_params[\"branch_diags\"] = conds[2]\n\n        return cond_params\n\n    @staticmethod\n    def init_branch_conds(\n        axial_resistivity: jnp.ndarray,\n        radiuses: jnp.ndarray,\n        lengths: jnp.ndarray,\n        nseg: int,\n    ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:\n        \"\"\"Given an axial resisitivity, set the coupling conductances.\n\n        Args:\n            axial_resistivity: Axial resistivity of each compartment.\n            radiuses: Radius of each compartment.\n            lengths: Length of each compartment.\n            nseg: Number of compartments in the branch.\n\n        Returns:\n            Tuple of forward coupling conductances, backward coupling conductances, and summed coupling conductances.\n        \"\"\"\n\n        # Compute coupling conductance for segments within a branch.\n        # `radius`: um\n        # `r_a`: ohm cm\n        # `length_single_compartment`: um\n        # `coupling_conds`: S * um / cm / um^2 = S / cm / um\n        r1 = radiuses[:-1]\n        r2 = radiuses[1:]\n        r_a1 = axial_resistivity[:-1]\n        r_a2 = axial_resistivity[1:]\n        l1 = lengths[:-1]\n        l2 = lengths[1:]\n        coupling_conds_bwd = compute_coupling_cond(r1, r2, r_a1, r_a2, l1, l2)\n        coupling_conds_fwd = compute_coupling_cond(r2, r1, r_a2, r_a1, l2, l1)\n\n        # Compute the summed coupling conductances of each compartment.\n        summed_coupling_conds = jnp.zeros((nseg))\n        summed_coupling_conds = summed_coupling_conds.at[1:].add(coupling_conds_fwd)\n        summed_coupling_conds = summed_coupling_conds.at[:-1].add(coupling_conds_bwd)\n        return coupling_conds_fwd, coupling_conds_bwd, summed_coupling_conds\n\n    def __len__(self) -> int:\n        return self.nseg\n
    "},{"location":"reference/modules/#jaxley.modules.branch.Branch.__init__","title":"__init__(compartments=None, nseg=None)","text":"

    Parameters:

    Name Type Description Default compartments Optional[Union[Compartment, List[Compartment]]]

    A single compartment or a list of compartments that make up the branch.

    None nseg Optional[int]

    Number of segments to divide the branch into. If compartments is an a single compartment, than the compartment is repeated nseg times to create the branch.

    None Source code in jaxley/modules/branch.py
    def __init__(\n    self,\n    compartments: Optional[Union[Compartment, List[Compartment]]] = None,\n    nseg: Optional[int] = None,\n):\n    \"\"\"\n    Args:\n        compartments: A single compartment or a list of compartments that make up the\n            branch.\n        nseg: Number of segments to divide the branch into. If `compartments` is an\n            a single compartment, than the compartment is repeated `nseg` times to\n            create the branch.\n    \"\"\"\n    super().__init__()\n    assert (\n        isinstance(compartments, (Compartment, List)) or compartments is None\n    ), \"Only Compartment or List[Compartment] is allowed.\"\n    if isinstance(compartments, Compartment):\n        assert (\n            nseg is not None\n        ), \"If `compartments` is not a list then you have to set `nseg`.\"\n    compartments = Compartment() if compartments is None else compartments\n    nseg = 1 if nseg is None else nseg\n\n    if isinstance(compartments, Compartment):\n        compartment_list = [compartments] * nseg\n    else:\n        compartment_list = compartments\n\n    self.nseg = len(compartment_list)\n    self.total_nbranches = 1\n    self.nbranches_per_cell = [1]\n    self.cumsum_nbranches = jnp.asarray([0, 1])\n\n    # Indexing.\n    self.nodes = pd.concat([c.nodes for c in compartment_list], ignore_index=True)\n    self._append_params_and_states(self.branch_params, self.branch_states)\n    self.nodes[\"comp_index\"] = np.arange(self.nseg).tolist()\n    self.nodes[\"branch_index\"] = [0] * self.nseg\n    self.nodes[\"cell_index\"] = [0] * self.nseg\n\n    # Channels.\n    self._gather_channels_from_constituents(compartment_list)\n\n    # Synapse indexing.\n    self.syn_edges = pd.DataFrame(\n        dict(global_pre_comp_index=[], global_post_comp_index=[], type=\"\")\n    )\n    self.branch_edges = pd.DataFrame(\n        dict(parent_branch_index=[], child_branch_index=[])\n    )\n\n    # For morphology indexing.\n    self.child_inds = np.asarray([]).astype(int)\n    self.child_belongs_to_branchpoint = np.asarray([]).astype(int)\n    self.par_inds = np.asarray([]).astype(int)\n    self.total_nbranchpoints = 0\n    self.branchpoint_group_inds = np.asarray([]).astype(int)\n\n    self.children_in_level = []\n    self.parents_in_level = []\n    self.root_inds = jnp.asarray([0])\n\n    self.initialize()\n    self.init_syns()\n    self.initialized_conds = False\n\n    # Coordinates.\n    self.xyzr = [float(\"NaN\") * np.zeros((2, 4))]\n
    "},{"location":"reference/modules/#jaxley.modules.branch.Branch.init_branch_conds","title":"init_branch_conds(axial_resistivity, radiuses, lengths, nseg) staticmethod","text":"

    Given an axial resisitivity, set the coupling conductances.

    Parameters:

    Name Type Description Default axial_resistivity ndarray

    Axial resistivity of each compartment.

    required radiuses ndarray

    Radius of each compartment.

    required lengths ndarray

    Length of each compartment.

    required nseg int

    Number of compartments in the branch.

    required

    Returns:

    Type Description Tuple[ndarray, ndarray, ndarray]

    Tuple of forward coupling conductances, backward coupling conductances, and summed coupling conductances.

    Source code in jaxley/modules/branch.py
    @staticmethod\ndef init_branch_conds(\n    axial_resistivity: jnp.ndarray,\n    radiuses: jnp.ndarray,\n    lengths: jnp.ndarray,\n    nseg: int,\n) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:\n    \"\"\"Given an axial resisitivity, set the coupling conductances.\n\n    Args:\n        axial_resistivity: Axial resistivity of each compartment.\n        radiuses: Radius of each compartment.\n        lengths: Length of each compartment.\n        nseg: Number of compartments in the branch.\n\n    Returns:\n        Tuple of forward coupling conductances, backward coupling conductances, and summed coupling conductances.\n    \"\"\"\n\n    # Compute coupling conductance for segments within a branch.\n    # `radius`: um\n    # `r_a`: ohm cm\n    # `length_single_compartment`: um\n    # `coupling_conds`: S * um / cm / um^2 = S / cm / um\n    r1 = radiuses[:-1]\n    r2 = radiuses[1:]\n    r_a1 = axial_resistivity[:-1]\n    r_a2 = axial_resistivity[1:]\n    l1 = lengths[:-1]\n    l2 = lengths[1:]\n    coupling_conds_bwd = compute_coupling_cond(r1, r2, r_a1, r_a2, l1, l2)\n    coupling_conds_fwd = compute_coupling_cond(r2, r1, r_a2, r_a1, l2, l1)\n\n    # Compute the summed coupling conductances of each compartment.\n    summed_coupling_conds = jnp.zeros((nseg))\n    summed_coupling_conds = summed_coupling_conds.at[1:].add(coupling_conds_fwd)\n    summed_coupling_conds = summed_coupling_conds.at[:-1].add(coupling_conds_bwd)\n    return coupling_conds_fwd, coupling_conds_bwd, summed_coupling_conds\n
    "},{"location":"reference/modules/#cell","title":"Cell","text":"

    Bases: Module

    Cell class.

    This class defines a single cell that can be simulated by itself or connected with synapses to build a network. A cell is made up of several branches and supports intricate cell morphologies.

    Source code in jaxley/modules/cell.py
    class Cell(Module):\n    \"\"\"Cell class.\n\n    This class defines a single cell that can be simulated by itself or\n    connected with synapses to build a network. A cell is made up of several branches\n    and supports intricate cell morphologies.\n    \"\"\"\n\n    cell_params: Dict = {}\n    cell_states: Dict = {}\n\n    def __init__(\n        self,\n        branches: Optional[Union[Branch, List[Branch]]] = None,\n        parents: Optional[List[int]] = None,\n        xyzr: Optional[List[np.ndarray]] = None,\n    ):\n        \"\"\"Initialize a cell.\n\n        Args:\n            branches: A single branch or a list of branches that make up the cell.\n                If a single branch is provided, then the branch is repeated `len(parents)`\n                times to create the cell.\n            parents: The parent branch index for each branch. The first branch has no\n                parent and is therefore set to -1.\n            xyzr: For every branch, the x, y, and z coordinates and the radius at the\n                traced coordinates. Note that this is the full tracing (from SWC), not\n                the stick representation coordinates.\n        \"\"\"\n        super().__init__()\n        assert (\n            isinstance(branches, (Branch, List)) or branches is None\n        ), \"Only Branch or List[Branch] is allowed.\"\n        if branches is not None:\n            assert (\n                parents is not None\n            ), \"If `branches` is not a list then you have to set `parents`.\"\n        if isinstance(branches, List):\n            assert len(parents) == len(\n                branches\n            ), \"Ensure equally many parents, i.e. len(branches) == len(parents).\"\n\n        branches = Branch() if branches is None else branches\n        parents = [-1] if parents is None else parents\n\n        if isinstance(branches, Branch):\n            branch_list = [branches for _ in range(len(parents))]\n        else:\n            branch_list = branches\n\n        if xyzr is not None:\n            assert len(xyzr) == len(parents)\n            self.xyzr = xyzr\n        else:\n            # For every branch (`len(parents)`), we have a start and end point (`2`) and\n            # a (x,y,z,r) coordinate for each of them (`4`).\n            # Since `xyzr` is only inspected at `.vis()` and because it depends on the\n            # (potentially learned) length of every compartment, we only populate\n            # self.xyzr at `.vis()`.\n            self.xyzr = [float(\"NaN\") * np.zeros((2, 4)) for _ in range(len(parents))]\n\n        self.nseg = branch_list[0].nseg\n        self.total_nbranches = len(branch_list)\n        self.nbranches_per_cell = [len(branch_list)]\n        self.comb_parents = jnp.asarray(parents)\n        self.comb_children = compute_children_indices(self.comb_parents)\n        self.cumsum_nbranches = jnp.asarray([0, len(branch_list)])\n\n        # Indexing.\n        self.nodes = pd.concat([c.nodes for c in branch_list], ignore_index=True)\n        self._append_params_and_states(self.cell_params, self.cell_states)\n        self.nodes[\"comp_index\"] = np.arange(self.nseg * self.total_nbranches).tolist()\n        self.nodes[\"branch_index\"] = (\n            np.arange(self.nseg * self.total_nbranches) // self.nseg\n        ).tolist()\n        self.nodes[\"cell_index\"] = [0] * (self.nseg * self.total_nbranches)\n\n        # Channels.\n        self._gather_channels_from_constituents(branch_list)\n\n        # Synapse indexing.\n        self.syn_edges = pd.DataFrame(\n            dict(global_pre_comp_index=[], global_post_comp_index=[], type=\"\")\n        )\n        self.branch_edges = pd.DataFrame(\n            dict(\n                parent_branch_index=self.comb_parents[1:],\n                child_branch_index=np.arange(1, self.total_nbranches),\n            )\n        )\n\n        # For morphology indexing.\n        par_inds = self.branch_edges[\"parent_branch_index\"].to_numpy()\n        self.child_inds = self.branch_edges[\"child_branch_index\"].to_numpy()\n        self.child_belongs_to_branchpoint = remap_to_consecutive(par_inds)\n\n        # TODO: does order have to be preserved?\n        self.par_inds = np.unique(par_inds)\n        self.total_nbranchpoints = len(self.par_inds)\n        self.root_inds = jnp.asarray([0])\n\n        self.initialize()\n\n        self.init_syns()\n        self.initialized_conds = False\n\n    def __getattr__(self, key: str):\n        # Ensure that hidden methods such as `__deepcopy__` still work.\n        if key.startswith(\"__\"):\n            return super().__getattribute__(key)\n\n        if key == \"branch\":\n            view = deepcopy(self.nodes)\n            view[\"global_comp_index\"] = view[\"comp_index\"]\n            view[\"global_branch_index\"] = view[\"branch_index\"]\n            view[\"global_cell_index\"] = view[\"cell_index\"]\n            return BranchView(self, view)\n        elif key in self.group_nodes:\n            inds = self.group_nodes[key].index.values\n            view = self.nodes.loc[inds]\n            view[\"global_comp_index\"] = view[\"comp_index\"]\n            view[\"global_branch_index\"] = view[\"branch_index\"]\n            view[\"global_cell_index\"] = view[\"cell_index\"]\n            return GroupView(self, view, BranchView, [\"branch\"])\n        else:\n            raise KeyError(f\"Key {key} not recognized.\")\n\n    def init_morph(self):\n        \"\"\"Initialize morphology.\"\"\"\n\n        # For Jaxley custom implementation.\n        children_and_parents = compute_morphology_indices_in_levels(\n            len(self.par_inds),\n            self.child_belongs_to_branchpoint,\n            self.par_inds,\n            self.child_inds,\n        )\n        self.branchpoint_group_inds = build_branchpoint_group_inds(\n            len(self.par_inds),\n            self.child_belongs_to_branchpoint,\n            self.nseg,\n            self.total_nbranches,\n        )\n        parents = self.comb_parents\n        children_inds = children_and_parents[\"children\"]\n        parents_inds = children_and_parents[\"parents\"]\n\n        levels = compute_levels(parents)\n        self.children_in_level = compute_children_in_level(levels, children_inds)\n        self.parents_in_level = compute_parents_in_level(\n            levels, self.par_inds, parents_inds\n        )\n\n        self.initialized_morph = True\n\n    def init_conds(self, params: Dict) -> Dict[str, jnp.ndarray]:\n        \"\"\"Given an axial resisitivity, set the coupling conductances.\"\"\"\n        nbranches = self.total_nbranches\n        nseg = self.nseg\n\n        axial_resistivity = jnp.reshape(params[\"axial_resistivity\"], (nbranches, nseg))\n        radiuses = jnp.reshape(params[\"radius\"], (nbranches, nseg))\n        lengths = jnp.reshape(params[\"length\"], (nbranches, nseg))\n\n        conds = vmap(Branch.init_branch_conds, in_axes=(0, 0, 0, None))(\n            axial_resistivity, radiuses, lengths, self.nseg\n        )\n        coupling_conds_fwd = conds[0]\n        coupling_conds_bwd = conds[1]\n        summed_coupling_conds = conds[2]\n\n        # The conductance from the children to the branch point.\n        branchpoint_conds_children = vmap(\n            compute_coupling_cond_branchpoint, in_axes=(0, 0, 0)\n        )(\n            radiuses[self.child_inds, 0],\n            axial_resistivity[self.child_inds, 0],\n            lengths[self.child_inds, 0],\n        )\n        # The conductance from the parents to the branch point.\n        branchpoint_conds_parents = vmap(\n            compute_coupling_cond_branchpoint, in_axes=(0, 0, 0)\n        )(\n            radiuses[self.par_inds, -1],\n            axial_resistivity[self.par_inds, -1],\n            lengths[self.par_inds, -1],\n        )\n\n        # Weights with which the compartments influence their nearby node.\n        # The impact of the children on the branch point.\n        branchpoint_weights_children = vmap(compute_impact_on_node, in_axes=(0, 0, 0))(\n            radiuses[self.child_inds, 0],\n            axial_resistivity[self.child_inds, 0],\n            lengths[self.child_inds, 0],\n        )\n        # The impact of parents on the branch point.\n        branchpoint_weights_parents = vmap(compute_impact_on_node, in_axes=(0, 0, 0))(\n            radiuses[self.par_inds, -1],\n            axial_resistivity[self.par_inds, -1],\n            lengths[self.par_inds, -1],\n        )\n\n        summed_coupling_conds = self.update_summed_coupling_conds(\n            summed_coupling_conds,\n            self.child_inds,\n            self.par_inds,\n            branchpoint_conds_children,\n            branchpoint_conds_parents,\n        )\n\n        cond_params = {\n            \"branch_uppers\": coupling_conds_bwd,\n            \"branch_lowers\": coupling_conds_fwd,\n            \"branch_diags\": summed_coupling_conds,\n            \"branchpoint_conds_children\": branchpoint_conds_children,\n            \"branchpoint_conds_parents\": branchpoint_conds_parents,\n            \"branchpoint_weights_children\": branchpoint_weights_children,\n            \"branchpoint_weights_parents\": branchpoint_weights_parents,\n        }\n        return cond_params\n\n    @staticmethod\n    def update_summed_coupling_conds(\n        summed_conds,\n        child_inds,\n        par_inds,\n        branchpoint_conds_children,\n        branchpoint_conds_parents,\n    ):\n        \"\"\"Perform updates on the diagonal based on conductances of the branchpoints.\n\n        Args:\n            summed_conds: shape [num_branches, nseg]\n            child_inds: shape [num_branches - 1]\n            conds_fwd: shape [num_branches - 1]\n            conds_bwd: shape [num_branches - 1]\n            parents: shape [num_branches]\n\n        Returns:\n            Updated `summed_coupling_conds`.\n        \"\"\"\n        summed_conds = summed_conds.at[child_inds, 0].add(branchpoint_conds_children)\n        summed_conds = summed_conds.at[par_inds, -1].add(branchpoint_conds_parents)\n        return summed_conds\n
    "},{"location":"reference/modules/#jaxley.modules.cell.Cell.__init__","title":"__init__(branches=None, parents=None, xyzr=None)","text":"

    Initialize a cell.

    Parameters:

    Name Type Description Default branches Optional[Union[Branch, List[Branch]]]

    A single branch or a list of branches that make up the cell. If a single branch is provided, then the branch is repeated len(parents) times to create the cell.

    None parents Optional[List[int]]

    The parent branch index for each branch. The first branch has no parent and is therefore set to -1.

    None xyzr Optional[List[ndarray]]

    For every branch, the x, y, and z coordinates and the radius at the traced coordinates. Note that this is the full tracing (from SWC), not the stick representation coordinates.

    None Source code in jaxley/modules/cell.py
    def __init__(\n    self,\n    branches: Optional[Union[Branch, List[Branch]]] = None,\n    parents: Optional[List[int]] = None,\n    xyzr: Optional[List[np.ndarray]] = None,\n):\n    \"\"\"Initialize a cell.\n\n    Args:\n        branches: A single branch or a list of branches that make up the cell.\n            If a single branch is provided, then the branch is repeated `len(parents)`\n            times to create the cell.\n        parents: The parent branch index for each branch. The first branch has no\n            parent and is therefore set to -1.\n        xyzr: For every branch, the x, y, and z coordinates and the radius at the\n            traced coordinates. Note that this is the full tracing (from SWC), not\n            the stick representation coordinates.\n    \"\"\"\n    super().__init__()\n    assert (\n        isinstance(branches, (Branch, List)) or branches is None\n    ), \"Only Branch or List[Branch] is allowed.\"\n    if branches is not None:\n        assert (\n            parents is not None\n        ), \"If `branches` is not a list then you have to set `parents`.\"\n    if isinstance(branches, List):\n        assert len(parents) == len(\n            branches\n        ), \"Ensure equally many parents, i.e. len(branches) == len(parents).\"\n\n    branches = Branch() if branches is None else branches\n    parents = [-1] if parents is None else parents\n\n    if isinstance(branches, Branch):\n        branch_list = [branches for _ in range(len(parents))]\n    else:\n        branch_list = branches\n\n    if xyzr is not None:\n        assert len(xyzr) == len(parents)\n        self.xyzr = xyzr\n    else:\n        # For every branch (`len(parents)`), we have a start and end point (`2`) and\n        # a (x,y,z,r) coordinate for each of them (`4`).\n        # Since `xyzr` is only inspected at `.vis()` and because it depends on the\n        # (potentially learned) length of every compartment, we only populate\n        # self.xyzr at `.vis()`.\n        self.xyzr = [float(\"NaN\") * np.zeros((2, 4)) for _ in range(len(parents))]\n\n    self.nseg = branch_list[0].nseg\n    self.total_nbranches = len(branch_list)\n    self.nbranches_per_cell = [len(branch_list)]\n    self.comb_parents = jnp.asarray(parents)\n    self.comb_children = compute_children_indices(self.comb_parents)\n    self.cumsum_nbranches = jnp.asarray([0, len(branch_list)])\n\n    # Indexing.\n    self.nodes = pd.concat([c.nodes for c in branch_list], ignore_index=True)\n    self._append_params_and_states(self.cell_params, self.cell_states)\n    self.nodes[\"comp_index\"] = np.arange(self.nseg * self.total_nbranches).tolist()\n    self.nodes[\"branch_index\"] = (\n        np.arange(self.nseg * self.total_nbranches) // self.nseg\n    ).tolist()\n    self.nodes[\"cell_index\"] = [0] * (self.nseg * self.total_nbranches)\n\n    # Channels.\n    self._gather_channels_from_constituents(branch_list)\n\n    # Synapse indexing.\n    self.syn_edges = pd.DataFrame(\n        dict(global_pre_comp_index=[], global_post_comp_index=[], type=\"\")\n    )\n    self.branch_edges = pd.DataFrame(\n        dict(\n            parent_branch_index=self.comb_parents[1:],\n            child_branch_index=np.arange(1, self.total_nbranches),\n        )\n    )\n\n    # For morphology indexing.\n    par_inds = self.branch_edges[\"parent_branch_index\"].to_numpy()\n    self.child_inds = self.branch_edges[\"child_branch_index\"].to_numpy()\n    self.child_belongs_to_branchpoint = remap_to_consecutive(par_inds)\n\n    # TODO: does order have to be preserved?\n    self.par_inds = np.unique(par_inds)\n    self.total_nbranchpoints = len(self.par_inds)\n    self.root_inds = jnp.asarray([0])\n\n    self.initialize()\n\n    self.init_syns()\n    self.initialized_conds = False\n
    "},{"location":"reference/modules/#jaxley.modules.cell.Cell.init_conds","title":"init_conds(params)","text":"

    Given an axial resisitivity, set the coupling conductances.

    Source code in jaxley/modules/cell.py
    def init_conds(self, params: Dict) -> Dict[str, jnp.ndarray]:\n    \"\"\"Given an axial resisitivity, set the coupling conductances.\"\"\"\n    nbranches = self.total_nbranches\n    nseg = self.nseg\n\n    axial_resistivity = jnp.reshape(params[\"axial_resistivity\"], (nbranches, nseg))\n    radiuses = jnp.reshape(params[\"radius\"], (nbranches, nseg))\n    lengths = jnp.reshape(params[\"length\"], (nbranches, nseg))\n\n    conds = vmap(Branch.init_branch_conds, in_axes=(0, 0, 0, None))(\n        axial_resistivity, radiuses, lengths, self.nseg\n    )\n    coupling_conds_fwd = conds[0]\n    coupling_conds_bwd = conds[1]\n    summed_coupling_conds = conds[2]\n\n    # The conductance from the children to the branch point.\n    branchpoint_conds_children = vmap(\n        compute_coupling_cond_branchpoint, in_axes=(0, 0, 0)\n    )(\n        radiuses[self.child_inds, 0],\n        axial_resistivity[self.child_inds, 0],\n        lengths[self.child_inds, 0],\n    )\n    # The conductance from the parents to the branch point.\n    branchpoint_conds_parents = vmap(\n        compute_coupling_cond_branchpoint, in_axes=(0, 0, 0)\n    )(\n        radiuses[self.par_inds, -1],\n        axial_resistivity[self.par_inds, -1],\n        lengths[self.par_inds, -1],\n    )\n\n    # Weights with which the compartments influence their nearby node.\n    # The impact of the children on the branch point.\n    branchpoint_weights_children = vmap(compute_impact_on_node, in_axes=(0, 0, 0))(\n        radiuses[self.child_inds, 0],\n        axial_resistivity[self.child_inds, 0],\n        lengths[self.child_inds, 0],\n    )\n    # The impact of parents on the branch point.\n    branchpoint_weights_parents = vmap(compute_impact_on_node, in_axes=(0, 0, 0))(\n        radiuses[self.par_inds, -1],\n        axial_resistivity[self.par_inds, -1],\n        lengths[self.par_inds, -1],\n    )\n\n    summed_coupling_conds = self.update_summed_coupling_conds(\n        summed_coupling_conds,\n        self.child_inds,\n        self.par_inds,\n        branchpoint_conds_children,\n        branchpoint_conds_parents,\n    )\n\n    cond_params = {\n        \"branch_uppers\": coupling_conds_bwd,\n        \"branch_lowers\": coupling_conds_fwd,\n        \"branch_diags\": summed_coupling_conds,\n        \"branchpoint_conds_children\": branchpoint_conds_children,\n        \"branchpoint_conds_parents\": branchpoint_conds_parents,\n        \"branchpoint_weights_children\": branchpoint_weights_children,\n        \"branchpoint_weights_parents\": branchpoint_weights_parents,\n    }\n    return cond_params\n
    "},{"location":"reference/modules/#jaxley.modules.cell.Cell.init_morph","title":"init_morph()","text":"

    Initialize morphology.

    Source code in jaxley/modules/cell.py
    def init_morph(self):\n    \"\"\"Initialize morphology.\"\"\"\n\n    # For Jaxley custom implementation.\n    children_and_parents = compute_morphology_indices_in_levels(\n        len(self.par_inds),\n        self.child_belongs_to_branchpoint,\n        self.par_inds,\n        self.child_inds,\n    )\n    self.branchpoint_group_inds = build_branchpoint_group_inds(\n        len(self.par_inds),\n        self.child_belongs_to_branchpoint,\n        self.nseg,\n        self.total_nbranches,\n    )\n    parents = self.comb_parents\n    children_inds = children_and_parents[\"children\"]\n    parents_inds = children_and_parents[\"parents\"]\n\n    levels = compute_levels(parents)\n    self.children_in_level = compute_children_in_level(levels, children_inds)\n    self.parents_in_level = compute_parents_in_level(\n        levels, self.par_inds, parents_inds\n    )\n\n    self.initialized_morph = True\n
    "},{"location":"reference/modules/#jaxley.modules.cell.Cell.update_summed_coupling_conds","title":"update_summed_coupling_conds(summed_conds, child_inds, par_inds, branchpoint_conds_children, branchpoint_conds_parents) staticmethod","text":"

    Perform updates on the diagonal based on conductances of the branchpoints.

    Parameters:

    Name Type Description Default summed_conds

    shape [num_branches, nseg]

    required child_inds

    shape [num_branches - 1]

    required conds_fwd

    shape [num_branches - 1]

    required conds_bwd

    shape [num_branches - 1]

    required parents

    shape [num_branches]

    required

    Returns:

    Type Description

    Updated summed_coupling_conds.

    Source code in jaxley/modules/cell.py
    @staticmethod\ndef update_summed_coupling_conds(\n    summed_conds,\n    child_inds,\n    par_inds,\n    branchpoint_conds_children,\n    branchpoint_conds_parents,\n):\n    \"\"\"Perform updates on the diagonal based on conductances of the branchpoints.\n\n    Args:\n        summed_conds: shape [num_branches, nseg]\n        child_inds: shape [num_branches - 1]\n        conds_fwd: shape [num_branches - 1]\n        conds_bwd: shape [num_branches - 1]\n        parents: shape [num_branches]\n\n    Returns:\n        Updated `summed_coupling_conds`.\n    \"\"\"\n    summed_conds = summed_conds.at[child_inds, 0].add(branchpoint_conds_children)\n    summed_conds = summed_conds.at[par_inds, -1].add(branchpoint_conds_parents)\n    return summed_conds\n
    "},{"location":"reference/modules/#network","title":"Network","text":"

    Bases: Module

    Network class.

    This class defines a network of cells that can be connected with synapses.

    Source code in jaxley/modules/network.py
    class Network(Module):\n    \"\"\"Network class.\n\n    This class defines a network of cells that can be connected with synapses.\n    \"\"\"\n\n    network_params: Dict = {}\n    network_states: Dict = {}\n\n    def __init__(\n        self,\n        cells: List[Cell],\n    ):\n        \"\"\"Initialize network of cells and synapses.\n\n        Args:\n            cells: A list of cells that make up the network.\n        \"\"\"\n        super().__init__()\n        for cell in cells:\n            self.xyzr += deepcopy(cell.xyzr)\n\n        self.cells = cells\n        self.nseg = cells[0].nseg\n        self._append_params_and_states(self.network_params, self.network_states)\n\n        self.nbranches_per_cell = [cell.total_nbranches for cell in self.cells]\n        self.nbranchpoints_per_cell = [cell.total_nbranchpoints for cell in self.cells]\n        self.total_nbranches = sum(self.nbranches_per_cell)\n        self.cumsum_nbranches = jnp.cumsum(jnp.asarray([0] + self.nbranches_per_cell))\n        self.cumsum_nbranchpoints = jnp.cumsum(\n            jnp.asarray([0] + self.nbranchpoints_per_cell)\n        )\n\n        self.nodes = pd.concat([c.nodes for c in cells], ignore_index=True)\n        self.nodes[\"comp_index\"] = np.arange(self.nseg * self.total_nbranches).tolist()\n        self.nodes[\"branch_index\"] = (\n            np.arange(self.nseg * self.total_nbranches) // self.nseg\n        ).tolist()\n        self.nodes[\"cell_index\"] = list(\n            itertools.chain(\n                *[[i] * (self.nseg * b) for i, b in enumerate(self.nbranches_per_cell)]\n            )\n        )\n\n        parents = [cell.comb_parents for cell in self.cells]\n        self.comb_parents = jnp.concatenate(\n            [p.at[1:].add(self.cumsum_nbranches[i]) for i, p in enumerate(parents)]\n        )\n\n        # Two columns: `parent_branch_index` and `child_branch_index`. One row per\n        # branch, apart from those branches which do not have a parent (i.e.\n        # -1 in parents). For every branch, tracks the global index of that branch\n        # (`child_branch_index`) and the global index of its parent\n        # (`parent_branch_index`).\n        self.branch_edges = pd.DataFrame(\n            dict(\n                parent_branch_index=self.comb_parents[self.comb_parents != -1],\n                child_branch_index=np.where(self.comb_parents != -1)[0],\n            )\n        )\n\n        # For morphology indexing.\n        par_inds = self.branch_edges[\"parent_branch_index\"].to_numpy()\n        self.child_inds = self.branch_edges[\"child_branch_index\"].to_numpy()\n        self.child_belongs_to_branchpoint = remap_to_consecutive(par_inds)\n        self.par_inds = np.unique(par_inds)  # TODO: does order have to be preserved?\n        self.total_nbranchpoints = len(self.par_inds)\n        self.root_inds = self.cumsum_nbranches[:-1]\n\n        # Channels.\n        self._gather_channels_from_constituents(cells)\n\n        self.initialize()\n        self.init_syns()\n        self.initialized_conds = False\n\n    def __getattr__(self, key: str):\n        # Ensure that hidden methods such as `__deepcopy__` still work.\n        if key.startswith(\"__\"):\n            return super().__getattribute__(key)\n\n        if key == \"cell\":\n            view = deepcopy(self.nodes)\n            view[\"global_comp_index\"] = view[\"comp_index\"]\n            view[\"global_branch_index\"] = view[\"branch_index\"]\n            view[\"global_cell_index\"] = view[\"cell_index\"]\n            return CellView(self, view)\n        elif key in self.synapse_names:\n            type_index = self.synapse_names.index(key)\n            return SynapseView(self, self.edges, key, self.synapses[type_index])\n        elif key in self.group_nodes:\n            inds = self.group_nodes[key].index.values\n            view = self.nodes.loc[inds]\n            view[\"global_comp_index\"] = view[\"comp_index\"]\n            view[\"global_branch_index\"] = view[\"branch_index\"]\n            view[\"global_cell_index\"] = view[\"cell_index\"]\n            return GroupView(self, view, CellView, [\"cell\"])\n        else:\n            raise KeyError(f\"Key {key} not recognized.\")\n\n    def init_morph(self):\n        self.branchpoint_group_inds = build_branchpoint_group_inds(\n            len(self.par_inds),\n            self.child_belongs_to_branchpoint,\n            self.nseg,\n            self.total_nbranches,\n        )\n        self.children_in_level = merge_cells(\n            self.cumsum_nbranches,\n            self.cumsum_nbranchpoints,\n            [cell.children_in_level for cell in self.cells],\n            exclude_first=False,\n        )\n        self.parents_in_level = merge_cells(\n            self.cumsum_nbranches,\n            self.cumsum_nbranchpoints,\n            [cell.parents_in_level for cell in self.cells],\n            exclude_first=False,\n        )\n        self.initialized_morph = True\n\n    def init_conds(self, params: Dict) -> Dict[str, jnp.ndarray]:\n        \"\"\"Given an axial resisitivity, set the coupling conductances.\"\"\"\n        nbranches = self.total_nbranches\n        nseg = self.nseg\n        parents = self.comb_parents\n\n        axial_resistivity = jnp.reshape(params[\"axial_resistivity\"], (nbranches, nseg))\n        radiuses = jnp.reshape(params[\"radius\"], (nbranches, nseg))\n        lengths = jnp.reshape(params[\"length\"], (nbranches, nseg))\n\n        conds = vmap(Branch.init_branch_conds, in_axes=(0, 0, 0, None))(\n            axial_resistivity, radiuses, lengths, self.nseg\n        )\n        coupling_conds_fwd = conds[0]\n        coupling_conds_bwd = conds[1]\n        summed_coupling_conds = conds[2]\n\n        # The conductance from the children to the branch point.\n        branchpoint_conds_children = vmap(\n            compute_coupling_cond_branchpoint, in_axes=(0, 0, 0)\n        )(\n            radiuses[self.child_inds, 0],\n            axial_resistivity[self.child_inds, 0],\n            lengths[self.child_inds, 0],\n        )\n        # The conductance from the parents to the branch point.\n        branchpoint_conds_parents = vmap(\n            compute_coupling_cond_branchpoint, in_axes=(0, 0, 0)\n        )(\n            radiuses[self.par_inds, -1],\n            axial_resistivity[self.par_inds, -1],\n            lengths[self.par_inds, -1],\n        )\n\n        # Weights with which the compartments influence their nearby node.\n        # The impact of the children on the branch point.\n        branchpoint_weights_children = vmap(compute_impact_on_node, in_axes=(0, 0, 0))(\n            radiuses[self.child_inds, 0],\n            axial_resistivity[self.child_inds, 0],\n            lengths[self.child_inds, 0],\n        )\n        # The impact of parents on the branch point.\n        branchpoint_weights_parents = vmap(compute_impact_on_node, in_axes=(0, 0, 0))(\n            radiuses[self.par_inds, -1],\n            axial_resistivity[self.par_inds, -1],\n            lengths[self.par_inds, -1],\n        )\n\n        summed_coupling_conds = Cell.update_summed_coupling_conds(\n            summed_coupling_conds,\n            self.child_inds,\n            self.par_inds,\n            branchpoint_conds_children,\n            branchpoint_conds_parents,\n        )\n\n        cond_params = {\n            \"branch_uppers\": coupling_conds_bwd,\n            \"branch_lowers\": coupling_conds_fwd,\n            \"branch_diags\": summed_coupling_conds,\n            \"branchpoint_conds_children\": branchpoint_conds_children,\n            \"branchpoint_conds_parents\": branchpoint_conds_parents,\n            \"branchpoint_weights_children\": branchpoint_weights_children,\n            \"branchpoint_weights_parents\": branchpoint_weights_parents,\n        }\n        return cond_params\n\n    def init_syns(self):\n        \"\"\"Initialize synapses.\"\"\"\n        self.synapses = []\n\n        # TODO(@michaeldeistler): should we also track this for channels?\n        self.synapse_names = []\n        self.synapse_param_names = []\n        self.synapse_state_names = []\n\n        self.initialized_syns = True\n\n    def _step_synapse(\n        self,\n        states: Dict,\n        syn_channels: List,\n        params: Dict,\n        delta_t: float,\n        edges: pd.DataFrame,\n    ) -> Tuple[Dict, Tuple[jnp.ndarray, jnp.ndarray]]:\n        \"\"\"Perform one step of the synapses and obtain their currents.\"\"\"\n        states = self._step_synapse_state(states, syn_channels, params, delta_t, edges)\n        states, current_terms = self._synapse_currents(\n            states, syn_channels, params, delta_t, edges\n        )\n        return states, current_terms\n\n    def _step_synapse_state(\n        self,\n        states: Dict,\n        syn_channels: List,\n        params: Dict,\n        delta_t: float,\n        edges: pd.DataFrame,\n    ) -> Dict:\n        voltages = states[\"v\"]\n\n        grouped_syns = edges.groupby(\"type\", sort=False, group_keys=False)\n        pre_syn_inds = grouped_syns[\"global_pre_comp_index\"].apply(list)\n        post_syn_inds = grouped_syns[\"global_post_comp_index\"].apply(list)\n        synapse_names = list(grouped_syns.indices.keys())\n\n        for i, synapse_type in enumerate(syn_channels):\n            assert (\n                synapse_names[i] == synapse_type._name\n            ), \"Mixup in the ordering of synapses. Please create an issue on Github.\"\n            synapse_param_names = list(synapse_type.synapse_params.keys())\n            synapse_state_names = list(synapse_type.synapse_states.keys())\n\n            synapse_params = {}\n            for p in synapse_param_names:\n                synapse_params[p] = params[p]\n            synapse_states = {}\n            for s in synapse_state_names:\n                synapse_states[s] = states[s]\n\n            pre_inds = np.asarray(pre_syn_inds[synapse_names[i]])\n            post_inds = np.asarray(post_syn_inds[synapse_names[i]])\n\n            # State updates.\n            states_updated = synapse_type.update_states(\n                synapse_states,\n                delta_t,\n                voltages[pre_inds],\n                voltages[post_inds],\n                synapse_params,\n            )\n\n            # Rebuild state.\n            for key, val in states_updated.items():\n                states[key] = val\n\n        return states\n\n    def _synapse_currents(\n        self,\n        states: Dict,\n        syn_channels: List,\n        params: Dict,\n        delta_t: float,\n        edges: pd.DataFrame,\n    ) -> Tuple[Dict, Tuple[jnp.ndarray, jnp.ndarray]]:\n        voltages = states[\"v\"]\n\n        grouped_syns = edges.groupby(\"type\", sort=False, group_keys=False)\n        pre_syn_inds = grouped_syns[\"global_pre_comp_index\"].apply(list)\n        post_syn_inds = grouped_syns[\"global_post_comp_index\"].apply(list)\n        synapse_names = list(grouped_syns.indices.keys())\n\n        syn_voltage_terms = jnp.zeros_like(voltages)\n        syn_constant_terms = jnp.zeros_like(voltages)\n        # Run with two different voltages that are `diff` apart to infer the slope and\n        # offset.\n        diff = 1e-3\n        for i, synapse_type in enumerate(syn_channels):\n            assert (\n                synapse_names[i] == synapse_type._name\n            ), \"Mixup in the ordering of synapses. Please create an issue on Github.\"\n            synapse_param_names = list(synapse_type.synapse_params.keys())\n            synapse_state_names = list(synapse_type.synapse_states.keys())\n\n            synapse_params = {}\n            for p in synapse_param_names:\n                synapse_params[p] = params[p]\n            synapse_states = {}\n            for s in synapse_state_names:\n                synapse_states[s] = states[s]\n\n            # Get pre and post indexes of the current synapse type.\n            pre_inds = np.asarray(pre_syn_inds[synapse_names[i]])\n            post_inds = np.asarray(post_syn_inds[synapse_names[i]])\n\n            # Compute slope and offset of the current through every synapse.\n            pre_v_and_perturbed = jnp.stack(\n                [voltages[pre_inds], voltages[pre_inds] + diff]\n            )\n            post_v_and_perturbed = jnp.stack(\n                [voltages[post_inds], voltages[post_inds] + diff]\n            )\n            synapse_currents = vmap(\n                synapse_type.compute_current, in_axes=(None, 0, 0, None)\n            )(\n                synapse_states,\n                pre_v_and_perturbed,\n                post_v_and_perturbed,\n                synapse_params,\n            )\n            synapse_currents_dist = convert_point_process_to_distributed(\n                synapse_currents,\n                params[\"radius\"][post_inds],\n                params[\"length\"][post_inds],\n            )\n\n            # Split into voltage and constant terms.\n            voltage_term = (synapse_currents_dist[1] - synapse_currents_dist[0]) / diff\n            constant_term = (\n                synapse_currents_dist[0] - voltage_term * voltages[post_inds]\n            )\n\n            # Gather slope and offset for every postsynaptic compartment.\n            gathered_syn_currents = gather_synapes(\n                len(voltages),\n                post_inds,\n                voltage_term,\n                constant_term,\n            )\n            syn_voltage_terms += gathered_syn_currents[0]\n            syn_constant_terms -= gathered_syn_currents[1]\n\n            # Add the synaptic currents through every compartment as state.\n            # `post_syn_currents` is a `jnp.ndarray` of as many elements as there are\n            # compartments in the network.\n            # `[0]` because we only use the non-perturbed voltage.\n            states[f\"{synapse_type._name}_current\"] = synapse_currents[0]\n\n        return states, (syn_voltage_terms, syn_constant_terms)\n\n    def vis(\n        self,\n        detail: str = \"full\",\n        ax: Optional[Axes] = None,\n        col: str = \"k\",\n        synapse_col: str = \"b\",\n        dims: Tuple[int] = (0, 1),\n        type: str = \"line\",\n        layers: Optional[List] = None,\n        morph_plot_kwargs: Dict = {},\n        synapse_plot_kwargs: Dict = {},\n        synapse_scatter_kwargs: Dict = {},\n        networkx_options: Dict = {},\n        layer_kwargs: Dict = {},\n    ) -> Axes:\n        \"\"\"Visualize the module.\n\n        Args:\n            detail: Either of [point, full]. `point` visualizes every neuron in the\n                network as a dot (and it uses `networkx` to obtain cell positions).\n                `full` plots the full morphology of every neuron. It requires that\n                `compute_xyz()` has been run and allows for indivual neurons to be\n                moved with `.move()`.\n            col: The color in which cells are plotted. Only takes effect if\n                `detail='full'`.\n            type: Either `line` or `scatter`. Only takes effect if `detail='full'`.\n            synapse_col: The color in which synapses are plotted. Only takes effect if\n                `detail='full'`.\n            dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of\n                two of them.\n            layers: Allows to plot the network in layers. Should provide the number of\n                neurons in each layer, e.g., [5, 10, 1] would be a network with 5 input\n                neurons, 10 hidden layer neurons, and 1 output neuron.\n            morph_plot_kwargs: Keyword arguments passed to the plotting function for\n                cell morphologies. Only takes effect for `detail='full'`.\n            synapse_plot_kwargs: Keyword arguments passed to the plotting function for\n                syanpses. Only takes effect for `detail='full'`.\n            synapse_scatter_kwargs: Keyword arguments passed to the scatter function\n                for the end point of synapses. Only takes effect for `detail='full'`.\n            networkx_options: Options passed to `networkx.draw()`. Only takes effect if\n                `detail='point'`.\n            layer_kwargs: Only used if `layers` is specified and if `detail='full'`.\n                Can have the following entries: `within_layer_offset` (float),\n                `between_layer_offset` (float), `vertical_layers` (bool).\n        \"\"\"\n        if detail == \"point\":\n            graph = self._build_graph(layers)\n\n            if layers is not None:\n                pos = nx.multipartite_layout(graph, subset_key=\"layer\")\n                nx.draw(graph, pos, with_labels=True, **networkx_options)\n            else:\n                nx.draw(graph, with_labels=True, **networkx_options)\n        elif detail == \"full\":\n            if layers is not None:\n                # Assemble cells in the network into layers.\n                global_counter = 0\n                layers_config = {\n                    \"within_layer_offset\": 500.0,\n                    \"between_layer_offset\": 1500.0,\n                    \"vertical_layers\": False,\n                }\n                layers_config.update(layer_kwargs)\n                for layer_ind, num_in_layer in enumerate(layers):\n                    for ind_within_layer in range(num_in_layer):\n                        if layers_config[\"vertical_layers\"]:\n                            x_offset = (\n                                ind_within_layer - (num_in_layer - 1) / 2\n                            ) * layers_config[\"within_layer_offset\"]\n                            y_offset = (len(layers) - 1 - layer_ind) * layers_config[\n                                \"between_layer_offset\"\n                            ]\n                        else:\n                            x_offset = layer_ind * layers_config[\"between_layer_offset\"]\n                            y_offset = (\n                                ind_within_layer - (num_in_layer - 1) / 2\n                            ) * layers_config[\"within_layer_offset\"]\n\n                        self.cell(global_counter).move_to(x=x_offset, y=y_offset, z=0)\n                        global_counter += 1\n            ax = self._vis(\n                dims=dims,\n                col=col,\n                ax=ax,\n                type=type,\n                view=self.nodes,\n                morph_plot_kwargs=morph_plot_kwargs,\n            )\n\n            pre_locs = self.edges[\"pre_locs\"].to_numpy()\n            post_locs = self.edges[\"post_locs\"].to_numpy()\n            pre_branch = self.edges[\"global_pre_branch_index\"].to_numpy()\n            post_branch = self.edges[\"global_post_branch_index\"].to_numpy()\n\n            dims_np = np.asarray(dims)\n\n            for pre_loc, post_loc, pre_b, post_b in zip(\n                pre_locs, post_locs, pre_branch, post_branch\n            ):\n                pre_coord = self.xyzr[pre_b]\n                if len(pre_coord) == 2:\n                    # If only start and end point of a branch are traced, perform a\n                    # linear interpolation to get the synpase location.\n                    pre_coord = pre_coord[0] + (pre_coord[1] - pre_coord[0]) * pre_loc\n                else:\n                    # If densely traced, use intermediate trace values for synapse loc.\n                    middle_ind = int((len(pre_coord) - 1) * pre_loc)\n                    pre_coord = pre_coord[middle_ind]\n\n                post_coord = self.xyzr[post_b]\n                if len(post_coord) == 2:\n                    # If only start and end point of a branch are traced, perform a\n                    # linear interpolation to get the synpase location.\n                    post_coord = (\n                        post_coord[0] + (post_coord[1] - post_coord[0]) * post_loc\n                    )\n                else:\n                    # If densely traced, use intermediate trace values for synapse loc.\n                    middle_ind = int((len(post_coord) - 1) * post_loc)\n                    post_coord = post_coord[middle_ind]\n\n                coords = np.stack([pre_coord[dims_np], post_coord[dims_np]]).T\n                ax.plot(\n                    coords[0],\n                    coords[1],\n                    c=synapse_col,\n                    **synapse_plot_kwargs,\n                )\n                ax.scatter(\n                    post_coord[dims_np[0]],\n                    post_coord[dims_np[1]],\n                    c=synapse_col,\n                    **synapse_scatter_kwargs,\n                )\n        else:\n            raise ValueError(\"detail must be in {full, point}.\")\n\n        return ax\n\n    def _build_graph(self, layers: Optional[List] = None, **options):\n        graph = nx.DiGraph()\n\n        def build_extents(*subset_sizes):\n            return nx.utils.pairwise(itertools.accumulate((0,) + subset_sizes))\n\n        if layers is not None:\n            extents = build_extents(*layers)\n            layers = [range(start, end) for start, end in extents]\n            for i, layer in enumerate(layers):\n                graph.add_nodes_from(layer, layer=i)\n        else:\n            graph.add_nodes_from(range(len(self.cells)))\n\n        pre_cell = self.edges[\"pre_cell_index\"].to_numpy()\n        post_cell = self.edges[\"post_cell_index\"].to_numpy()\n\n        inds = np.stack([pre_cell, post_cell]).T\n        graph.add_edges_from(inds)\n\n        return graph\n
    "},{"location":"reference/modules/#jaxley.modules.network.Network.__init__","title":"__init__(cells)","text":"

    Initialize network of cells and synapses.

    Parameters:

    Name Type Description Default cells List[Cell]

    A list of cells that make up the network.

    required Source code in jaxley/modules/network.py
    def __init__(\n    self,\n    cells: List[Cell],\n):\n    \"\"\"Initialize network of cells and synapses.\n\n    Args:\n        cells: A list of cells that make up the network.\n    \"\"\"\n    super().__init__()\n    for cell in cells:\n        self.xyzr += deepcopy(cell.xyzr)\n\n    self.cells = cells\n    self.nseg = cells[0].nseg\n    self._append_params_and_states(self.network_params, self.network_states)\n\n    self.nbranches_per_cell = [cell.total_nbranches for cell in self.cells]\n    self.nbranchpoints_per_cell = [cell.total_nbranchpoints for cell in self.cells]\n    self.total_nbranches = sum(self.nbranches_per_cell)\n    self.cumsum_nbranches = jnp.cumsum(jnp.asarray([0] + self.nbranches_per_cell))\n    self.cumsum_nbranchpoints = jnp.cumsum(\n        jnp.asarray([0] + self.nbranchpoints_per_cell)\n    )\n\n    self.nodes = pd.concat([c.nodes for c in cells], ignore_index=True)\n    self.nodes[\"comp_index\"] = np.arange(self.nseg * self.total_nbranches).tolist()\n    self.nodes[\"branch_index\"] = (\n        np.arange(self.nseg * self.total_nbranches) // self.nseg\n    ).tolist()\n    self.nodes[\"cell_index\"] = list(\n        itertools.chain(\n            *[[i] * (self.nseg * b) for i, b in enumerate(self.nbranches_per_cell)]\n        )\n    )\n\n    parents = [cell.comb_parents for cell in self.cells]\n    self.comb_parents = jnp.concatenate(\n        [p.at[1:].add(self.cumsum_nbranches[i]) for i, p in enumerate(parents)]\n    )\n\n    # Two columns: `parent_branch_index` and `child_branch_index`. One row per\n    # branch, apart from those branches which do not have a parent (i.e.\n    # -1 in parents). For every branch, tracks the global index of that branch\n    # (`child_branch_index`) and the global index of its parent\n    # (`parent_branch_index`).\n    self.branch_edges = pd.DataFrame(\n        dict(\n            parent_branch_index=self.comb_parents[self.comb_parents != -1],\n            child_branch_index=np.where(self.comb_parents != -1)[0],\n        )\n    )\n\n    # For morphology indexing.\n    par_inds = self.branch_edges[\"parent_branch_index\"].to_numpy()\n    self.child_inds = self.branch_edges[\"child_branch_index\"].to_numpy()\n    self.child_belongs_to_branchpoint = remap_to_consecutive(par_inds)\n    self.par_inds = np.unique(par_inds)  # TODO: does order have to be preserved?\n    self.total_nbranchpoints = len(self.par_inds)\n    self.root_inds = self.cumsum_nbranches[:-1]\n\n    # Channels.\n    self._gather_channels_from_constituents(cells)\n\n    self.initialize()\n    self.init_syns()\n    self.initialized_conds = False\n
    "},{"location":"reference/modules/#jaxley.modules.network.Network.init_conds","title":"init_conds(params)","text":"

    Given an axial resisitivity, set the coupling conductances.

    Source code in jaxley/modules/network.py
    def init_conds(self, params: Dict) -> Dict[str, jnp.ndarray]:\n    \"\"\"Given an axial resisitivity, set the coupling conductances.\"\"\"\n    nbranches = self.total_nbranches\n    nseg = self.nseg\n    parents = self.comb_parents\n\n    axial_resistivity = jnp.reshape(params[\"axial_resistivity\"], (nbranches, nseg))\n    radiuses = jnp.reshape(params[\"radius\"], (nbranches, nseg))\n    lengths = jnp.reshape(params[\"length\"], (nbranches, nseg))\n\n    conds = vmap(Branch.init_branch_conds, in_axes=(0, 0, 0, None))(\n        axial_resistivity, radiuses, lengths, self.nseg\n    )\n    coupling_conds_fwd = conds[0]\n    coupling_conds_bwd = conds[1]\n    summed_coupling_conds = conds[2]\n\n    # The conductance from the children to the branch point.\n    branchpoint_conds_children = vmap(\n        compute_coupling_cond_branchpoint, in_axes=(0, 0, 0)\n    )(\n        radiuses[self.child_inds, 0],\n        axial_resistivity[self.child_inds, 0],\n        lengths[self.child_inds, 0],\n    )\n    # The conductance from the parents to the branch point.\n    branchpoint_conds_parents = vmap(\n        compute_coupling_cond_branchpoint, in_axes=(0, 0, 0)\n    )(\n        radiuses[self.par_inds, -1],\n        axial_resistivity[self.par_inds, -1],\n        lengths[self.par_inds, -1],\n    )\n\n    # Weights with which the compartments influence their nearby node.\n    # The impact of the children on the branch point.\n    branchpoint_weights_children = vmap(compute_impact_on_node, in_axes=(0, 0, 0))(\n        radiuses[self.child_inds, 0],\n        axial_resistivity[self.child_inds, 0],\n        lengths[self.child_inds, 0],\n    )\n    # The impact of parents on the branch point.\n    branchpoint_weights_parents = vmap(compute_impact_on_node, in_axes=(0, 0, 0))(\n        radiuses[self.par_inds, -1],\n        axial_resistivity[self.par_inds, -1],\n        lengths[self.par_inds, -1],\n    )\n\n    summed_coupling_conds = Cell.update_summed_coupling_conds(\n        summed_coupling_conds,\n        self.child_inds,\n        self.par_inds,\n        branchpoint_conds_children,\n        branchpoint_conds_parents,\n    )\n\n    cond_params = {\n        \"branch_uppers\": coupling_conds_bwd,\n        \"branch_lowers\": coupling_conds_fwd,\n        \"branch_diags\": summed_coupling_conds,\n        \"branchpoint_conds_children\": branchpoint_conds_children,\n        \"branchpoint_conds_parents\": branchpoint_conds_parents,\n        \"branchpoint_weights_children\": branchpoint_weights_children,\n        \"branchpoint_weights_parents\": branchpoint_weights_parents,\n    }\n    return cond_params\n
    "},{"location":"reference/modules/#jaxley.modules.network.Network.init_syns","title":"init_syns()","text":"

    Initialize synapses.

    Source code in jaxley/modules/network.py
    def init_syns(self):\n    \"\"\"Initialize synapses.\"\"\"\n    self.synapses = []\n\n    # TODO(@michaeldeistler): should we also track this for channels?\n    self.synapse_names = []\n    self.synapse_param_names = []\n    self.synapse_state_names = []\n\n    self.initialized_syns = True\n
    "},{"location":"reference/modules/#jaxley.modules.network.Network.vis","title":"vis(detail='full', ax=None, col='k', synapse_col='b', dims=(0, 1), type='line', layers=None, morph_plot_kwargs={}, synapse_plot_kwargs={}, synapse_scatter_kwargs={}, networkx_options={}, layer_kwargs={})","text":"

    Visualize the module.

    Parameters:

    Name Type Description Default detail str

    Either of [point, full]. point visualizes every neuron in the network as a dot (and it uses networkx to obtain cell positions). full plots the full morphology of every neuron. It requires that compute_xyz() has been run and allows for indivual neurons to be moved with .move().

    'full' col str

    The color in which cells are plotted. Only takes effect if detail='full'.

    'k' type str

    Either line or scatter. Only takes effect if detail='full'.

    'line' synapse_col str

    The color in which synapses are plotted. Only takes effect if detail='full'.

    'b' dims Tuple[int]

    Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of two of them.

    (0, 1) layers Optional[List]

    Allows to plot the network in layers. Should provide the number of neurons in each layer, e.g., [5, 10, 1] would be a network with 5 input neurons, 10 hidden layer neurons, and 1 output neuron.

    None morph_plot_kwargs Dict

    Keyword arguments passed to the plotting function for cell morphologies. Only takes effect for detail='full'.

    {} synapse_plot_kwargs Dict

    Keyword arguments passed to the plotting function for syanpses. Only takes effect for detail='full'.

    {} synapse_scatter_kwargs Dict

    Keyword arguments passed to the scatter function for the end point of synapses. Only takes effect for detail='full'.

    {} networkx_options Dict

    Options passed to networkx.draw(). Only takes effect if detail='point'.

    {} layer_kwargs Dict

    Only used if layers is specified and if detail='full'. Can have the following entries: within_layer_offset (float), between_layer_offset (float), vertical_layers (bool).

    {} Source code in jaxley/modules/network.py
    def vis(\n    self,\n    detail: str = \"full\",\n    ax: Optional[Axes] = None,\n    col: str = \"k\",\n    synapse_col: str = \"b\",\n    dims: Tuple[int] = (0, 1),\n    type: str = \"line\",\n    layers: Optional[List] = None,\n    morph_plot_kwargs: Dict = {},\n    synapse_plot_kwargs: Dict = {},\n    synapse_scatter_kwargs: Dict = {},\n    networkx_options: Dict = {},\n    layer_kwargs: Dict = {},\n) -> Axes:\n    \"\"\"Visualize the module.\n\n    Args:\n        detail: Either of [point, full]. `point` visualizes every neuron in the\n            network as a dot (and it uses `networkx` to obtain cell positions).\n            `full` plots the full morphology of every neuron. It requires that\n            `compute_xyz()` has been run and allows for indivual neurons to be\n            moved with `.move()`.\n        col: The color in which cells are plotted. Only takes effect if\n            `detail='full'`.\n        type: Either `line` or `scatter`. Only takes effect if `detail='full'`.\n        synapse_col: The color in which synapses are plotted. Only takes effect if\n            `detail='full'`.\n        dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of\n            two of them.\n        layers: Allows to plot the network in layers. Should provide the number of\n            neurons in each layer, e.g., [5, 10, 1] would be a network with 5 input\n            neurons, 10 hidden layer neurons, and 1 output neuron.\n        morph_plot_kwargs: Keyword arguments passed to the plotting function for\n            cell morphologies. Only takes effect for `detail='full'`.\n        synapse_plot_kwargs: Keyword arguments passed to the plotting function for\n            syanpses. Only takes effect for `detail='full'`.\n        synapse_scatter_kwargs: Keyword arguments passed to the scatter function\n            for the end point of synapses. Only takes effect for `detail='full'`.\n        networkx_options: Options passed to `networkx.draw()`. Only takes effect if\n            `detail='point'`.\n        layer_kwargs: Only used if `layers` is specified and if `detail='full'`.\n            Can have the following entries: `within_layer_offset` (float),\n            `between_layer_offset` (float), `vertical_layers` (bool).\n    \"\"\"\n    if detail == \"point\":\n        graph = self._build_graph(layers)\n\n        if layers is not None:\n            pos = nx.multipartite_layout(graph, subset_key=\"layer\")\n            nx.draw(graph, pos, with_labels=True, **networkx_options)\n        else:\n            nx.draw(graph, with_labels=True, **networkx_options)\n    elif detail == \"full\":\n        if layers is not None:\n            # Assemble cells in the network into layers.\n            global_counter = 0\n            layers_config = {\n                \"within_layer_offset\": 500.0,\n                \"between_layer_offset\": 1500.0,\n                \"vertical_layers\": False,\n            }\n            layers_config.update(layer_kwargs)\n            for layer_ind, num_in_layer in enumerate(layers):\n                for ind_within_layer in range(num_in_layer):\n                    if layers_config[\"vertical_layers\"]:\n                        x_offset = (\n                            ind_within_layer - (num_in_layer - 1) / 2\n                        ) * layers_config[\"within_layer_offset\"]\n                        y_offset = (len(layers) - 1 - layer_ind) * layers_config[\n                            \"between_layer_offset\"\n                        ]\n                    else:\n                        x_offset = layer_ind * layers_config[\"between_layer_offset\"]\n                        y_offset = (\n                            ind_within_layer - (num_in_layer - 1) / 2\n                        ) * layers_config[\"within_layer_offset\"]\n\n                    self.cell(global_counter).move_to(x=x_offset, y=y_offset, z=0)\n                    global_counter += 1\n        ax = self._vis(\n            dims=dims,\n            col=col,\n            ax=ax,\n            type=type,\n            view=self.nodes,\n            morph_plot_kwargs=morph_plot_kwargs,\n        )\n\n        pre_locs = self.edges[\"pre_locs\"].to_numpy()\n        post_locs = self.edges[\"post_locs\"].to_numpy()\n        pre_branch = self.edges[\"global_pre_branch_index\"].to_numpy()\n        post_branch = self.edges[\"global_post_branch_index\"].to_numpy()\n\n        dims_np = np.asarray(dims)\n\n        for pre_loc, post_loc, pre_b, post_b in zip(\n            pre_locs, post_locs, pre_branch, post_branch\n        ):\n            pre_coord = self.xyzr[pre_b]\n            if len(pre_coord) == 2:\n                # If only start and end point of a branch are traced, perform a\n                # linear interpolation to get the synpase location.\n                pre_coord = pre_coord[0] + (pre_coord[1] - pre_coord[0]) * pre_loc\n            else:\n                # If densely traced, use intermediate trace values for synapse loc.\n                middle_ind = int((len(pre_coord) - 1) * pre_loc)\n                pre_coord = pre_coord[middle_ind]\n\n            post_coord = self.xyzr[post_b]\n            if len(post_coord) == 2:\n                # If only start and end point of a branch are traced, perform a\n                # linear interpolation to get the synpase location.\n                post_coord = (\n                    post_coord[0] + (post_coord[1] - post_coord[0]) * post_loc\n                )\n            else:\n                # If densely traced, use intermediate trace values for synapse loc.\n                middle_ind = int((len(post_coord) - 1) * post_loc)\n                post_coord = post_coord[middle_ind]\n\n            coords = np.stack([pre_coord[dims_np], post_coord[dims_np]]).T\n            ax.plot(\n                coords[0],\n                coords[1],\n                c=synapse_col,\n                **synapse_plot_kwargs,\n            )\n            ax.scatter(\n                post_coord[dims_np[0]],\n                post_coord[dims_np[1]],\n                c=synapse_col,\n                **synapse_scatter_kwargs,\n            )\n    else:\n        raise ValueError(\"detail must be in {full, point}.\")\n\n    return ax\n
    "},{"location":"reference/optimize/","title":"Optimization","text":""},{"location":"reference/optimize/#jaxley.optimize.optimizer.TypeOptimizer","title":"TypeOptimizer","text":"

    optax wrapper which allows different argument values for different params.

    Source code in jaxley/optimize/optimizer.py
    class TypeOptimizer:\n    \"\"\"`optax` wrapper which allows different argument values for different params.\"\"\"\n\n    def __init__(\n        self,\n        optimizer: Callable,\n        optimizer_args: Dict[str, Any],\n        opt_params: List[Dict[str, jnp.ndarray]],\n    ):\n        \"\"\"Create the optimizers.\n\n        This requires access to `opt_params` in order to know how many optimizers\n        should be created. It creates `len(opt_params)` optimizers.\n\n        Example usage:\n        ```\n        lrs = {\"HH_gNa\": 0.01, \"radius\": 1.0}\n        optimizer = TypeOptimizer(lambda lr: optax.adam(lr), lrs, opt_params)\n        opt_state = optimizer.init(opt_params)\n        ```\n\n        ```\n        optimizer_args = {\"HH_gNa\": [0.01, 0.4], \"radius\": [1.0, 0.8]}\n        optimizer = TypeOptimizer(\n            lambda args: optax.sgd(args[0], momentum=args[1]),\n            optimizer_args,\n            opt_params\n        )\n        opt_state = optimizer.init(opt_params)\n        ```\n\n        Args:\n            optimizer: A Callable that takes the learning rate and returns the\n                `optax.optimizer` which should be used.\n            optimizer_args: The arguments for different kinds of parameters.\n                Each item of the dictionary will be passed to the `Callable` passed to\n                `optimizer`.\n            opt_params: The parameters to be optimized. The exact values are not used,\n                only the number of elements in the list and the key of each dict.\n        \"\"\"\n        self.base_optimizer = optimizer\n\n        self.optimizers = []\n        for params in opt_params:\n            names = list(params.keys())\n            assert len(names) == 1, \"Multiple parameters were added at once.\"\n            name = names[0]\n            optimizer = self.base_optimizer(optimizer_args[name])\n            self.optimizers.append({name: optimizer})\n\n    def init(self, opt_params: List[Dict[str, jnp.ndarray]]) -> List:\n        \"\"\"Initialize the optimizers. Equivalent to `optax.optimizers.init()`.\"\"\"\n        opt_states = []\n        for params, optimizer in zip(opt_params, self.optimizers):\n            name = list(optimizer.keys())[0]\n            opt_state = optimizer[name].init(params)\n            opt_states.append(opt_state)\n        return opt_states\n\n    def update(self, gradient: jnp.ndarray, opt_state: List) -> Tuple[List, List]:\n        \"\"\"Update the optimizers. Equivalent to `optax.optimizers.update()`.\"\"\"\n        all_updates = []\n        new_opt_states = []\n        for grad, state, opt in zip(gradient, opt_state, self.optimizers):\n            name = list(opt.keys())[0]\n            updates, new_opt_state = opt[name].update(grad, state)\n            all_updates.append(updates)\n            new_opt_states.append(new_opt_state)\n        return all_updates, new_opt_states\n
    "},{"location":"reference/optimize/#jaxley.optimize.optimizer.TypeOptimizer.__init__","title":"__init__(optimizer, optimizer_args, opt_params)","text":"

    Create the optimizers.

    This requires access to opt_params in order to know how many optimizers should be created. It creates len(opt_params) optimizers.

    Example usage:

    lrs = {\"HH_gNa\": 0.01, \"radius\": 1.0}\noptimizer = TypeOptimizer(lambda lr: optax.adam(lr), lrs, opt_params)\nopt_state = optimizer.init(opt_params)\n

    optimizer_args = {\"HH_gNa\": [0.01, 0.4], \"radius\": [1.0, 0.8]}\noptimizer = TypeOptimizer(\n    lambda args: optax.sgd(args[0], momentum=args[1]),\n    optimizer_args,\n    opt_params\n)\nopt_state = optimizer.init(opt_params)\n

    Parameters:

    Name Type Description Default optimizer Callable

    A Callable that takes the learning rate and returns the optax.optimizer which should be used.

    required optimizer_args Dict[str, Any]

    The arguments for different kinds of parameters. Each item of the dictionary will be passed to the Callable passed to optimizer.

    required opt_params List[Dict[str, ndarray]]

    The parameters to be optimized. The exact values are not used, only the number of elements in the list and the key of each dict.

    required Source code in jaxley/optimize/optimizer.py
    def __init__(\n    self,\n    optimizer: Callable,\n    optimizer_args: Dict[str, Any],\n    opt_params: List[Dict[str, jnp.ndarray]],\n):\n    \"\"\"Create the optimizers.\n\n    This requires access to `opt_params` in order to know how many optimizers\n    should be created. It creates `len(opt_params)` optimizers.\n\n    Example usage:\n    ```\n    lrs = {\"HH_gNa\": 0.01, \"radius\": 1.0}\n    optimizer = TypeOptimizer(lambda lr: optax.adam(lr), lrs, opt_params)\n    opt_state = optimizer.init(opt_params)\n    ```\n\n    ```\n    optimizer_args = {\"HH_gNa\": [0.01, 0.4], \"radius\": [1.0, 0.8]}\n    optimizer = TypeOptimizer(\n        lambda args: optax.sgd(args[0], momentum=args[1]),\n        optimizer_args,\n        opt_params\n    )\n    opt_state = optimizer.init(opt_params)\n    ```\n\n    Args:\n        optimizer: A Callable that takes the learning rate and returns the\n            `optax.optimizer` which should be used.\n        optimizer_args: The arguments for different kinds of parameters.\n            Each item of the dictionary will be passed to the `Callable` passed to\n            `optimizer`.\n        opt_params: The parameters to be optimized. The exact values are not used,\n            only the number of elements in the list and the key of each dict.\n    \"\"\"\n    self.base_optimizer = optimizer\n\n    self.optimizers = []\n    for params in opt_params:\n        names = list(params.keys())\n        assert len(names) == 1, \"Multiple parameters were added at once.\"\n        name = names[0]\n        optimizer = self.base_optimizer(optimizer_args[name])\n        self.optimizers.append({name: optimizer})\n
    "},{"location":"reference/optimize/#jaxley.optimize.optimizer.TypeOptimizer.init","title":"init(opt_params)","text":"

    Initialize the optimizers. Equivalent to optax.optimizers.init().

    Source code in jaxley/optimize/optimizer.py
    def init(self, opt_params: List[Dict[str, jnp.ndarray]]) -> List:\n    \"\"\"Initialize the optimizers. Equivalent to `optax.optimizers.init()`.\"\"\"\n    opt_states = []\n    for params, optimizer in zip(opt_params, self.optimizers):\n        name = list(optimizer.keys())[0]\n        opt_state = optimizer[name].init(params)\n        opt_states.append(opt_state)\n    return opt_states\n
    "},{"location":"reference/optimize/#jaxley.optimize.optimizer.TypeOptimizer.update","title":"update(gradient, opt_state)","text":"

    Update the optimizers. Equivalent to optax.optimizers.update().

    Source code in jaxley/optimize/optimizer.py
    def update(self, gradient: jnp.ndarray, opt_state: List) -> Tuple[List, List]:\n    \"\"\"Update the optimizers. Equivalent to `optax.optimizers.update()`.\"\"\"\n    all_updates = []\n    new_opt_states = []\n    for grad, state, opt in zip(gradient, opt_state, self.optimizers):\n        name = list(opt.keys())[0]\n        updates, new_opt_state = opt[name].update(grad, state)\n        all_updates.append(updates)\n        new_opt_states.append(new_opt_state)\n    return all_updates, new_opt_states\n
    "},{"location":"reference/optimize/#jaxley.optimize.transforms.ParamTransform","title":"ParamTransform","text":"

    Parameter transformation utility.

    This class is used to transform parameters from an unconstrained space to a constrained space and back. If the range is bounded both from above and below, we use the sigmoid function to transform the parameters. If the range is only bounded from below or above, we use softplus.

    Attributes:

    Name Type Description lowers

    A dictionary of lower bounds for each parameter (None for no bound).

    uppers

    A dictionary of upper bounds for each parameter (None for no bound).

    Source code in jaxley/optimize/transforms.py
    class ParamTransform:\n    \"\"\"Parameter transformation utility.\n\n    This class is used to transform parameters from an unconstrained space to a constrained space\n    and back. If the range is bounded both from above and below, we use the sigmoid function to\n    transform the parameters. If the range is only bounded from below or above, we use softplus.\n\n    Attributes:\n        lowers: A dictionary of lower bounds for each parameter (None for no bound).\n        uppers: A dictionary of upper bounds for each parameter (None for no bound).\n\n    \"\"\"\n\n    def __init__(self, lowers: Dict[str, float], uppers: Dict[str, float]):\n        \"\"\"Initialize the bounds.\n\n        Args:\n            lowers: A dictionary of lower bounds for each parameter (None for no bound).\n            uppers: A dictionary of upper bounds for each parameter (None for no bound).\n        \"\"\"\n\n        self.lowers = lowers\n        self.uppers = uppers\n\n    def forward(self, params: List[Dict[str, jnp.ndarray]]) -> jnp.ndarray:\n        \"\"\"Pushes unconstrained parameters through a tf such that they fit the interval.\n\n        Args:\n            params: A list of dictionaries with unconstrained parameters.\n\n        Returns:\n            A list of dictionaries with transformed parameters.\n\n        \"\"\"\n\n        tf_params = []\n        for param in params:\n            key = list(param.keys())[0]\n\n            # If constrained from below and above, use sigmoid\n            if self.lowers[key] is not None and self.uppers[key] is not None:\n                tf = (\n                    sigmoid(param[key]) * (self.uppers[key] - self.lowers[key])\n                    + self.lowers[key]\n                )\n                tf_params.append({key: tf})\n\n            # If constrained from below, use softplus\n            elif self.lowers[key] is not None:\n                tf = softplus(param[key]) + self.lowers[key]\n                tf_params.append({key: tf})\n\n            # If constrained from above, use negative softplus\n            elif self.uppers[key] is not None:\n                tf = -softplus(-param[key]) + self.uppers[key]\n                tf_params.append({key: tf})\n\n            # Else just pass through\n            else:\n                tf_params.append({key: param[key]})\n\n        return tf_params\n\n    def inverse(self, params: jnp.ndarray) -> jnp.ndarray:\n        \"\"\"Takes parameters from within the interval and makes them unconstrained.\n\n        Args:\n            params: A list of dictionaries with transformed parameters.\n\n        Returns:\n            A list of dictionaries with unconstrained parameters.\n        \"\"\"\n\n        tf_params = []\n        for param in params:\n            key = list(param.keys())[0]\n\n            # If constrained from below and above, use expit\n            if self.lowers[key] is not None and self.uppers[key] is not None:\n                tf = expit(\n                    (param[key] - self.lowers[key])\n                    / (self.uppers[key] - self.lowers[key])\n                )\n                tf_params.append({key: tf})\n\n            # If constrained from below, use inv_softplus\n            elif self.lowers[key] is not None:\n                tf = inv_softplus(param[key] - self.lowers[key])\n                tf_params.append({key: tf})\n\n            # If constrained from above, use negative inv_softplus\n            elif self.uppers[key] is not None:\n                tf = -inv_softplus(-(param[key] - self.uppers[key]))\n                tf_params.append({key: tf})\n\n            # else just pass through\n            else:\n                tf_params.append({key: param[key]})\n\n        return tf_params\n
    "},{"location":"reference/optimize/#jaxley.optimize.transforms.ParamTransform.__init__","title":"__init__(lowers, uppers)","text":"

    Initialize the bounds.

    Parameters:

    Name Type Description Default lowers Dict[str, float]

    A dictionary of lower bounds for each parameter (None for no bound).

    required uppers Dict[str, float]

    A dictionary of upper bounds for each parameter (None for no bound).

    required Source code in jaxley/optimize/transforms.py
    def __init__(self, lowers: Dict[str, float], uppers: Dict[str, float]):\n    \"\"\"Initialize the bounds.\n\n    Args:\n        lowers: A dictionary of lower bounds for each parameter (None for no bound).\n        uppers: A dictionary of upper bounds for each parameter (None for no bound).\n    \"\"\"\n\n    self.lowers = lowers\n    self.uppers = uppers\n
    "},{"location":"reference/optimize/#jaxley.optimize.transforms.ParamTransform.forward","title":"forward(params)","text":"

    Pushes unconstrained parameters through a tf such that they fit the interval.

    Parameters:

    Name Type Description Default params List[Dict[str, ndarray]]

    A list of dictionaries with unconstrained parameters.

    required

    Returns:

    Type Description ndarray

    A list of dictionaries with transformed parameters.

    Source code in jaxley/optimize/transforms.py
    def forward(self, params: List[Dict[str, jnp.ndarray]]) -> jnp.ndarray:\n    \"\"\"Pushes unconstrained parameters through a tf such that they fit the interval.\n\n    Args:\n        params: A list of dictionaries with unconstrained parameters.\n\n    Returns:\n        A list of dictionaries with transformed parameters.\n\n    \"\"\"\n\n    tf_params = []\n    for param in params:\n        key = list(param.keys())[0]\n\n        # If constrained from below and above, use sigmoid\n        if self.lowers[key] is not None and self.uppers[key] is not None:\n            tf = (\n                sigmoid(param[key]) * (self.uppers[key] - self.lowers[key])\n                + self.lowers[key]\n            )\n            tf_params.append({key: tf})\n\n        # If constrained from below, use softplus\n        elif self.lowers[key] is not None:\n            tf = softplus(param[key]) + self.lowers[key]\n            tf_params.append({key: tf})\n\n        # If constrained from above, use negative softplus\n        elif self.uppers[key] is not None:\n            tf = -softplus(-param[key]) + self.uppers[key]\n            tf_params.append({key: tf})\n\n        # Else just pass through\n        else:\n            tf_params.append({key: param[key]})\n\n    return tf_params\n
    "},{"location":"reference/optimize/#jaxley.optimize.transforms.ParamTransform.inverse","title":"inverse(params)","text":"

    Takes parameters from within the interval and makes them unconstrained.

    Parameters:

    Name Type Description Default params ndarray

    A list of dictionaries with transformed parameters.

    required

    Returns:

    Type Description ndarray

    A list of dictionaries with unconstrained parameters.

    Source code in jaxley/optimize/transforms.py
    def inverse(self, params: jnp.ndarray) -> jnp.ndarray:\n    \"\"\"Takes parameters from within the interval and makes them unconstrained.\n\n    Args:\n        params: A list of dictionaries with transformed parameters.\n\n    Returns:\n        A list of dictionaries with unconstrained parameters.\n    \"\"\"\n\n    tf_params = []\n    for param in params:\n        key = list(param.keys())[0]\n\n        # If constrained from below and above, use expit\n        if self.lowers[key] is not None and self.uppers[key] is not None:\n            tf = expit(\n                (param[key] - self.lowers[key])\n                / (self.uppers[key] - self.lowers[key])\n            )\n            tf_params.append({key: tf})\n\n        # If constrained from below, use inv_softplus\n        elif self.lowers[key] is not None:\n            tf = inv_softplus(param[key] - self.lowers[key])\n            tf_params.append({key: tf})\n\n        # If constrained from above, use negative inv_softplus\n        elif self.uppers[key] is not None:\n            tf = -inv_softplus(-(param[key] - self.uppers[key]))\n            tf_params.append({key: tf})\n\n        # else just pass through\n        else:\n            tf_params.append({key: param[key]})\n\n    return tf_params\n
    "},{"location":"reference/optimize/#jaxley.optimize.transforms.expit","title":"expit(x)","text":"

    Inverse sigmoid (expit)

    Source code in jaxley/optimize/transforms.py
    def expit(x: jnp.ndarray) -> jnp.ndarray:\n    \"\"\"Inverse sigmoid (expit)\"\"\"\n    return -jnp.log(1 / x - 1)\n
    "},{"location":"reference/optimize/#jaxley.optimize.transforms.inv_softplus","title":"inv_softplus(x)","text":"

    Inverse softplus.

    Source code in jaxley/optimize/transforms.py
    def inv_softplus(x: jnp.ndarray) -> jnp.ndarray:\n    \"\"\"Inverse softplus.\"\"\"\n    return jnp.log(jnp.exp(x) - 1)\n
    "},{"location":"reference/optimize/#jaxley.optimize.transforms.sigmoid","title":"sigmoid(x)","text":"

    Sigmoid.

    Source code in jaxley/optimize/transforms.py
    def sigmoid(x: jnp.ndarray) -> jnp.ndarray:\n    \"\"\"Sigmoid.\"\"\"\n    return 1 / (1 + save_exp(-x))\n
    "},{"location":"reference/optimize/#jaxley.optimize.transforms.softplus","title":"softplus(x)","text":"

    Softplus.

    Source code in jaxley/optimize/transforms.py
    def softplus(x: jnp.ndarray) -> jnp.ndarray:\n    \"\"\"Softplus.\"\"\"\n    return jnp.log(1 + jnp.exp(x))\n
    "},{"location":"reference/utils/","title":"Utils","text":""},{"location":"reference/utils/#jaxley.utils.cell_utils.compute_children_indices","title":"compute_children_indices(parents)","text":"

    Return all children indices of every branch.

    Example:

    parents = [-1, 0, 0]\ncompute_children_indices(parents) -> [[1, 2], [], []]\n

    Source code in jaxley/utils/cell_utils.py
    def compute_children_indices(parents) -> List[jnp.ndarray]:\n    \"\"\"Return all children indices of every branch.\n\n    Example:\n    ```\n    parents = [-1, 0, 0]\n    compute_children_indices(parents) -> [[1, 2], [], []]\n    ```\n    \"\"\"\n    num_branches = len(parents)\n    child_indices = []\n    for b in range(num_branches):\n        child_indices.append(np.where(parents == b)[0])\n    return child_indices\n
    "},{"location":"reference/utils/#jaxley.utils.cell_utils.compute_coupling_cond","title":"compute_coupling_cond(rad1, rad2, r_a1, r_a2, l1, l2)","text":"

    Return the coupling conductance between two compartments.

    Equations taken from https://en.wikipedia.org/wiki/Compartmental_neuron_models.

    radius: um r_a: ohm cm length_single_compartment: um coupling_conds: S * um / cm / um^2 = S / cm / um -> *10**7 -> mS / cm^2

    Source code in jaxley/utils/cell_utils.py
    def compute_coupling_cond(rad1, rad2, r_a1, r_a2, l1, l2):\n    \"\"\"Return the coupling conductance between two compartments.\n\n    Equations taken from `https://en.wikipedia.org/wiki/Compartmental_neuron_models`.\n\n    `radius`: um\n    `r_a`: ohm cm\n    `length_single_compartment`: um\n    `coupling_conds`: S * um / cm / um^2 = S / cm / um -> *10**7 -> mS / cm^2\n    \"\"\"\n    # Multiply by 10**7 to convert (S / cm / um) -> (mS / cm^2).\n    return rad1 * rad2**2 / (r_a1 * rad2**2 * l1 + r_a2 * rad1**2 * l2) / l1 * 10**7\n
    "},{"location":"reference/utils/#jaxley.utils.cell_utils.compute_coupling_cond_branchpoint","title":"compute_coupling_cond_branchpoint(rad, r_a, l)","text":"

    Return the coupling conductance between one compartment and a comp with l=0.

    From https://en.wikipedia.org/wiki/Compartmental_neuron_models

    If one compartment has l=0.0 then the equations simplify.

    R_long = \\sum_i r_a * L_i/2 / crosssection_i

    with crosssection = pi * r**2

    For a single compartment with L>0, this turns into: R_long = r_a * L/2 / crosssection

    Then, g_long = crosssection * 2 / L / r_a

    Then, the effective conductance is g_long / zylinder_area. So: g = pi * r**2 * 2 / L / r_a / 2 / pi / r / L g = r / r_a / L**2

    Source code in jaxley/utils/cell_utils.py
    def compute_coupling_cond_branchpoint(rad, r_a, l):\n    r\"\"\"Return the coupling conductance between one compartment and a comp with l=0.\n\n    From https://en.wikipedia.org/wiki/Compartmental_neuron_models\n\n    If one compartment has l=0.0 then the equations simplify.\n\n    R_long = \\sum_i r_a * L_i/2 / crosssection_i\n\n    with crosssection = pi * r**2\n\n    For a single compartment with L>0, this turns into:\n    R_long = r_a * L/2 / crosssection\n\n    Then, g_long = crosssection * 2 / L / r_a\n\n    Then, the effective conductance is g_long / zylinder_area. So:\n    g = pi * r**2 * 2 / L / r_a / 2 / pi / r / L\n    g = r / r_a / L**2\n    \"\"\"\n    return rad / r_a / l**2 * 10**7  # Convert (S / cm / um) -> (mS / cm^2)\n
    "},{"location":"reference/utils/#jaxley.utils.cell_utils.compute_impact_on_node","title":"compute_impact_on_node(rad, r_a, l)","text":"

    Compute the weight with which a compartment influences its node.

    In order to satisfy Kirchhoffs current law, the current at a branch point must be proportional to the crosssection of the compartment. We only require proportionality here because the branch point equation reads: g_1 * (V_1 - V_b) + g_2 * (V_2 - V_b) = 0.0

    Because R_long = r_a * L/2 / crosssection, we get g_long = crosssection * 2 / L / r_a \\propto rad**2 / L / r_a

    This equation can be multiplied by any constant.

    Source code in jaxley/utils/cell_utils.py
    def compute_impact_on_node(rad, r_a, l):\n    r\"\"\"Compute the weight with which a compartment influences its node.\n\n    In order to satisfy Kirchhoffs current law, the current at a branch point must be\n    proportional to the crosssection of the compartment. We only require proportionality\n    here because the branch point equation reads:\n    `g_1 * (V_1 - V_b) + g_2 * (V_2 - V_b) = 0.0`\n\n    Because R_long = r_a * L/2 / crosssection, we get\n    g_long = crosssection * 2 / L / r_a \\propto rad**2 / L / r_a\n\n    This equation can be multiplied by any constant.\"\"\"\n    return rad**2 / r_a / l\n
    "},{"location":"reference/utils/#jaxley.utils.cell_utils.compute_morphology_indices_in_levels","title":"compute_morphology_indices_in_levels(num_branchpoints, child_belongs_to_branchpoint, par_inds, child_inds)","text":"

    Return (row, col) to build the sparse matrix defining the voltage eqs.

    This is run at init, not during runtime.

    Source code in jaxley/utils/cell_utils.py
    def compute_morphology_indices_in_levels(\n    num_branchpoints,\n    child_belongs_to_branchpoint,\n    par_inds,\n    child_inds,\n):\n    \"\"\"Return (row, col) to build the sparse matrix defining the voltage eqs.\n\n    This is run at `init`, not during runtime.\n    \"\"\"\n    branchpoint_inds_parents = jnp.arange(num_branchpoints)\n    branchpoint_inds_children = child_belongs_to_branchpoint\n    branch_inds_parents = par_inds\n    branch_inds_children = child_inds\n\n    children = jnp.stack([branch_inds_children, branchpoint_inds_children])\n    parents = jnp.stack([branch_inds_parents, branchpoint_inds_parents])\n\n    return {\"children\": children.T, \"parents\": parents.T}\n
    "},{"location":"reference/utils/#jaxley.utils.cell_utils.convert_point_process_to_distributed","title":"convert_point_process_to_distributed(current, radius, length)","text":"

    Convert current point process (nA) to distributed current (uA/cm2).

    This function gets called for synapses and for external stimuli.

    Parameters:

    Name Type Description Default current ndarray

    Current in nA.

    required radius ndarray

    Compartment radius in um.

    required length ndarray

    Compartment length in um.

    required Return

    Current in uA/cm2.

    Source code in jaxley/utils/cell_utils.py
    def convert_point_process_to_distributed(\n    current: jnp.ndarray, radius: jnp.ndarray, length: jnp.ndarray\n) -> jnp.ndarray:\n    \"\"\"Convert current point process (nA) to distributed current (uA/cm2).\n\n    This function gets called for synapses and for external stimuli.\n\n    Args:\n        current: Current in `nA`.\n        radius: Compartment radius in `um`.\n        length: Compartment length in `um`.\n\n    Return:\n        Current in `uA/cm2`.\n    \"\"\"\n    area = 2 * pi * radius * length\n    current /= area  # nA / um^2\n    return current * 100_000  # Convert (nA / um^2) to (uA / cm^2)\n
    "},{"location":"reference/utils/#jaxley.utils.cell_utils.equal_segments","title":"equal_segments(branch_property, nseg_per_branch)","text":"

    Generates segments where some property is the same in each segment.

    Parameters:

    Name Type Description Default branch_property list

    List of values of the property in each branch. Should have len(branch_property) == num_branches.

    required Source code in jaxley/utils/cell_utils.py
    def equal_segments(branch_property: list, nseg_per_branch: int):\n    \"\"\"Generates segments where some property is the same in each segment.\n\n    Args:\n        branch_property: List of values of the property in each branch. Should have\n            `len(branch_property) == num_branches`.\n    \"\"\"\n    assert isinstance(branch_property, list), \"branch_property must be a list.\"\n    return jnp.asarray([branch_property] * nseg_per_branch).T\n
    "},{"location":"reference/utils/#jaxley.utils.cell_utils.get_num_neighbours","title":"get_num_neighbours(num_children, nseg_per_branch, num_branches)","text":"

    Number of neighbours of each compartment.

    Source code in jaxley/utils/cell_utils.py
    def get_num_neighbours(\n    num_children: jnp.ndarray,\n    nseg_per_branch: int,\n    num_branches: int,\n):\n    \"\"\"\n    Number of neighbours of each compartment.\n    \"\"\"\n    num_neighbours = 2 * jnp.ones((num_branches * nseg_per_branch))\n    num_neighbours = num_neighbours.at[nseg_per_branch - 1].set(1.0)\n    num_neighbours = num_neighbours.at[jnp.arange(num_branches) * nseg_per_branch].set(\n        num_children + 1.0\n    )\n    return num_neighbours\n
    "},{"location":"reference/utils/#jaxley.utils.cell_utils.group_and_sum","title":"group_and_sum(values_to_sum, inds_to_group_by, num_branchpoints)","text":"

    Group values by whether they have the same integer and sum values within group.

    This is used to construct the last diagonals at the branch points.

    Written by ChatGPT.

    Source code in jaxley/utils/cell_utils.py
    def group_and_sum(\n    values_to_sum: jnp.ndarray, inds_to_group_by: jnp.ndarray, num_branchpoints: int\n) -> jnp.ndarray:\n    \"\"\"Group values by whether they have the same integer and sum values within group.\n\n    This is used to construct the last diagonals at the branch points.\n\n    Written by ChatGPT.\n    \"\"\"\n    # Initialize an array to hold the sum of each group\n    group_sums = jnp.zeros(num_branchpoints)\n\n    # `.at[inds]` requires that `inds` is not empty, so we need an if-case here.\n    # `len(inds) == 0` is the case for branches and compartments.\n    if num_branchpoints > 0:\n        group_sums = group_sums.at[inds_to_group_by].add(values_to_sum)\n\n    return group_sums\n
    "},{"location":"reference/utils/#jaxley.utils.cell_utils.index_of_loc","title":"index_of_loc(branch_ind, loc, nseg_per_branch)","text":"

    Returns the index of a segment given a loc in [0, 1] and the index of a branch.

    This is used because we specify locations such as synapses as a value between 0 and 1. We have to convert this onto a discrete segment here.

    Parameters:

    Name Type Description Default branch_ind int

    Index of the branch.

    required loc float

    Location (in [0, 1]) along that branch.

    required nseg_per_branch int

    Number of segments of each branch.

    required

    Returns:

    Type Description int

    The index of the compartment within the entire cell.

    Source code in jaxley/utils/cell_utils.py
    def index_of_loc(branch_ind: int, loc: float, nseg_per_branch: int) -> int:\n    \"\"\"Returns the index of a segment given a loc in [0, 1] and the index of a branch.\n\n    This is used because we specify locations such as synapses as a value between 0 and\n    1. We have to convert this onto a discrete segment here.\n\n    Args:\n        branch_ind: Index of the branch.\n        loc: Location (in [0, 1]) along that branch.\n        nseg_per_branch: Number of segments of each branch.\n\n    Returns:\n        The index of the compartment within the entire cell.\n    \"\"\"\n    nseg = nseg_per_branch  # only for convenience.\n    possible_locs = np.linspace(0.5 / nseg, 1 - 0.5 / nseg, nseg)\n    ind_along_branch = np.argmin(np.abs(possible_locs - loc))\n    return branch_ind * nseg + ind_along_branch\n
    "},{"location":"reference/utils/#jaxley.utils.cell_utils.interpolate_xyz","title":"interpolate_xyz(loc, coords)","text":"

    Perform a linear interpolation between xyz-coordinates.

    Parameters:

    Name Type Description Default loc float

    The location in [0,1] along the branch.

    required coords ndarray

    Array containing the reconstructed xyzr points of the branch.

    required Return

    Interpolated xyz coordinate at loc, shape `(3,).

    Source code in jaxley/utils/cell_utils.py
    def interpolate_xyz(loc: float, coords: np.ndarray):\n    \"\"\"Perform a linear interpolation between xyz-coordinates.\n\n    Args:\n        loc: The location in [0,1] along the branch.\n        coords: Array containing the reconstructed xyzr points of the branch.\n\n    Return:\n        Interpolated xyz coordinate at `loc`, shape `(3,).\n    \"\"\"\n    return vmap(lambda x: jnp.interp(loc, jnp.linspace(0, 1, len(x)), x), in_axes=(1,))(\n        coords[:, :3]\n    )\n
    "},{"location":"reference/utils/#jaxley.utils.cell_utils.linear_segments","title":"linear_segments(initial_val, endpoint_vals, parents, nseg_per_branch)","text":"

    Generates segments where some property is linearly interpolated.

    Parameters:

    Name Type Description Default initial_val float

    The value at the tip of the soma.

    required endpoint_vals list

    The value at the endpoints of each branch.

    required Source code in jaxley/utils/cell_utils.py
    def linear_segments(\n    initial_val: float, endpoint_vals: list, parents: jnp.ndarray, nseg_per_branch: int\n):\n    \"\"\"Generates segments where some property is linearly interpolated.\n\n    Args:\n        initial_val: The value at the tip of the soma.\n        endpoint_vals: The value at the endpoints of each branch.\n    \"\"\"\n    branch_property = endpoint_vals + [initial_val]\n    num_branches = len(parents)\n    # Compute radiuses by linear interpolation.\n    endpoint_radiuses = jnp.asarray(branch_property)\n\n    def compute_rad(branch_ind, loc):\n        start = endpoint_radiuses[parents[branch_ind]]\n        end = endpoint_radiuses[branch_ind]\n        return (end - start) * loc + start\n\n    branch_inds_of_each_comp = jnp.tile(jnp.arange(num_branches), nseg_per_branch)\n    locs_of_each_comp = jnp.linspace(1, 0, nseg_per_branch).repeat(num_branches)\n    rad_of_each_comp = compute_rad(branch_inds_of_each_comp, locs_of_each_comp)\n\n    return jnp.reshape(rad_of_each_comp, (nseg_per_branch, num_branches)).T\n
    "},{"location":"reference/utils/#jaxley.utils.cell_utils.loc_of_index","title":"loc_of_index(global_comp_index, nseg)","text":"

    Return location corresponding to index.

    Source code in jaxley/utils/cell_utils.py
    def loc_of_index(global_comp_index, nseg):\n    \"\"\"Return location corresponding to index.\"\"\"\n    index = global_comp_index % nseg\n    possible_locs = np.linspace(0.5 / nseg, 1 - 0.5 / nseg, nseg)\n    return possible_locs[index]\n
    "},{"location":"reference/utils/#jaxley.utils.cell_utils.merge_cells","title":"merge_cells(cumsum_num_branches, cumsum_num_branchpoints, arrs, exclude_first=True)","text":"

    Build full list of which branches are solved in which iteration.

    From the branching pattern of single cells, this \u201cmerges\u201d them into a single ordering of branches.

    Parameters:

    Name Type Description Default cumsum_num_branches List[int]

    cumulative number of branches. E.g., for three cells with 10, 15, and 5 branches respectively, this will should be a list containing [0, 10, 25, 30].

    required arrs List[List[ndarray]]

    A list of a list of arrays that should be merged.

    required exclude_first bool

    If True, the first element of each list in arrs will remain unchanged. Useful if a -1 (which indicates \u201cno parent\u201d) entry should not be changed.

    True

    Returns:

    Type Description ndarray

    A list of arrays which contain the branch indices that are computed at each

    ndarray

    level (i.e., iteration).

    Source code in jaxley/utils/cell_utils.py
    def merge_cells(\n    cumsum_num_branches: List[int],\n    cumsum_num_branchpoints: List[int],\n    arrs: List[List[jnp.ndarray]],\n    exclude_first: bool = True,\n) -> jnp.ndarray:\n    \"\"\"\n    Build full list of which branches are solved in which iteration.\n\n    From the branching pattern of single cells, this \"merges\" them into a single\n    ordering of branches.\n\n    Args:\n        cumsum_num_branches: cumulative number of branches. E.g., for three cells with\n            10, 15, and 5 branches respectively, this will should be a list containing\n            `[0, 10, 25, 30]`.\n        arrs: A list of a list of arrays that should be merged.\n        exclude_first: If `True`, the first element of each list in `arrs` will remain\n            unchanged. Useful if a `-1` (which indicates \"no parent\") entry should not\n            be changed.\n\n    Returns:\n        A list of arrays which contain the branch indices that are computed at each\n        level (i.e., iteration).\n    \"\"\"\n    ps = []\n    for i, att in enumerate(arrs):\n        p = att\n        if exclude_first:\n            raise NotImplementedError\n            p = [p[0]] + [p_in_level + cumsum_num_branches[i] for p_in_level in p[1:]]\n        else:\n            p = [\n                p_in_level\n                + jnp.asarray([cumsum_num_branches[i], cumsum_num_branchpoints[i]])\n                for p_in_level in p\n            ]\n        ps.append(p)\n\n    max_len = max([len(att) for att in arrs])\n    combined_parents_in_level = []\n    for i in range(max_len):\n        current_ps = []\n        for p in ps:\n            if len(p) > i:\n                current_ps.append(p[i])\n        combined_parents_in_level.append(jnp.concatenate(current_ps))\n\n    return combined_parents_in_level\n
    "},{"location":"reference/utils/#jaxley.utils.cell_utils.params_to_pstate","title":"params_to_pstate(params, indices_set_by_trainables)","text":"

    Make outputs get_parameters() conform with outputs of .data_set().

    make_trainable() followed by params=get_parameters() does not return indices because these indices would also be differentiated by jax.grad (as soon as the params are passed to def simulate(params). Therefore, in jx.integrate, we run the function to add indices to the dict. The outputs of params_to_pstate are of the same shape as the outputs of .data_set().

    Source code in jaxley/utils/cell_utils.py
    def params_to_pstate(\n    params: List[Dict[str, jnp.ndarray]],\n    indices_set_by_trainables: List[jnp.ndarray],\n):\n    \"\"\"Make outputs `get_parameters()` conform with outputs of `.data_set()`.\n\n    `make_trainable()` followed by `params=get_parameters()` does not return indices\n    because these indices would also be differentiated by `jax.grad` (as soon as\n    the `params` are passed to `def simulate(params)`. Therefore, in `jx.integrate`,\n    we run the function to add indices to the dict. The outputs of `params_to_pstate`\n    are of the same shape as the outputs of `.data_set()`.\"\"\"\n    return [\n        {\"key\": list(p.keys())[0], \"val\": list(p.values())[0], \"indices\": i}\n        for p, i in zip(params, indices_set_by_trainables)\n    ]\n
    "},{"location":"reference/utils/#jaxley.utils.cell_utils.remap_to_consecutive","title":"remap_to_consecutive(arr)","text":"

    Maps an array of integers to an array of consecutive integers.

    E.g. [0, 0, 1, 4, 4, 6, 6] -> [0, 0, 1, 2, 2, 3, 3]

    Source code in jaxley/utils/cell_utils.py
    def remap_to_consecutive(arr):\n    \"\"\"Maps an array of integers to an array of consecutive integers.\n\n    E.g. `[0, 0, 1, 4, 4, 6, 6] -> [0, 0, 1, 2, 2, 3, 3]`\n    \"\"\"\n    _, inverse_indices = jnp.unique(arr, return_inverse=True)\n    return inverse_indices\n
    "},{"location":"reference/utils/#jaxley.utils.plot_utils.plot_morph","title":"plot_morph(xyzr, dims=(0, 1), col='k', ax=None, type='line', morph_plot_kwargs={})","text":"

    Plot morphology.

    Parameters:

    Name Type Description Default dims Tuple[int]

    Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of two of them.

    (0, 1) type str

    Either line or scatter.

    'line' col str

    The color for all branches.

    'k' Source code in jaxley/utils/plot_utils.py
    def plot_morph(\n    xyzr,\n    dims: Tuple[int] = (0, 1),\n    col: str = \"k\",\n    ax: Optional[Axes] = None,\n    type: str = \"line\",\n    morph_plot_kwargs: Dict = {},\n):\n    \"\"\"Plot morphology.\n\n    Args:\n        dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of\n            two of them.\n        type: Either `line` or `scatter`.\n        col: The color for all branches.\n    \"\"\"\n\n    if ax is None:\n        _, ax = plt.subplots(1, 1, figsize=(3, 3))\n\n    for coords_of_branch in xyzr:\n        x1, x2 = coords_of_branch[:, dims].T\n\n        if \"line\" in type.lower():\n            _ = ax.plot(x1, x2, color=col, **morph_plot_kwargs)\n        elif \"scatter\" in type.lower():\n            _ = ax.scatter(x1, x2, color=col, **morph_plot_kwargs)\n        else:\n            raise NotImplementedError\n\n    return ax\n
    "},{"location":"reference/utils/#jaxley.utils.swc.swc_to_jaxley","title":"swc_to_jaxley(fname, max_branch_len=100.0, sort=True, num_lines=None)","text":"

    Read an SWC file and bring morphology into jaxley compatible formats.

    Parameters:

    Name Type Description Default fname str

    Path to swc file.

    required max_branch_len float

    Maximal length of one branch. If a branch exceeds this length, it is split into equal parts such that each subbranch is below max_branch_len.

    100.0 num_lines Optional[int]

    Number of lines of the SWC file to read.

    None Source code in jaxley/utils/swc.py
    def swc_to_jaxley(\n    fname: str,\n    max_branch_len: float = 100.0,\n    sort: bool = True,\n    num_lines: Optional[int] = None,\n) -> Tuple[List[int], List[float], List[Callable], List[float], List[np.ndarray]]:\n    \"\"\"Read an SWC file and bring morphology into `jaxley` compatible formats.\n\n    Args:\n        fname: Path to swc file.\n        max_branch_len: Maximal length of one branch. If a branch exceeds this length,\n            it is split into equal parts such that each subbranch is below\n            `max_branch_len`.\n        num_lines: Number of lines of the SWC file to read.\n    \"\"\"\n    content = np.loadtxt(fname)[:num_lines]\n    types = content[:, 1]\n    is_single_point_soma = types[0] == 1 and types[1] != 1\n\n    if is_single_point_soma:\n        # Warn here, but the conversion of the length happens in `_compute_pathlengths`.\n        warn(\n            \"Found a soma which consists of a single traced point. `Jaxley` \"\n            \"interprets this soma as a spherical compartment with radius \"\n            \"specified in the SWC file, i.e. with surface area 4*pi*r*r.\"\n        )\n    sorted_branches, types = _split_into_branches_and_sort(\n        content,\n        max_branch_len=max_branch_len,\n        is_single_point_soma=is_single_point_soma,\n        sort=sort,\n    )\n\n    parents = _build_parents(sorted_branches)\n    each_length = _compute_pathlengths(\n        sorted_branches, content[:, 1:6], is_single_point_soma=is_single_point_soma\n    )\n    pathlengths = [np.sum(length_traced) for length_traced in each_length]\n    for i, pathlen in enumerate(pathlengths):\n        if pathlen == 0.0:\n            warn(\"Found a segment with length 0. Clipping it to 1.0\")\n            pathlengths[i] = 1.0\n    radius_fns = _radius_generating_fns(\n        sorted_branches, content[:, 5], each_length, parents, types\n    )\n\n    if np.sum(np.asarray(parents) == -1) > 1.0:\n        parents = np.asarray([-1] + parents)\n        parents[1:] += 1\n        parents = parents.tolist()\n        pathlengths = [0.1] + pathlengths\n        radius_fns = [lambda x: content[0, 5] * np.ones_like(x)] + radius_fns\n        sorted_branches = [[0]] + sorted_branches\n\n        # Type of padded section is assumed to be of `custom` type:\n        # http://www.neuronland.org/NLMorphologyConverter/MorphologyFormats/SWC/Spec.html\n        types = [5.0] + types\n\n    all_coords_of_branches = []\n    for i, branch in enumerate(sorted_branches):\n        # Remove 1 because `content` is an array that is indexed from 0.\n        branch = np.asarray(branch) - 1\n\n        # Deal with additional branch that might have been added above in the lines\n        # `if np.sum(np.asarray(parents) == -1) > 1.0:`\n        branch[branch < 0] = 0\n\n        # Get traced coordinates of the branch.\n        coords_of_branch = content[branch, 2:6]\n        all_coords_of_branches.append(coords_of_branch)\n\n    return parents, pathlengths, radius_fns, types, all_coords_of_branches\n
    "},{"location":"reference/utils/#jaxley.utils.jax_utils.nested_checkpoint_scan","title":"nested_checkpoint_scan(f, init, xs, length=None, *, nested_lengths, scan_fn=jax.lax.scan, checkpoint_fn=jax.checkpoint)","text":"

    A version of lax.scan that supports recursive gradient checkpointing.

    Code taken from: https://github.com/google/jax/issues/2139

    The interface of nested_checkpoint_scan exactly matches lax.scan, except for the required nested_lengths argument.

    The key feature of nested_checkpoint_scan is that gradient calculations require O(max(nested_lengths)) memory, vs O(prod(nested_lengths)) for unnested scans, which it achieves by re-evaluating the forward pass len(nested_lengths) - 1 times.

    nested_checkpoint_scan reduces to lax.scan when nested_lengths has a single element.

    Parameters:

    Name Type Description Default f Callable[[Carry, Dict[str, ndarray]], Tuple[Carry, Output]]

    function to scan over.

    required init Carry

    initial value.

    required xs Dict[str, ndarray]

    scanned over values.

    required length Optional[int]

    leading length of all dimensions

    None nested_lengths Sequence[int]

    required list of lengths to scan over for each level of checkpointing. The product of nested_lengths must match length (if provided) and the size of the leading axis for all arrays in xs.

    required scan_fn

    function matching the API of lax.scan

    scan checkpoint_fn Callable[[Func], Func]

    function matching the API of jax.checkpoint.

    checkpoint Source code in jaxley/utils/jax_utils.py
    def nested_checkpoint_scan(\n    f: Callable[[Carry, Dict[str, jnp.ndarray]], Tuple[Carry, Output]],\n    init: Carry,\n    xs: Dict[str, jnp.ndarray],\n    length: Optional[int] = None,\n    *,\n    nested_lengths: Sequence[int],\n    scan_fn=jax.lax.scan,\n    checkpoint_fn: Callable[[Func], Func] = jax.checkpoint,\n):\n    \"\"\"A version of lax.scan that supports recursive gradient checkpointing.\n\n    Code taken from: https://github.com/google/jax/issues/2139\n\n    The interface of `nested_checkpoint_scan` exactly matches lax.scan, except for\n    the required `nested_lengths` argument.\n\n    The key feature of `nested_checkpoint_scan` is that gradient calculations\n    require O(max(nested_lengths)) memory, vs O(prod(nested_lengths)) for unnested\n    scans, which it achieves by re-evaluating the forward pass\n    `len(nested_lengths) - 1` times.\n\n    `nested_checkpoint_scan` reduces to `lax.scan` when `nested_lengths` has a\n    single element.\n\n    Args:\n        f: function to scan over.\n        init: initial value.\n        xs: scanned over values.\n        length: leading length of all dimensions\n        nested_lengths: required list of lengths to scan over for each level of\n            checkpointing. The product of nested_lengths must match length (if\n            provided) and the size of the leading axis for all arrays in ``xs``.\n        scan_fn: function matching the API of lax.scan\n        checkpoint_fn: function matching the API of jax.checkpoint.\n    \"\"\"\n    if length is not None and length != math.prod(nested_lengths):\n        raise ValueError(f\"inconsistent {length=} and {nested_lengths=}\")\n\n    def nested_reshape(x):\n        x = jnp.asarray(x)\n        new_shape = tuple(nested_lengths) + x.shape[1:]\n        return x.reshape(new_shape)\n\n    sub_xs = jax.tree_map(nested_reshape, xs)\n    return _inner_nested_scan(f, init, sub_xs, nested_lengths, scan_fn, checkpoint_fn)\n
    "},{"location":"reference/utils/#jaxley.utils.syn_utils.gather_synapes","title":"gather_synapes(number_of_compartments, post_syn_comp_inds, current_each_synapse_voltage_term, current_each_synapse_constant_term)","text":"

    Compute current at the post synapse.

    All this does it that it sums the synaptic currents that come into a particular compartment. It returns an array of as many elements as there are compartments.

    Source code in jaxley/utils/syn_utils.py
    def gather_synapes(\n    number_of_compartments: jnp.ndarray,\n    post_syn_comp_inds: np.ndarray,\n    current_each_synapse_voltage_term: jnp.ndarray,\n    current_each_synapse_constant_term: jnp.ndarray,\n) -> Tuple[jnp.ndarray, jnp.ndarray]:\n    \"\"\"Compute current at the post synapse.\n\n    All this does it that it sums the synaptic currents that come into a particular\n    compartment. It returns an array of as many elements as there are compartments.\n    \"\"\"\n    incoming_currents_voltages = jnp.zeros((number_of_compartments,))\n    incoming_currents_contant = jnp.zeros((number_of_compartments,))\n\n    dnums = ScatterDimensionNumbers(\n        update_window_dims=(),\n        inserted_window_dims=(0,),\n        scatter_dims_to_operand_dims=(0,),\n    )\n    incoming_currents_voltages = scatter_add(\n        incoming_currents_voltages,\n        post_syn_comp_inds[:, None],\n        current_each_synapse_voltage_term,\n        dnums,\n    )\n    incoming_currents_contant = scatter_add(\n        incoming_currents_contant,\n        post_syn_comp_inds[:, None],\n        current_each_synapse_constant_term,\n        dnums,\n    )\n    return incoming_currents_voltages, incoming_currents_contant\n
    "},{"location":"tutorial/01_morph_neurons/","title":"Basics of running simulations in Jaxley","text":"

    In this tutorial, you will learn how to:

    • build your first morphologically detailed cell or read it from SWC
    • stimulate the cell
    • record from the cell
    • visualize cells
    • run your first simulation

    Here is a code snippet which you will learn to understand in this tutorial:

    import jaxley as jx\nfrom jaxley.channels import Na, K, Leak\nimport matplotlib.pyplot as plt\n\n\n# Build the network.\ncomp = jx.Compartment()\nbranch = jx.Branch(comp, nseg=4)\ncell = jx.Cell(branch, parents=[-1, 0, 0, 1, 1])\n\n# Insert channels.\ncell.insert(Leak())\ncell.branch(0).insert(Na())\ncell.branch(0).insert(K())\n\n# Visualize the morphology.\ncell.compute_xyz()\nfig, ax = plt.subplots(1, 1, figsize=(4, 4))\ncell.vis(ax=ax)\n\n# Stimulate.\ncurrent = jx.step_current(i_delay=1.0, i_dur=1.0, i_amp=0.1, dt=0.025, t_max=10.0)\ncell.branch(0).loc(0.0).stimulate(current)\n\n# Record.\ncell.branch(0).loc(0.0).record(\"v\")\n\n# Simulate and plot.\nv = jx.integrate(net)\nplt.plot(v.T)\n

    First, we import the relevant libraries:

    from jax import config\nconfig.update(\"jax_enable_x64\", True)\nconfig.update(\"jax_platform_name\", \"cpu\")\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport jax.numpy as jnp\nfrom jax import jit\n\nimport jaxley as jx\nfrom jaxley.channels import Na, K, Leak\nfrom jaxley.synapses import IonotropicSynapse\nfrom jaxley.connect import fully_connect\n

    We will now build our first cell in Jaxley. You have two options to do this: you can either build a cell bottom-up by defining the morphology yourselve, or you can load cells from SWC files.

    "},{"location":"tutorial/01_morph_neurons/#define-the-cell-from-scratch","title":"Define the cell from scratch","text":"

    To define a cell from scratch you first have to define a single compartment and then assemble those compartments into a branch:

    comp = jx.Compartment()\nbranch = jx.Branch(comp, nseg=4)\n

    Next, we can assemble branches into a cell. To do so, we have to define for each branch what its parent branch is. A -1 entry means that this branch does not have a parent.

    parents = jnp.asarray([-1, 0, 0, 1, 1])\ncell = jx.Cell(branch, parents=parents)\n
    "},{"location":"tutorial/01_morph_neurons/#read-the-cell-from-an-swc-file","title":"Read the cell from an SWC file","text":"

    Alternatively, you could also load cells from SWC with

    cell = jx.read_swc(fname, nseg=4).

    "},{"location":"tutorial/01_morph_neurons/#visualize-the-cells","title":"Visualize the cells","text":"

    Cells can be visualized as follows:

    cell.compute_xyz()  # Only needed for visualization.\n\nfig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = cell.vis(ax=ax, col=\"k\")\n

    "},{"location":"tutorial/01_morph_neurons/#insert-mechanisms","title":"Insert mechanisms","text":"

    Currently, the cell does not contain any kind of ion channel (not even a leak). We can fix this by inserting a leak channel into the entire cell, and by inserting sodium and potassium into the zero-eth branch.

    cell.insert(Leak())\ncell.branch(0).insert(Na())\ncell.branch(0).insert(K())\n

    The easiest way to know which branch is the zero-eth branch (or, e.g., the zero-eth compartment of the zero-eth branch) is to plot it in a different color:

    fig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = cell.vis(ax=ax, col=\"k\")\n_ = cell.branch(0).vis(ax=ax, col=\"r\")\n_ = cell.branch(0).loc(0.0).vis(ax=ax, col=\"b\")\n

    "},{"location":"tutorial/01_morph_neurons/#stimulate-the-cell","title":"Stimulate the cell","text":"

    We next stimulate one of the compartments with a step current. For this, we first define the step current:

    dt = 0.025\nt_max = 10.0\ntime_vec = np.arange(0, t_max+dt, dt)\ncurrent = jx.step_current(i_delay=1.0, i_dur=1.0, i_amp=0.1, delta_t=dt, t_max=t_max)\n\nfig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = plt.plot(time_vec, current)\n

    We then stimulate one of the compartments of the cell with this step current:

    cell.delete_stimuli()\ncell.branch(0).loc(0.0).stimulate(current)\n
    Added 1 external_states. See `.externals` for details.\n
    "},{"location":"tutorial/01_morph_neurons/#define-recordings","title":"Define recordings","text":"

    Next, you have to define where to record the voltage. In this case, we will record the voltage at two locations:

    cell.delete_recordings()\ncell.branch(0).loc(0.0).record(\"v\")\ncell.branch(3).loc(1.0).record(\"v\")\n
    Added 1 recordings. See `.recordings` for details.\nAdded 1 recordings. See `.recordings` for details.\n

    We can again visualize these locations to understand where we inserted recordings:

    fig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = cell.vis(ax=ax)\n_ = cell.branch(0).loc(0.0).vis(ax=ax, col=\"b\")\n_ = cell.branch(3).loc(1.0).vis(ax=ax, col=\"g\")\n

    "},{"location":"tutorial/01_morph_neurons/#simulate-the-cell-response","title":"Simulate the cell response","text":"

    Having set up the cell, inserted stimuli and recordings, we are now ready to run a simulation with jx.integrate:

    voltages = jx.integrate(cell)\nprint(\"voltages.shape\", voltages.shape)\n
    voltages.shape (2, 402)\n

    The jx.integrate function returns an array of shape (num_recordings, num_timepoints). In our case, we inserted2` recordings and we simulated for 10ms at a 0.025 time step, which leads to 402 time steps.

    We can now visualize the voltage response:

    fig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = ax.plot(voltages[0], c=\"b\")\n_ = ax.plot(voltages[1], c=\"orange\")\n

    At the location of the first recording (in blue) the cell spiked, whereas at the second recording, it did not. This makes sense because we only inserted sodium and potassium channels into the first branch, but not in the entire cell.

    Congrats! You have just run your first morphologically detailed neuron simulation in Jaxley. We suggest to continue by learning how to build networks. If you are only interested in single cell simulations, you can directly jump to learning how to modify parameters of your simulation. If you want to simulate detailed morphologies from SWC files, checkout our tutorial on working with detailed morphologies.

    "},{"location":"tutorial/02_small_network/","title":"Network simulations in Jaxley","text":"

    In this tutorial, you will learn how to:

    • connect neurons into a network
    • visualize networks

    Here is a code snippet which you will learn to understand in this tutorial:

    import jaxley as jx\nfrom jaxley.synapses import IonotropicSynapse\nfrom jaxley.connect import connect\n\n\n# Define a network. `cell` is defined as in previous tutorial.\nnet = jx.Network([cell for _ in range(11)])\n\n# Define synapses.\nfully_connect(\n    net.cell(range(10)),\n    net.cell(10),\n    IonotropicSynapse(),\n)\n\n# Visualize the network.\nnet.compute_xyz()\nfig, ax = plt.subplots(1, 1, figsize=(4, 4))\nnet.vis(ax=ax, detail=\"full\", layers=[10, 1])  # or `detail=\"point\"`.\n

    In the previous tutorial, you learned how to build single cells with morphological detail, how to insert stimuli and recordings, and how to run a first simulation. In this tutorial, we will define networks of multiple cells and connect them with synapses. Let\u2019s get started:

    from jax import config\nconfig.update(\"jax_enable_x64\", True)\nconfig.update(\"jax_platform_name\", \"cpu\")\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport jax.numpy as jnp\nfrom jax import jit\n\nimport jaxley as jx\nfrom jaxley.channels import Na, K, Leak\nfrom jaxley.synapses import IonotropicSynapse\nfrom jaxley.connect import fully_connect, connect\n
    "},{"location":"tutorial/02_small_network/#define-the-network","title":"Define the network","text":"

    First, we define a cell as you saw in the previous tutorial.

    comp = jx.Compartment()\nbranch = jx.Branch(comp, nseg=4)\ncell = jx.Cell(branch, parents=[-1, 0, 0, 1, 1, 2, 2])\n

    We can assemble multiple cells into a network by using jx.Network, which takes a list of jx.Cells. Here, we assemble 11 cells into a network:

    num_cells = 11\nnet = jx.Network([cell for _ in range(num_cells)])\n

    At this point, we can already visualize this network:

    net.compute_xyz()\nnet.rotate(180)\nfig, ax = plt.subplots(1, 1, figsize=(3, 6))\n_ = net.vis(ax=ax, detail=\"full\", layers=[10, 1], layer_kwargs={\"within_layer_offset\": 150, \"between_layer_offset\": 200})\n

    Note: you can use move_to to have more control over the location of cells, e.g.: network.cell(i).move_to(x=0, y=200)

    As you can see, the neurons are not connected yet. Let\u2019s fix this by connecting neurons with synapses. We will build a network consisting of two layers: 10 neurons in the input layer and 1 neuron in the output layer.

    We can use Jaxley\u2019s fully_connect method to connect these layers:

    pre = net.cell(range(10))\npost = net.cell(10)\nfully_connect(pre, post, IonotropicSynapse())\n
    /Users/michaeldeistler/Documents/phd/jaxley/jaxley/modules/base.py:1533: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.\n  self.pointer.edges = pd.concat(\n

    Let\u2019s visualize this again:

    fig, ax = plt.subplots(1, 1, figsize=(3, 6))\n_ = net.vis(ax=ax, detail=\"full\", layers=[10, 1], layer_kwargs={\"within_layer_offset\": 150, \"between_layer_offset\": 200})\n

    As you can see, the full_connect method inserted one synapse (in blue) from every neuron in the first layer to the output neuron. The fully_connect method builds this synapse from the zero-eth compartment and zero-eth branch of the presynaptic neuron onto a random branch of the postsynaptic neuron. If you want more control over the pre- and post-synaptic branches, you can use the connect method:

    pre = net.cell(0).branch(5).loc(1.0)\npost = net.cell(10).branch(0).loc(0.0)\nconnect(pre, post, IonotropicSynapse())\n
    fig, ax = plt.subplots(1, 1, figsize=(3, 6))\n_ = net.vis(ax=ax, detail=\"full\", layers=[10, 1], layer_kwargs={\"within_layer_offset\": 150, \"between_layer_offset\": 200})\n

    "},{"location":"tutorial/02_small_network/#stimulating-recording-and-simulating-the-network","title":"Stimulating, recording, and simulating the network","text":"

    We will now set up a simulation of the network. This works exactly as it does for single neurons:

    # Stimulus.\ni_delay = 3.0  # ms\ni_amp = 0.05  # nA\ni_dur = 2.0  # ms\n\n# Duration and step size.\ndt = 0.025  # ms\nt_max = 50.0  # ms\n
    time_vec = jnp.arange(0.0, t_max + dt, dt)\n

    As a simple example, we insert sodium, potassium, and leak into every compartment of every cell of the network.

    net.insert(Na())\nnet.insert(K())\nnet.insert(Leak())\n

    We stimulate every neuron in the input layer and record the voltage from the output neuron:

    current = jx.step_current(i_delay, i_dur, i_amp, dt, t_max)\nnet.delete_stimuli()\nfor stim_ind in range(10):\n    net.cell(stim_ind).branch(0).loc(0.0).stimulate(current)\n\nnet.delete_recordings()\nnet.cell(10).branch(0).loc(0.0).record()\n
    Added 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 recordings. See `.recordings` for details.\n

    Finally, we can again run the network simulation and plot the result:

    s = jx.integrate(net)\n
    fig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = ax.plot(s.T)\n

    That\u2019s it! You now know how to simulate networks of morphologically detailed neurons. Next, you should learn how to modify parameters of your simulation in this tutorial.

    "},{"location":"tutorial/03_setting_parameters/","title":"Setting parameters and initial states","text":"

    In this tutorial, you will learn how to:

    • set parameters of Jaxley models such as compartment radius or channel conductances
    • set initial states
    • set synaptic parameters

    Here is a code snippet which you will learn to understand in this tutorial:

    cell = ...  # See tutorial on Basics of Jaxley.\ncell.insert(Na())\n\ncell.set(\"radius\", 1.0)  # Set compartment radius.\ncell.branch(0).set(\"Na_gNa\", 0.1)  # Set sodium maximal conductance.\ncell.set(\"v\", -65.0)  # Set initial voltage.\n\nnet = ...  # See tutorial on Networks of Jaxley.\nfully_connect(net.cell(0), net.cell(1), IonotropicSynapse())\nnet.IonotropicSynapse().set(\"IonotropicSynapse_gS\", 0.01)\n

    In the previous two tutorials, you learned how to build single cells or networks and how to simulate them. In this tutorial, you will learn how to change parameters of such simulations.

    Let\u2019s get started!

    import matplotlib.pyplot as plt\nimport numpy as np\nimport jax.numpy as jnp\nfrom jax import jit, vmap\n\nimport jaxley as jx\nfrom jaxley.channels import Na, K, Leak\n
    "},{"location":"tutorial/03_setting_parameters/#preface-building-the-cell-or-network","title":"Preface: Building the cell or network","text":"

    We first build a cell (or network) in the same way as we showed in the previous tutorials:

    dt = 0.025\nt_max = 10.0\n\ncomp = jx.Compartment()\nbranch = jx.Branch(comp, nseg=2)\ncell = jx.Cell(branch, parents=[-1, 0])\n
    "},{"location":"tutorial/03_setting_parameters/#setting-parameters-in-jaxley","title":"Setting parameters in Jaxley","text":"

    To modify parameters of the simulation, you can use the .set() method. For example

    cell.set(\"radius\", 0.1)\n
    will modify the radius of every compartment in the cell to 0.1 micrometer. You can also modify the parameters only of some branches:
    cell.branch(0).set(\"radius\", 1.0)\n
    or even of compartments:
    cell.branch(0).comp(0).set(\"radius\", 10.0)\n

    You can always inspect the current parameters by inspecting cell.nodes, which is a pandas Dataframe that contains all information about the cell. You can use .set() to set morphological parameters, channel parameters, synaptic parameters, and initial states, as outlined below:

    "},{"location":"tutorial/03_setting_parameters/#setting-morphological-parameters","title":"Setting morphological parameters","text":"

    Jaxley allows to set the following morphological parameters:

    • radius: the radius of the (zylindrical) compartment (in micrometer)
    • length: the length of the zylindrical compartment (in micrometer)
    • axial_resistivity: the resistivity of current flow between compartments (in ohm centimeter)
    cell.branch(0).set(\"axial_resistivity\", 1000.0)\ncell.set(\"length\", 1.0)  # This will set every compartment in the cell to have length 1.0.\n
    cell.nodes\n
    comp_index branch_index cell_index length radius axial_resistivity capacitance v 0 0 0 0 1.0 1.0 1000.0 1.0 -70.0 1 1 0 0 1.0 1.0 1000.0 1.0 -70.0 2 2 1 0 1.0 1.0 5000.0 1.0 -70.0 3 3 1 0 1.0 1.0 5000.0 1.0 -70.0"},{"location":"tutorial/03_setting_parameters/#setting-channel-parameters","title":"Setting channel parameters","text":"

    You can also modify channel parameters. Every parameter that should be modifiable has to be defined in self.channel_params of the channel.

    cell.insert(Na())\ncell.branch(1).comp(0).set(\"Na_gNa\", 0.1)\n
    cell.nodes\n
    comp_index branch_index cell_index length radius axial_resistivity capacitance v Na Na_gNa eNa vt Na_m Na_h 0 0 0 0 1.0 1.0 1000.0 1.0 -70.0 True 0.05 50.0 -60.0 0.2 0.2 1 1 0 0 1.0 1.0 1000.0 1.0 -70.0 True 0.05 50.0 -60.0 0.2 0.2 2 2 1 0 1.0 1.0 5000.0 1.0 -70.0 True 0.10 50.0 -60.0 0.2 0.2 3 3 1 0 1.0 1.0 5000.0 1.0 -70.0 True 0.05 50.0 -60.0 0.2 0.2"},{"location":"tutorial/03_setting_parameters/#setting-synaptic-parameters","title":"Setting synaptic parameters","text":"

    In order to set parameters of synapses, you have to use net.SynapseName.set(), e.g.:

    from jaxley.synapses import IonotropicSynapse\nfrom jaxley.connect import fully_connect\n\nnum_cells = 2\nnet = jx.Network([cell for _ in range(num_cells)])\nfully_connect(net.cell(0), net.cell(1), IonotropicSynapse())\n\n# Unlike for channels, you have to index into the synapse with `net.SynapseName`\nnet.IonotropicSynapse.set(\"IonotropicSynapse_gS\", 0.1)\n
    /Users/michaeldeistler/Documents/phd/jaxley/jaxley/modules/base.py:1533: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.\n  self.pointer.edges = pd.concat(\n

    You can inspect synaptic parameters and states with net.edges:

    net.edges\n
    pre_locs pre_branch_index pre_cell_index post_locs post_branch_index post_cell_index type type_ind global_pre_comp_index global_post_comp_index global_pre_branch_index global_post_branch_index IonotropicSynapse_gS IonotropicSynapse_e_syn IonotropicSynapse_k_minus IonotropicSynapse_s 0 0.25 0 0 0.25 1 1 IonotropicSynapse 0 0 6 0 3 0.1 0.0 0.025 0.2"},{"location":"tutorial/03_setting_parameters/#setting-initial-states","title":"Setting initial states","text":"

    Finally, you can also set initial states. These include the initial voltage v and the states of all channels and synapses (which must be defined in self.channel_states of the channel. For example:

    net.set(\"v\", -72.0)\nnet.IonotropicSynapse.set(\"IonotropicSynapse_s\", 0.1)\n
    net.nodes\n
    comp_index branch_index cell_index length radius axial_resistivity capacitance v Na Na_gNa eNa vt Na_m Na_h 0 0 0 0 1.0 1.0 1000.0 1.0 -72.0 True 0.05 50.0 -60.0 0.2 0.2 1 1 0 0 1.0 1.0 1000.0 1.0 -72.0 True 0.05 50.0 -60.0 0.2 0.2 2 2 1 0 1.0 1.0 5000.0 1.0 -72.0 True 0.10 50.0 -60.0 0.2 0.2 3 3 1 0 1.0 1.0 5000.0 1.0 -72.0 True 0.05 50.0 -60.0 0.2 0.2 4 4 2 1 1.0 1.0 1000.0 1.0 -72.0 True 0.05 50.0 -60.0 0.2 0.2 5 5 2 1 1.0 1.0 1000.0 1.0 -72.0 True 0.05 50.0 -60.0 0.2 0.2 6 6 3 1 1.0 1.0 5000.0 1.0 -72.0 True 0.10 50.0 -60.0 0.2 0.2 7 7 3 1 1.0 1.0 5000.0 1.0 -72.0 True 0.05 50.0 -60.0 0.2 0.2
    net.edges\n
    pre_locs pre_branch_index pre_cell_index post_locs post_branch_index post_cell_index type type_ind global_pre_comp_index global_post_comp_index global_pre_branch_index global_post_branch_index IonotropicSynapse_gS IonotropicSynapse_e_syn IonotropicSynapse_k_minus IonotropicSynapse_s 0 0.25 0 0 0.25 1 1 IonotropicSynapse 0 0 6 0 3 0.1 0.0 0.025 0.1"},{"location":"tutorial/03_setting_parameters/#summary","title":"Summary","text":"

    You can now modify parameters of your Jaxley simulation. In the next tutorial, you will learn how to make parameter sweeps (or stimulus sweeps) fast with jit-compilation and GPU parallelization.

    "},{"location":"tutorial/04_jit_and_vmap/","title":"Speeding up simulations with JIT-compilation and GPUs","text":"

    In this tutorial, you will learn how to:

    • make parameter sweeps in Jaxley
    • use jit to compile your simulations and make them faster
    • use vmap to parallelize simulations on GPUs

    Here is a code snippet which you will learn to understand in this tutorial:

    from jax import jit, vmap\n\n\ncell = ...  # See tutorial on Basics of Jaxley.\n\ndef simulate(params):\n    param_state = None\n    param_state = cell.data_set(\"Na_gNa\", params[0], param_state)\n    param_state = cell.data_set(\"K_gK\", params[1], param_state)\n    return jx.integrate(cell, param_state=param_state)\n\n# Define 100 sets of sodium and potassium conductances.\nall_params = jnp.asarray(np.random.rand(100, 2))\n\n# Fast for-loops with jit compilation.\njitted_simulate = jit(simulate)\nvoltages = [jitted_simulate(params) for params in all_params]\n\n# Using vmap for parallelization.\nvmapped_simulate = vmap(jitted_simulate, in_axes=(0,))\nvoltages = vmapped_simulate(all_params)\n

    In the previous tutorials, you learned how to build single cells or networks and how to change their parameters. In this tutorial, you will learn how to speed up such simulations by many orders of magnitude. This can be achieved in to ways:

    • by using JIT compilation
    • by using GPU parallelization

    Let\u2019s get started!

    "},{"location":"tutorial/04_jit_and_vmap/#using-gpu-or-cpu","title":"Using GPU or CPU","text":"

    In Jaxley you can set whether you want to use gpu or cpu with the following lines at the beginning of your script:

    from jax import config\nconfig.update(\"jax_platform_name\", \"cpu\")\n

    JAX (and Jaxley) also allow to choose between float32 and float64. Especially on GPUs, float32 will be faster, but we have experienced stability issues when simulating morphologically detailed neurons with float32.

    config.update(\"jax_enable_x64\", True)  # Set to false to use `float32`.\n

    Next, we will import relevant libraries:

    import matplotlib.pyplot as plt\nimport numpy as np\nimport jax.numpy as jnp\nfrom jax import jit, vmap\n\nimport jaxley as jx\nfrom jaxley.channels import Na, K, Leak\n
    "},{"location":"tutorial/04_jit_and_vmap/#building-the-cell-or-network","title":"Building the cell or network","text":"

    We first build a cell (or network) in the same way as we showed in the previous tutorials:

    dt = 0.025\nt_max = 10.0\n\ncomp = jx.Compartment()\nbranch = jx.Branch(comp, nseg=4)\ncell = jx.Cell(branch, parents=[-1, 0, 0, 1, 1, 2, 2])\n\ncell.insert(Na())\ncell.insert(K())\ncell.insert(Leak())\n\ncell.delete_stimuli()\ncurrent = jx.step_current(i_delay=1.0, i_dur=1.0, i_amp=0.1, delta_t=dt, t_max=t_max)\ncell.branch(0).loc(0.0).stimulate(current)\n\ncell.delete_recordings()\ncell.branch(0).loc(0.0).record()\n
    Added 1 external_states. See `.externals` for details.\nAdded 1 recordings. See `.recordings` for details.\n
    "},{"location":"tutorial/04_jit_and_vmap/#parameter-sweeps","title":"Parameter sweeps","text":"

    Assume you want to run the same cell with many different values for the sodium and potassium conductance, for example for genetic algorithms or for parameter sweeps. To do this efficiently in Jaxley, you have to use the data_set() method (in combination with jit and vmap, as shown later):

    def simulate(params):\n    param_state = None\n    param_state = cell.data_set(\"Na_gNa\", params[0], param_state)\n    param_state = cell.data_set(\"K_gK\", params[1], param_state)\n    return jx.integrate(cell, param_state=param_state)\n

    The .data_set() method takes three arguments:

    1) the name of the parameter you want to set. Jaxley allows to set the following parameters: \u201cradius\u201d, \u201clength\u201d, \u201caxial_resistivity\u201d, as well as all parameters of channels and synapses. 2) the value of the parameter. 3) a param_state which is initialized as None and is modified by .data_set(). This has to be passed to jx.integrate().

    Having done this, the simplest (but least efficient) way to perform the parameter sweep is to run a for-loop over many parameter sets:

    # Define 5 sets of sodium and potassium conductances.\nall_params = jnp.asarray(np.random.rand(5, 2))\n\nvoltages = jnp.asarray([simulate(params) for params in all_params])\nprint(\"voltages.shape\", voltages.shape)\n
    voltages.shape (5, 1, 402)\n

    The resulting voltages have shape (num_simulations, num_recordings, num_timesteps).

    "},{"location":"tutorial/04_jit_and_vmap/#stimulus-sweeps","title":"Stimulus sweeps","text":"

    In addition to running sweeps across multiple parameters, you can also run sweeeps across multiple stimuli (e.g. step current stimuli of different amplitudes. You can achieve this with the data_stimulate() method:

    def simulate(i_amp):\n    current = jx.step_current(1.0, 1.0, i_amp, 0.025, 10.0)\n\n    data_stimuli = None\n    data_stimuli = cell.branch(0).comp(0).data_stimulate(current, data_stimuli)\n    return jx.integrate(cell, data_stimuli=data_stimuli)\n

    "},{"location":"tutorial/04_jit_and_vmap/#speeding-up-for-loops-via-jit-compilation","title":"Speeding up for loops via jit compilation","text":"

    We can speed up such parameter sweeps (or stimulus sweeps) with jit compilation. jit compilation will compile the simulation when it is run for the first time, such that every other simulation will be must faster. This can be achieved by defining a new function which uses JAX\u2019s jit():

    jitted_simulate = jit(simulate)\n
    # First run, will be slow.\nvoltages = jitted_simulate(all_params[0])\n
    # More runs, will be much faster.\nvoltages = jnp.asarray([jitted_simulate(params) for params in all_params])\nprint(\"voltages.shape\", voltages.shape)\n
    voltages.shape (5, 1, 402)\n

    jit compilation can be up to 10k times faster, especially for small simulations with few compartments. For very large models, the gain obtained with jit will be much smaller (jit may even provide no speed up at all).

    "},{"location":"tutorial/04_jit_and_vmap/#speeding-up-with-gpu-parallelization-via-vmap","title":"Speeding up with GPU parallelization via vmap","text":"

    Another way to speed up parameter sweeps is with GPU parallelization. Parallelization in Jaxley can be achieved by using vmap of JAX. To do this, we first create a new function that handles multiple parameter sets directly:

    # Using vmap for parallelization.\nvmapped_simulate = vmap(jitted_simulate)\n

    We can then run this method on all parameter sets (all_params.shape == (100, 2)), and Jaxley will automatically parallelize across them. Of course, you will only get a speed-up if you have a GPU available and you specified gpu as device in the beginning of this tutorial.

    voltages = vmapped_simulate(all_params)\n

    GPU parallelization with vmap can give a large speed-up, which can easily be 2-3 orders of magnitude.

    "},{"location":"tutorial/04_jit_and_vmap/#combining-jit-and-vmap","title":"Combining jit and vmap","text":"

    Finally, you can also combine using jit and vmap. For example, you can run multiple batches of many parallel simulations. Each batch can be parallelized with vmap and simulating each batch can be compiled with jit:

    jitted_vmapped_simulate = jit(vmap(simulate))\n
    for batch in range(10):\n    all_params = jnp.asarray(np.random.rand(5, 2))\n    voltages_batch = jitted_vmapped_simulate(all_params)\n

    That\u2019s all you have to know about jit and vmap! If you have worked through this and the previous tutorials, you should be ready to set up your first network simulations.

    "},{"location":"tutorial/04_jit_and_vmap/#next-steps","title":"Next steps","text":"

    If you want to learn more, we recommend you to read the tutorial on building channel and synapse models or to read the tutorial on groups, which allow to make your Jaxley simulations more elegant and convenient to interact with.

    Alternatively, you can also directly jump ahead to the tutorial on training biophysical networks which will teach you how you can optimize parameters of biophysical models with gradient descent.

    "},{"location":"tutorial/05_channel_and_synapse_models/","title":"Building and using ion channel models","text":"

    In this tutorial, you will learn how to:

    • define your own ion channel models beyond the preconfigured channels in Jaxley

    This tutorial assumes that you have already learned how to build basic simulations.

    from jax import config\nconfig.update(\"jax_enable_x64\", True)\nconfig.update(\"jax_platform_name\", \"cpu\")\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport jax\nimport jax.numpy as jnp\nfrom jax import jit, value_and_grad\n\nimport jaxley as jx\n

    First, we define a cell as you saw in the previous tutorial:

    comp = jx.Compartment()\nbranch = jx.Branch(comp, nseg=4)\ncell = jx.Cell(branch, parents=[-1, 0, 0, 1, 1, 2, 2])\n

    You have also already learned how to insert preconfigured channels into Jaxley models:

    cell.insert(Na())\ncell.insert(K())\ncell.insert(Leak())\n

    In this tutorial, we will show you how to build your own channel and synapse models.

    "},{"location":"tutorial/05_channel_and_synapse_models/#your-own-channel","title":"Your own channel","text":"

    Below is how you can define your own channel. We will go into detail about individual parts of the code in the next couple of cells.

    import jax.numpy as jnp\nfrom jaxley.channels import Channel\nfrom jaxley.solver_gate import solve_gate_exponential\n\n\ndef exp_update_alpha(x, y):\n    return x / (jnp.exp(x / y) - 1.0)\n\nclass Potassium(Channel):\n    \"\"\"Potassium channel.\"\"\"\n\n    def __init__(self, name = None):\n        super().__init__(name)\n        self.channel_params = {\"gK_new\": 1e-4}\n        self.channel_states = {\"n_new\": 0.0}\n        self.current_name = \"i_K\"\n\n    def update_states(self, states, dt, v, params):\n        \"\"\"Update state.\"\"\"\n        ns = states[\"n_new\"]\n        alpha = 0.01 * exp_update_alpha(-(v + 55), 10)\n        beta = 0.125 * jnp.exp(-(v + 65) / 80)\n        new_n = solve_gate_exponential(ns, dt, alpha, beta)\n        return {\"n_new\": new_n}\n\n    def compute_current(self, states, v, params):\n        \"\"\"Return current.\"\"\"\n        ns = states[\"n_new\"]\n\n        # Multiply with 1000 to convert Siemens to milli Siemens.\n        kd_conds = params[\"gK_new\"] * ns**4 * 1000  # mS/cm^2\n\n        e_kd = -77.0        \n        return kd_conds * (v - e_kd)\n\n    def init_state(self, v, params):\n        alpha = 0.01 * exp_update_alpha(-(v + 55), 10)\n        beta = 0.125 * jnp.exp(-(v + 65) / 80)\n        return {\"n_new\": alpha / (alpha + beta)}\n

    Let\u2019s look at each part of this in detail.

    The below is simply a helper function for the solver of the gate variables:

    def exp_update_alpha(x, y):\n    return x / (jnp.exp(x / y) - 1.0)\n

    Next, we define our channel as a class. It should inherit from the Channel class and define channel_params, channel_states, and current_name.

    class Potassium(Channel):\n    \"\"\"Potassium channel.\"\"\"\n\n    def __init__(self, name=None):\n        super().__init__(name)\n        self.channel_params = {\"gK_new\": 1e-4}\n        self.channel_states = {\"n_new\": 0.0}\n        self.current_name = \"i_K\"\n

    Next, we have the update_states() method, which updates the gating variables:

        def update_states(self, states, dt, v, params):\n

    The inputs states to the update_states method is a dictionary which contains all states that are updated (including states of other channels). v is a jnp.ndarray which contains the voltage of a single compartment (shape ()). Let\u2019s get the state of the potassium channel which we are building here:

    ns = states[\"n_new\"]\n

    Next, we update the state of the channel. In this example, we do this with exponential Euler, but you can implement any solver yourself:

    alpha = 0.01 * exp_update_alpha(-(v + 55), 10)\nbeta = 0.125 * jnp.exp(-(v + 65) / 80)\nnew_n = solve_gate_exponential(ns, dt, alpha, beta)\nreturn {\"n_new\": new_n}\n

    A channel also needs a compute_current() method which returns the current throught the channel:

        def compute_current(self, states, v, params):\n        ns = states[\"n_new\"]\n\n        # Multiply with 1000 to convert Siemens to milli Siemens.\n        kd_conds = params[\"gK_new\"] * ns**4 * 1000  # mS/cm^2\n\n        e_kd = -77.0        \n        current = kd_conds * (v - e_kd)\n        return current\n

    Finally, the init_state() method can be implemented optionally. It can be used to automatically compute the initial state based on the voltage when cell.init_states() is run.

    Alright, done! We can now insert this channel into any jx.Module such as our cell:

    cell.insert(Potassium())\n
    cell.delete_stimuli()\ncurrent = jx.step_current(1.0, 1.0, 0.1, 0.025, 10.0)\ncell.branch(0).comp(0).stimulate(current)\n\ncell.delete_recordings()\ncell.branch(0).comp(0).record()\n
    Added 1 external_states. See `.externals` for details.\nAdded 1 recordings. See `.recordings` for details.\n
    s = jx.integrate(cell)\n
    fig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = ax.plot(s.T[:-1])\n_ = ax.set_ylim([-80, 50])\n_ = ax.set_xlabel(\"Time (ms)\")\n_ = ax.set_ylabel(\"Voltage (mV)\")\n

    "},{"location":"tutorial/05_channel_and_synapse_models/#your-own-synapse","title":"Your own synapse","text":"

    The parts below assume that you have already learned how to build network simulations in Jaxley.

    The below is an example of how to define your own synapse model in Jaxley:`

    import jax.numpy as jnp\nfrom jaxley.synapses.synapse import Synapse\n\n\nclass TestSynapse(Synapse):\n    \"\"\"\n    Compute syanptic current and update syanpse state.\n    \"\"\"\n    def __init__(self, name = None):\n        super().__init__(name)\n        self.synapse_params = {\"gChol\": 0.001, \"eChol\": 0.0}\n        self.synapse_states = {\"s_chol\": 0.1}\n\n    def update_states(self, states, delta_t, pre_voltage, post_voltage, params):\n        \"\"\"Return updated synapse state and current.\"\"\"\n        s_inf = 1.0 / (1.0 + jnp.exp((-35.0 - pre_voltage) / 10.0))\n        exp_term = jnp.exp(-delta_t)\n        new_s = states[\"s_chol\"] * exp_term + s_inf * (1.0 - exp_term)\n        return {\"s_chol\": new_s}\n\n    def compute_current(self, states, pre_voltage, post_voltage, params):\n        g_syn = params[\"gChol\"] * states[\"s_chol\"]\n        return g_syn * (post_voltage - params[\"eChol\"])\n

    As you can see above, synapses follow closely how channels are defined. The main difference is that the compute_current method takes two voltages: the pre-synaptic voltage (a jnp.ndarray of shape ()) and the post-synaptic voltage (a jnp.ndarray of shape ()).

    net = jx.Network([cell for _ in range(3)])\n
    from jaxley.connect import connect\n\npre = net.cell(0).branch(0).loc(0.0)\npost = net.cell(1).branch(0).loc(0.0)\nconnect(pre, post, TestSynapse())\n
    /Users/michaeldeistler/Documents/phd/jaxley/jaxley/modules/base.py:1533: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.\n  self.pointer.edges = pd.concat(\n
    net.cell(0).branch(0).loc(0.0).stimulate(jx.step_current(1.0, 2.0, 0.1, 0.025, 10.0))\nfor i in range(3):\n    net.cell(i).branch(0).loc(0.0).record()\n
    Added 1 external_states. See `.externals` for details.\nAdded 1 recordings. See `.recordings` for details.\nAdded 1 recordings. See `.recordings` for details.\nAdded 1 recordings. See `.recordings` for details.\n
    s = jx.integrate(net)\n
    fig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = ax.plot(s.T[:-1])\n_ = ax.set_ylim([-80, 50])\n_ = ax.set_xlabel(\"Time (ms)\")\n_ = ax.set_ylabel(\"Voltage (mV)\")\n

    That\u2019s it! You are now ready to build your own custom simulations and equip them with channel and synapse models!

    This tutorial does not have an immediate follow-up tutorial. You could read the tutorial on groups, which allow to make your Jaxley simulations more elegant and convenient to interact with.

    Alternatively, you can also directly jump ahead to the tutorial on training biophysical networks which will teach you how you can optimize parameters of biophysical models with gradient descent.

    "},{"location":"tutorial/06_groups/","title":"Defining groups for easier handling of complex networks","text":"

    In this tutorial, you will learn how to:

    • define groups (aka sectionlists) to simplify iteractions with Jaxley

    Here is a code snippet which you will learn to understand in this tutorial:

    from jax import jit, vmap\n\n\nnet = ...  # See tutorial on Basics of Jaxley.\n\nnet.cell(0).add_to_group(\"fast_spiking\")\nnet.cell(1).add_to_group(\"slow_spiking\")\n\ndef simulate(params):\n    param_state = None\n    param_state = net.fast_spiking.data_set(\"HH_gNa\", params[0], param_state)\n    param_state = net.slow_spiking.data_set(\"HH_gNa\", params[1], param_state)\n    return jx.integrate(net, param_state=param_state)\n\n# Define sodium for fast and slow spiking neurons.\nparams = jnp.asarray([1.0, 0.1])\n\n# Run simulation.\nvoltages = simulate(params)\n

    In many cases, you might want to group several compartments (or branches, or cells) and assign a unique parameter or mechanism to this group. For example, you might want to define a couple of branches as basal and then assign a Hodgkin-Huxley mechanism only to those branches. Or you might define a couple of cells as fast spiking and assign them a high value for the sodium conductance. We describe how you can do this in this tutorial.

    from jax import config\nconfig.update(\"jax_enable_x64\", True)\nconfig.update(\"jax_platform_name\", \"cpu\")\n\nimport time\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport jax\nimport jax.numpy as jnp\nfrom jax import jit, value_and_grad\n\nimport jaxley as jx\nfrom jaxley.channels import Na, K, Leak\nfrom jaxley.synapses import IonotropicSynapse\nfrom jaxley.connect import fully_connect\n

    First, we define a network as you saw in the previous tutorial:

    comp = jx.Compartment()\nbranch = jx.Branch(comp, nseg=2)\ncell = jx.Cell(branch, parents=[-1, 0, 0, 1])\nnetwork = jx.Network([cell for _ in range(3)])\n\npre = network.cell([0, 1])\npost = network.cell([2])\nfully_connect(pre, post, IonotropicSynapse())\n\nnetwork.insert(Na())\nnetwork.insert(K())\nnetwork.insert(Leak())\n
    /Users/michaeldeistler/Documents/phd/jaxley/jaxley/modules/base.py:1533: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.\n  self.pointer.edges = pd.concat(\n
    "},{"location":"tutorial/06_groups/#group-apical-dendrites","title":"Group: apical dendrites","text":"

    Assume that, in each of the five neurons in this network, the second and forth branch are apical dendrites. We can define this as:

    for cell_ind in range(3):\n    network.cell(cell_ind).branch(1).add_to_group(\"apical\")\n    network.cell(cell_ind).branch(3).add_to_group(\"apical\")\n

    After this, we can access network.apical as we previously accesses anything else:

    network.apical.set(\"radius\", 0.3)\n
    network.apical.view\n
    comp_index branch_index cell_index length radius axial_resistivity capacitance v Na Na_gNa ... K_gK eK K_n Leak Leak_gLeak Leak_eLeak global_comp_index global_branch_index global_cell_index controlled_by_param 2 2 1 0 10.0 0.3 5000.0 1.0 -70.0 True 0.05 ... 0.005 -90.0 0.2 True 0.0001 -70.0 2 1 0 0 3 3 1 0 10.0 0.3 5000.0 1.0 -70.0 True 0.05 ... 0.005 -90.0 0.2 True 0.0001 -70.0 3 1 0 0 6 6 3 0 10.0 0.3 5000.0 1.0 -70.0 True 0.05 ... 0.005 -90.0 0.2 True 0.0001 -70.0 6 3 0 0 7 7 3 0 10.0 0.3 5000.0 1.0 -70.0 True 0.05 ... 0.005 -90.0 0.2 True 0.0001 -70.0 7 3 0 0 10 10 5 1 10.0 0.3 5000.0 1.0 -70.0 True 0.05 ... 0.005 -90.0 0.2 True 0.0001 -70.0 10 5 1 0 11 11 5 1 10.0 0.3 5000.0 1.0 -70.0 True 0.05 ... 0.005 -90.0 0.2 True 0.0001 -70.0 11 5 1 0 14 14 7 1 10.0 0.3 5000.0 1.0 -70.0 True 0.05 ... 0.005 -90.0 0.2 True 0.0001 -70.0 14 7 1 0 15 15 7 1 10.0 0.3 5000.0 1.0 -70.0 True 0.05 ... 0.005 -90.0 0.2 True 0.0001 -70.0 15 7 1 0 18 18 9 2 10.0 0.3 5000.0 1.0 -70.0 True 0.05 ... 0.005 -90.0 0.2 True 0.0001 -70.0 18 9 2 0 19 19 9 2 10.0 0.3 5000.0 1.0 -70.0 True 0.05 ... 0.005 -90.0 0.2 True 0.0001 -70.0 19 9 2 0 22 22 11 2 10.0 0.3 5000.0 1.0 -70.0 True 0.05 ... 0.005 -90.0 0.2 True 0.0001 -70.0 22 11 2 0 23 23 11 2 10.0 0.3 5000.0 1.0 -70.0 True 0.05 ... 0.005 -90.0 0.2 True 0.0001 -70.0 23 11 2 0

    12 rows \u00d7 25 columns

    "},{"location":"tutorial/06_groups/#group-fast-spiking","title":"Group: fast spiking","text":"

    Similarly, you could define a group of fast-spiking cells. Assume that the first and second cell are fast-spiking:

    network.cell(0).add_to_group(\"fast_spiking\")\nnetwork.cell(1).add_to_group(\"fast_spiking\")\n
    network.fast_spiking.set(\"Na_gNa\", 0.4)\n
    network.fast_spiking.view\n
    comp_index branch_index cell_index length radius axial_resistivity capacitance v Na Na_gNa ... K_gK eK K_n Leak Leak_gLeak Leak_eLeak global_comp_index global_branch_index global_cell_index controlled_by_param 0 0 0 0 10.0 1.0 5000.0 1.0 -70.0 True 0.4 ... 0.005 -90.0 0.2 True 0.0001 -70.0 0 0 0 0 1 1 0 0 10.0 1.0 5000.0 1.0 -70.0 True 0.4 ... 0.005 -90.0 0.2 True 0.0001 -70.0 1 0 0 0 2 2 1 0 10.0 0.3 5000.0 1.0 -70.0 True 0.4 ... 0.005 -90.0 0.2 True 0.0001 -70.0 2 1 0 0 3 3 1 0 10.0 0.3 5000.0 1.0 -70.0 True 0.4 ... 0.005 -90.0 0.2 True 0.0001 -70.0 3 1 0 0 4 4 2 0 10.0 1.0 5000.0 1.0 -70.0 True 0.4 ... 0.005 -90.0 0.2 True 0.0001 -70.0 4 2 0 0 5 5 2 0 10.0 1.0 5000.0 1.0 -70.0 True 0.4 ... 0.005 -90.0 0.2 True 0.0001 -70.0 5 2 0 0 6 6 3 0 10.0 0.3 5000.0 1.0 -70.0 True 0.4 ... 0.005 -90.0 0.2 True 0.0001 -70.0 6 3 0 0 7 7 3 0 10.0 0.3 5000.0 1.0 -70.0 True 0.4 ... 0.005 -90.0 0.2 True 0.0001 -70.0 7 3 0 0 8 8 4 1 10.0 1.0 5000.0 1.0 -70.0 True 0.4 ... 0.005 -90.0 0.2 True 0.0001 -70.0 8 4 1 0 9 9 4 1 10.0 1.0 5000.0 1.0 -70.0 True 0.4 ... 0.005 -90.0 0.2 True 0.0001 -70.0 9 4 1 0 10 10 5 1 10.0 0.3 5000.0 1.0 -70.0 True 0.4 ... 0.005 -90.0 0.2 True 0.0001 -70.0 10 5 1 0 11 11 5 1 10.0 0.3 5000.0 1.0 -70.0 True 0.4 ... 0.005 -90.0 0.2 True 0.0001 -70.0 11 5 1 0 12 12 6 1 10.0 1.0 5000.0 1.0 -70.0 True 0.4 ... 0.005 -90.0 0.2 True 0.0001 -70.0 12 6 1 0 13 13 6 1 10.0 1.0 5000.0 1.0 -70.0 True 0.4 ... 0.005 -90.0 0.2 True 0.0001 -70.0 13 6 1 0 14 14 7 1 10.0 0.3 5000.0 1.0 -70.0 True 0.4 ... 0.005 -90.0 0.2 True 0.0001 -70.0 14 7 1 0 15 15 7 1 10.0 0.3 5000.0 1.0 -70.0 True 0.4 ... 0.005 -90.0 0.2 True 0.0001 -70.0 15 7 1 0

    16 rows \u00d7 25 columns

    "},{"location":"tutorial/06_groups/#groups-from-swc-files","title":"Groups from SWC files","text":"

    Note: If you are reading swc morphologigies, you can automatically assign groups with jx.read_swc(file_name, assign_groups=True). After that, you can directly use cell.soma, cell.apical, cell.basal, or cell.axon.

    "},{"location":"tutorial/06_groups/#how-groups-are-interpreted-by-make_trainable","title":"How groups are interpreted by .make_trainable()","text":"

    If you make a parameter of a group trainable, then it will be treated as a single shared parameter for a given property:

    network.fast_spiking.make_trainable(\"Na_gNa\")\n
    Number of newly added trainable parameters: 1. Total number of trainable parameters: 1\n

    As such, get_parameters() returns only a single trainable parameter, which will be the sodium conductance for every compartment of every fast-spiking neuron:

    network.get_parameters()\n
    [{'Na_gNa': Array([0.4], dtype=float64)}]\n

    If, instead, you would want a separate parameter for every fast-spiking cell, you should not use the group, but instead do the following (remember that fast-spiking neurons had indices [0,1,3]):

    network.cell([0,1]).make_trainable(\"axial_resistivity\")\n
    Number of newly added trainable parameters: 2. Total number of trainable parameters: 3\n
    network.get_parameters()\n
    [{'Na_gNa': Array([0.4], dtype=float64)},\n {'axial_resistivity': Array([5000., 5000.], dtype=float64)}]\n

    This generated three parameters for the axial resistivitiy, each corresponding to one cell.

    "},{"location":"tutorial/06_groups/#summary","title":"Summary","text":"

    Groups allow you to organize your simulation in a more intuitive way, and they allow to perform parameter sharing with make_trainable().

    "},{"location":"tutorial/07_gradient_descent/","title":"Training biophysical models","text":"

    In this tutorial, you will learn how to train biophysical models in Jaxley. This includes the following:

    • compute the gradient with respect to parameters
    • use parameter transformations
    • use multi-level checkpointing
    • define optimizers
    • write dataloaders and parallelize across data

    Here is a code snippet which you will learn to understand in this tutorial:

    from jax import jit, vmap, value_and_grad\nimport jaxley as jx\n\n\nnet = ...  # See tutorial on the basics of `Jaxley`.\n\n# Define which parameters to train.\nnet.cell(\"all\").make_trainable(\"HH_gNa\")\nnet.IonotropicSynapse.make_trainable(\"IonotropicSynapse_gS\")\nparameters = net.get_parameters()\n\n# Define parameter transform and apply it to the parameters.\ntransform = jx.ParamTransform(\n    lowers={\"HH_gNa\": 0.0, \"IonotropicSynapse_gS\": 0.0},\n    uppers={\"HH_gNa\": 1.0, \"IonotropicSynapse_gS\": 1.0},\n)\nopt_params = transform.inverse(parameters)\n\n# Define simulation and batch it across stimuli.\ndef simulate(params, datapoint):\n    current = jx.datapoint_to_step_currents(i_delay=1.0, i_dur=1.0, i_amps=datapoint, dt=0.025, t_max=5.0)\n    data_stimuli = net.cell(0).branch(0).comp(0).data_stimulate(current, None\n    return jx.integrate(net, params=params, data_stimuli=data_stimuli, checkpoint_inds=[20, 20])\n\nbatch_simulate = vmap(simulate, in_axes=(None, 0))\n\n# Define loss function and its gradient.\ndef loss_fn(opt_params, datapoints, label):\n    params = transform.forward(opt_params)\n    voltages = batch_simulate(params, datapoints)\n    return jnp.abs(jnp.mean(voltages) - label)\n\ngrad_fn = jit(value_and_grad(loss_fn, argnums=0))\n\n# Define data and dataloader.\ndata = jnp.asarray(np.random.randn(100, 3))\ndataloader = Dataset.from_tensor_slices((inputs, labels))\ndataloader = dataloader.shuffle(dataloader.cardinality()).batch(4)\n\n# Define the optimizer.\noptimizer = optax.Adam(lr=0.01)\nopt_state = optimizer.init_state(opt_params)\n\nfor epoch in range(10):\n    for batch in dataloader:\n        stimuli = batch[0].numpy()\n        labels = batch[1].numpy()\n        loss, gradient = grad_fn(opt_params, stimuli, labels)\n\n        # Optimizer step.\n        updates, opt_state = optimizer.update(gradient, opt_state)\n        opt_params = optax.apply_updates(opt_params, updates)\n

    from jax import config\nconfig.update(\"jax_enable_x64\", True)\nconfig.update(\"jax_platform_name\", \"cpu\")\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport jax\nimport jax.numpy as jnp\nfrom jax import jit, vmap, value_and_grad\n\nimport jaxley as jx\nfrom jaxley.channels import Leak\nfrom jaxley.synapses import TanhRateSynapse\nfrom jaxley.connect import fully_connect\n

    First, we define a network as you saw in the previous tutorial:

    _ = np.random.seed(0)  # For synaptic locations.\n\ncomp = jx.Compartment()\nbranch = jx.Branch(comp, nseg=2)\ncell = jx.Cell(branch, parents=[-1, 0, 0])\nnet = jx.Network([cell for _ in range(3)])\n\npre = net.cell([0, 1])\npost = net.cell([2])\nfully_connect(pre, post, TanhRateSynapse())\n\n# Change some default values of the tanh synapse.\nnet.TanhRateSynapse.set(\"TanhRateSynapse_x_offset\", -60.0)\nnet.TanhRateSynapse.set(\"TanhRateSynapse_gS\", 1e-3)\nnet.TanhRateSynapse.set(\"TanhRateSynapse_slope\", 0.1)\n\nnet.insert(Leak())\n
    /Users/michaeldeistler/Documents/phd/jaxley/jaxley/modules/base.py:1533: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.\n  self.pointer.edges = pd.concat(\n

    This network consists of three neurons arranged in two layers:

    net.compute_xyz()\nnet.rotate(180)\nfig, ax = plt.subplots(1, 1, figsize=(3, 2))\n_ = net.vis(ax=ax, detail=\"full\", layers=[2, 1], layer_kwargs={\"within_layer_offset\": 100.0, \"between_layer_offset\": 100.0}) \n

    We consider the last neuron as the output neuron and record the voltage from there:

    net.delete_recordings()\nnet.cell(0).branch(0).loc(0.0).record()\nnet.cell(1).branch(0).loc(0.0).record()\nnet.cell(2).branch(0).loc(0.0).record()\n
    Added 1 recordings. See `.recordings` for details.\nAdded 1 recordings. See `.recordings` for details.\nAdded 1 recordings. See `.recordings` for details.\n
    "},{"location":"tutorial/07_gradient_descent/#defining-a-dataset","title":"Defining a dataset","text":"

    We will train this biophysical network on a classification task. The inputs will be values and the label is binary:

    inputs = jnp.asarray(np.random.rand(100, 2))\nlabels = jnp.asarray((inputs[:, 0] + inputs[:, 1]) > 1.0)\n
    fig, ax = plt.subplots(1, 1, figsize=(3, 2))\n_ = ax.scatter(inputs[labels, 0], inputs[labels, 1])\n_ = ax.scatter(inputs[~labels, 0], inputs[~labels, 1])\n

    labels = labels.astype(float)\n
    "},{"location":"tutorial/07_gradient_descent/#defining-trainable-parameters","title":"Defining trainable parameters","text":"
    net.delete_trainables()\n

    This follows the same API as .set() seen in the previous tutorial. If you want to use a single parameter for all radiuses in the entire network, do:

    net.make_trainable(\"radius\")\n
    Number of newly added trainable parameters: 1. Total number of trainable parameters: 1\n

    We can also define parameters for individual compartments. To do this, use the \"all\" key. The following defines a separate parameter the sodium conductance for every compartment in the entire network:

    net.cell(\"all\").branch(\"all\").loc(\"all\").make_trainable(\"Leak_gLeak\")\n
    Number of newly added trainable parameters: 18. Total number of trainable parameters: 19\n
    "},{"location":"tutorial/07_gradient_descent/#making-synaptic-parameters-trainable","title":"Making synaptic parameters trainable","text":"

    Synaptic parameters can be made trainable in the exact same way. To use a single parameter for all syanptic conductances in the entire network, do

    net.TanhRateSynapse.make_trainable(\"TanhRateSynapse_gS\")\n

    Here, we use a different syanptic conductance for all syanpses. This can be done as follows:

    net.TanhRateSynapse(\"all\").make_trainable(\"TanhRateSynapse_gS\")\n
    Number of newly added trainable parameters: 2. Total number of trainable parameters: 21\n
    "},{"location":"tutorial/07_gradient_descent/#running-the-simulation","title":"Running the simulation","text":"

    Once all parameters are defined, you have to use .get_parameters() to obtain all trainable parameters. This is also the time to check how many trainable parameters your network has:

    params = net.get_parameters()\n

    You can now run the simulation with the trainable parameters by passing them to the jx.integrate function.

    s = jx.integrate(net, params=params, t_max=10.0)\n
    "},{"location":"tutorial/07_gradient_descent/#stimulating-the-network","title":"Stimulating the network","text":"

    The network above does not yet get any stimuli. We will use the 2D inputs from the dataset to stimulate the two input neurons. The amplitude of the step current corresponds to the input value. Below is the simulator that defines this:

    def simulate(params, inputs):\n    currents = jx.datapoint_to_step_currents(i_delay=1.0, i_dur=1.0, i_amp=inputs / 10, delta_t=0.025, t_max=10.0)\n\n    data_stimuli = None\n    data_stimuli = net.cell(0).branch(2).loc(1.0).data_stimulate(currents[0], data_stimuli=data_stimuli)\n    data_stimuli = net.cell(1).branch(2).loc(1.0).data_stimulate(currents[1], data_stimuli=data_stimuli)\n\n    return jx.integrate(net, params=params, data_stimuli=data_stimuli)\n\nbatched_simulate = vmap(simulate, in_axes=(None, 0))\n

    We can also inspect some traces:

    traces = batched_simulate(params, inputs[:4])\n
    fig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = ax.plot(traces[:, 2, :].T)\n

    "},{"location":"tutorial/07_gradient_descent/#defining-a-loss-function","title":"Defining a loss function","text":"

    Let us define a loss function to be optimized:

    def loss(params, inputs, labels):\n    traces = batched_simulate(params, inputs)  # Shape `(batchsize, num_recordings, timepoints)`.\n    prediction = jnp.mean(traces[:, 2], axis=1)  # Use the average over time of the output neuron (2) as prediction.\n    prediction = (prediction + 72.0) / 5  # Such that the prediction is roughly in [0, 1].\n    losses = jnp.abs(prediction - labels)  # Mean absolute error loss.\n    return jnp.mean(losses)  # Average across the batch.\n

    And we can use JAX\u2019s inbuilt functions to take the gradient through the entire ODE:

    jitted_grad = jit(value_and_grad(loss, argnums=0))\n
    value, gradient = jitted_grad(params, inputs[:4], labels[:4])\n
    "},{"location":"tutorial/07_gradient_descent/#defining-parameter-transformations","title":"Defining parameter transformations","text":"

    Before training, however, we will enforce for all parameters to be within a prespecified range (such that, e.g., conductances can not become negative)

    transform = jx.ParamTransform(\n    lowers={\n        \"Leak_gLeak\": 1e-5,\n        \"radius\": 0.1,\n        \"TanhRateSynapse_gS\": 1e-5,\n    },\n    uppers={\n        \"Leak_gLeak\": 1e-3,\n        \"radius\": 5.0,\n        \"TanhRateSynapse_gS\": 1e-2,\n    }, \n)\n

    With these modify the loss function acocrdingly:

    def loss(opt_params, inputs, labels):\n    transform.forward(opt_params)\n\n    traces = batched_simulate(params, inputs)  # Shape `(batchsize, num_recordings, timepoints)`.\n    prediction = jnp.mean(traces[:, 2], axis=1)  # Use the average over time of the output neuron (2) as prediction.\n    prediction = (prediction + 72.0)  # Such that the prediction is around 0.\n    losses = jnp.abs(prediction - labels)  # Mean absolute error loss.\n    return jnp.mean(losses)  # Average across the batch.\n
    "},{"location":"tutorial/07_gradient_descent/#using-checkpointing","title":"Using checkpointing","text":"

    Checkpointing allows to vastly reduce the memory requirements of training biophysical models.

    t_max = 5.0\ndt = 0.025\n\nlevels = 2\ntime_points = t_max // dt + 2\ncheckpoints = [int(np.ceil(time_points**(1/levels))) for _ in range(levels)]\n

    To enable checkpointing, we have to modify the simulate function appropriately and use

    jx.integrate(..., checkpoint_inds=checkpoints)\n
    as done below:

    def simulate(params, inputs):\n    currents = jx.datapoint_to_step_currents(i_delay=1.0, i_dur=1.0, i_amp=inputs / 10.0, delta_t=dt, t_max=t_max)\n\n    data_stimuli = None\n    data_stimuli = net.cell(0).branch(2).loc(1.0).data_stimulate(currents[0], data_stimuli=data_stimuli)\n    data_stimuli = net.cell(1).branch(2).loc(1.0).data_stimulate(currents[1], data_stimuli=data_stimuli)\n\n    return jx.integrate(net, params=params, data_stimuli=data_stimuli, checkpoint_lengths=checkpoints)\n\nbatched_simulate = vmap(simulate, in_axes=(None, 0))\n\n\ndef predict(params, inputs):\n    traces = simulate(params, inputs)  # Shape `(batchsize, num_recordings, timepoints)`.\n    prediction = jnp.mean(traces[2])  # Use the average over time of the output neuron (2) as prediction.\n    return prediction + 72.0  # Such that the prediction is around 0.\n\nbatched_predict = vmap(predict, in_axes=(None, 0))\n\n\ndef loss(opt_params, inputs, labels):\n    params = transform.forward(opt_params)\n\n    predictions = batched_predict(params, inputs)\n    losses = jnp.abs(predictions - labels)  # Mean absolute error loss.\n    return jnp.mean(losses)  # Average across the batch.\n\njitted_grad = jit(value_and_grad(loss, argnums=0))\n
    "},{"location":"tutorial/07_gradient_descent/#training","title":"Training","text":"

    We will use the ADAM optimizer from the optax library to optimize the free parameters (you have to install the package with pip install optax first):

    import optax\n
    opt_params = transform.inverse(params)\noptimizer = optax.adam(learning_rate=0.01)\nopt_state = optimizer.init(opt_params)\n
    "},{"location":"tutorial/07_gradient_descent/#writing-a-dataloader","title":"Writing a dataloader","text":"
    import tensorflow as tf\nfrom tensorflow.data import Dataset\n
    batch_size = 4\n\ntf.random.set_seed(1)\ndataloader = Dataset.from_tensor_slices((inputs, labels))\ndataloader = dataloader.shuffle(dataloader.cardinality()).batch(batch_size)\n
    "},{"location":"tutorial/07_gradient_descent/#training-loop","title":"Training loop","text":"
    for epoch in range(10):\n    epoch_loss = 0.0\n    for batch_ind, batch in enumerate(dataloader):\n        current_batch = batch[0].numpy()\n        label_batch = batch[1].numpy()\n        loss_val, gradient = jitted_grad(opt_params, current_batch, label_batch)\n        updates, opt_state = optimizer.update(gradient, opt_state)\n        opt_params = optax.apply_updates(opt_params, updates)\n        epoch_loss += loss_val\n\n    print(f\"epoch {epoch}, loss {epoch_loss}\")\n\nfinal_params = transform.forward(opt_params)\n
    epoch 0, loss 25.61663325387099\nepoch 1, loss 21.7304402547341\nepoch 2, loss 15.943236054666484\nepoch 3, loss 9.191846765081072\nepoch 4, loss 7.256558484588674\nepoch 5, loss 6.577375342584615\nepoch 6, loss 6.568056585075223\nepoch 7, loss 6.510474263850299\nepoch 8, loss 6.481302675498705\nepoch 9, loss 6.5030439519558865\n
    ntest = 32\npredictions = batched_predict(final_params, inputs[:ntest])\n
    fig, ax = plt.subplots(1, 1, figsize=(3, 2))\n_ = ax.scatter(labels[:ntest], predictions)\n_ = ax.set_xlabel(\"Label\")\n_ = ax.set_ylabel(\"Prediction\")\n

    Indeed, the loss goes down and the network successfully classifies the patterns.

    "},{"location":"tutorial/07_gradient_descent/#summary","title":"Summary","text":"

    Puh, this was a pretty dense tutorial with a lot of material. You should have learned how to:

    • compute the gradient with respect to parameters
    • use parameter transformations
    • use multi-level checkpointing
    • define optimizers
    • write dataloaders and parallelize across data

    This was the last tutorial of the Jaxley toolbox. If anything is still unclear please create a discussion. If you find any bugs, please open an issue. Happy coding!

    "},{"location":"tutorial/08_importing_morphologies/","title":"Working with morphologies","text":"

    In this tutorial, you will learn how to:

    • Load morphologies and make them compatible with Jaxley
    • How to use the visualization features
    • How to assemble a small network of morphologically accurate cells.

    Here is a code snippet which you will learn to understand in this tutorial:

    import jaxley as jx\n\ncell = jx.read_swc(\"my_cell.swc\", nseg=4, assign_groups=True)\n

    To work with more complicated morphologies, Jaxley supports importing morphological reconstructions via .swc files. .swc is currently the only supported format. Other formats like .asc need to be converted to .swc first, for example using the BlueBrain\u2019s morph-tool. For more information on the exact specifications of .swc see here.

    import jaxley as jx\nfrom jaxley.synapses import IonotropicSynapse\nimport matplotlib.pyplot as plt\n

    To work with .swc files, Jaxley implements a custom .swc reader. The reader traces the morphology and identifies all uninterrupted sections. These are then partitioned into branches, each of which will be approximated by a number of equally many compartments that can be simulated fully in parallel.

    To demonstrate this, let\u2019s import an example morphology of a Layer 5 pyramidal cell and visualize it.

    # import swc file into jx.Cell object\nfname = \"data/morph.swc\"\ncell = jx.read_swc(fname, nseg=8, max_branch_len=2000.0, assign_groups=True)\n\n# print shape (num_cells, num_branches, num_comps)\nprint(cell.shape)\n\ncell.show()\n
    (1, 157, 8)\n
    comp_index branch_index cell_index length radius axial_resistivity capacitance v 0 0 0 0 0.01250 8.119 5000.0 1.0 -70.0 1 1 0 0 0.01250 8.119 5000.0 1.0 -70.0 2 2 0 0 0.01250 8.119 5000.0 1.0 -70.0 3 3 0 0 0.01250 8.119 5000.0 1.0 -70.0 4 4 0 0 0.01250 8.119 5000.0 1.0 -70.0 ... ... ... ... ... ... ... ... ... 1251 1251 156 0 24.12382 0.550 5000.0 1.0 -70.0 1252 1252 156 0 24.12382 0.550 5000.0 1.0 -70.0 1253 1253 156 0 24.12382 0.550 5000.0 1.0 -70.0 1254 1254 156 0 24.12382 0.550 5000.0 1.0 -70.0 1255 1255 156 0 24.12382 0.550 5000.0 1.0 -70.0

    1256 rows \u00d7 8 columns

    As we can see, this yields a morphology that is approximated by 1256 compartments. Depending on the amount of detail that you need, you can also change the number of compartments in each branch:

    cell = jx.read_swc(fname, nseg=2, max_branch_len=2000.0, assign_groups=True)\n\n# print shape (num_cells, num_branches, num_comps)\nprint(cell.shape)\n\ncell.show()\n
    (1, 157, 2)\n
    comp_index branch_index cell_index length radius axial_resistivity capacitance v 0 0 0 0 0.050000 8.119000 5000.0 1.0 -70.0 1 1 0 0 0.050000 8.119000 5000.0 1.0 -70.0 2 2 1 0 6.241557 7.493344 5000.0 1.0 -70.0 3 3 1 0 6.241557 4.273686 5000.0 1.0 -70.0 4 4 2 0 4.160500 7.960000 5000.0 1.0 -70.0 ... ... ... ... ... ... ... ... ... 309 309 154 0 49.728572 0.400000 5000.0 1.0 -70.0 310 310 155 0 46.557908 0.494201 5000.0 1.0 -70.0 311 311 155 0 46.557908 0.302202 5000.0 1.0 -70.0 312 312 156 0 96.495281 0.742532 5000.0 1.0 -70.0 313 313 156 0 96.495281 0.550000 5000.0 1.0 -70.0

    314 rows \u00d7 8 columns

    # visualize the cell\ncell.vis()\nplt.axis(\"off\")\nplt.title(\"L5PC\")\nplt.show()\n

    While we only use two compartments to approximate each branch in this example, we can see the morphology is still plotted in great detail. This is because we always plot the full .swc reconstruction irrespective of the number of compartments used. This is stored in the cell.xyzr attribute in a per branch fashion.

    To highlight each branch seperately, we can iterate over them.

    fig, ax = plt.subplots(1, 1, figsize=(5, 5))\n# define colorwheel with 10 colors\ncolors = plt.cm.tab10.colors\nfor i in range(cell.shape[1]):\n    cell.branch(i).vis(ax=ax, col=colors[i % 10])\nplt.axis(\"off\")\nplt.title(\"Branches\")\nplt.show()\n

    Since Jaxley supports grouping different branches or compartments together, we can also use the id labels provided by the .swc file to assign group labels to the jx.Cell object.

    print(list(cell.group_nodes.keys()))\n\nfig, ax = plt.subplots(1, 1, figsize=(5, 5))\ncolors = plt.cm.tab10.colors\ncell.basal.vis(ax=ax, col=colors[2])\ncell.soma.vis(ax=ax, col=colors[1])\ncell.apical.vis(ax=ax, col=colors[0])\nplt.axis(\"off\")\nplt.title(\"Groups\")\nplt.show()\n
    ['soma', 'basal', 'apical', 'custom']\n

    To build a network of morphologically detailed cells, we can now connect several reconstructed cells together and also visualize the network. However, since all cells are going to have the same center, Jaxley will naively plot all of them on top of each other. To seperate out the cells, we therefore have to move them to a new location first.

    net = jx.Network([cell]*5)\njx.connect(net[0,0,0], net[2,0,0], IonotropicSynapse())\njx.connect(net[0,0,0], net[3,0,0], IonotropicSynapse())\njx.connect(net[0,0,0], net[4,0,0], IonotropicSynapse())\n\njx.connect(net[1,0,0], net[2,0,0], IonotropicSynapse())\njx.connect(net[1,0,0], net[3,0,0], IonotropicSynapse())\njx.connect(net[1,0,0], net[4,0,0], IonotropicSynapse())\n\nnet.rotate(-90)\n\nnet.cell(0).move(0, 300)\nnet.cell(1).move(0, 500)\n\nnet.cell(2).move(900, 200)\nnet.cell(3).move(900, 400)\nnet.cell(4).move(900, 600)\n\nnet.vis()\nplt.axis(\"off\")\nplt.show()\n
    /home/jnsbck/Uni/PhD/projects/jaxleyverse/jaxley/jaxley/modules/base.py:1528: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.\n  self.pointer.edges = pd.concat(\n

    Congrats! You have now learned how to vizualize and build networks out of very complex morphologies. To simulate this network, you can follow the steps in the tutroial on how to build a network.

    "}]} \ No newline at end of file +{"config":{"lang":["en"],"separator":"[\\s\\-]+","pipeline":["stopWordFilter"]},"docs":[{"location":"","title":"Home","text":"

    Jaxley is a differentiable simulator for biophysical neuron models in JAX. Its key features are:

    • automatic differentiation, allowing gradient-based optimization of thousands of parameters
    • support for CPU, GPU, or TPU without any changes to the code
    • jit-compilation, making it as fast as other packages while being fully written in python
    • backward-Euler solver for stable numerical solution of multicompartment neurons
    • elegant mechanisms for parameter sharing
    "},{"location":"#getting-started","title":"Getting started","text":"

    Jaxley allows to simulate biophysical neuron models on CPU, GPU, or TPU:

    import matplotlib.pyplot as plt\nfrom jax import config\n\nimport jaxley as jx\nfrom jaxley.channels import HH\n\nconfig.update(\"jax_platform_name\", \"cpu\")  # Or \"gpu\" / \"tpu\".\n\ncell = jx.Cell()  # Define cell.\ncell.insert(HH())  # Insert channels.\n\ncurrent = jx.step_current(i_delay=1.0, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=10.0)\ncell.stimulate(current)  # Stimulate with step current.\ncell.record(\"v\")  # Record voltage.\n\nv = jx.integrate(cell)  # Run simulation.\nplt.plot(v.T)  # Plot voltage trace.\n

    If you want to learn more, we have tutorials on how to:

    • simulate morphologically detailed neurons
    • simulate networks of such neurons
    • set parameters of cells and networks
    • speed up simulations with GPUs and jit
    • define your own channels and synapses
    • define groups
    • read and handle SWC files
    • compute the gradient and train biophysical models
    "},{"location":"#installation","title":"Installation","text":"

    Jaxley is available on pypi:

    pip install jaxley\n
    This will install Jaxley with CPU support. If you want GPU support, follow the instructions on the JAX github repository to install JAX with GPU support (in addition to installing Jaxley). For example, for NVIDIA GPUs, run
    pip install -U \"jax[cuda12]\"\n

    "},{"location":"#feedback-and-contributions","title":"Feedback and Contributions","text":"

    We welcome any feedback on how Jaxley is working for your neuron models and are happy to receive bug reports, pull requests and other feedback (see contribute). We wish to maintain a positive community, please read our Code of Conduct.

    "},{"location":"#license","title":"License","text":"

    Apache License Version 2.0 (Apache-2.0)

    "},{"location":"#citation","title":"Citation","text":"

    If you use Jaxley, consider citing the corresponding paper:

    @article{deistler2024differentiable,\n  doi = {10.1101/2024.08.21.608979},\n  year = {2024},\n  publisher = {Cold Spring Harbor Laboratory},\n  author = {Deistler, Michael and Kadhim, Kyra L. and Pals, Matthijs and Beck, Jonas and Huang, Ziwei and Gloeckler, Manuel and Lappalainen, Janne K. and Schr{\\\"o}der, Cornelius and Berens, Philipp and Gon{\\c c}alves, Pedro J. and Macke, Jakob H.},\n  title = {Differentiable simulation enables large-scale training of detailed biophysical models of neural dynamics},\n  journal = {bioRxiv}\n}\n
    "},{"location":"code_of_conduct/","title":"Contributor Covenant Code of Conduct","text":""},{"location":"code_of_conduct/#our-pledge","title":"Our Pledge","text":"

    We as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, caste, color, religion, or sexual identity and orientation.

    We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community.

    "},{"location":"code_of_conduct/#our-standards","title":"Our Standards","text":"

    Examples of behavior that contributes to a positive environment for our community include:

    • Demonstrating empathy and kindness toward other people
    • Being respectful of differing opinions, viewpoints, and experiences
    • Giving and gracefully accepting constructive feedback
    • Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience
    • Focusing on what is best not just for us as individuals, but for the overall community

    Examples of unacceptable behavior include:

    • The use of sexualized language or imagery, and sexual attention or advances of any kind
    • Trolling, insulting or derogatory comments, and personal or political attacks
    • Public or private harassment
    • Publishing others\u2019 private information, such as a physical or email address, without their explicit permission
    • Other conduct which could reasonably be considered inappropriate in a professional setting
    "},{"location":"code_of_conduct/#enforcement-responsibilities","title":"Enforcement Responsibilities","text":"

    Community leaders are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any behavior that they deem inappropriate, threatening, offensive, or harmful.

    Community leaders have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, and will communicate reasons for moderation decisions when appropriate.

    "},{"location":"code_of_conduct/#scope","title":"Scope","text":"

    This Code of Conduct applies within all community spaces, and also applies when an individual is officially representing the community in public spaces. Examples of representing our community include using an official e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event.

    "},{"location":"code_of_conduct/#enforcement","title":"Enforcement","text":"

    Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting jaxley developer Michael Deistler via email (michael.deistler@uni-tuebingen.de). All complaints will be reviewed and investigated promptly and fairly.

    All community leaders are obligated to respect the privacy and security of the reporter of any incident.

    "},{"location":"code_of_conduct/#enforcement-guidelines","title":"Enforcement Guidelines","text":"

    Community leaders will follow these Community Impact Guidelines in determining the consequences for any action they deem in violation of this Code of Conduct:

    "},{"location":"code_of_conduct/#1-correction","title":"1. Correction","text":"

    Community Impact: Use of inappropriate language or other behavior deemed unprofessional or unwelcome in the community.

    Consequence: A private, written warning from community leaders, providing clarity around the nature of the violation and an explanation of why the behavior was inappropriate. A public apology may be requested.

    "},{"location":"code_of_conduct/#2-warning","title":"2. Warning","text":"

    Community Impact: A violation through a single incident or series of actions.

    Consequence: A warning with consequences for continued behavior. No interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, for a specified period of time. This includes avoiding interactions in community spaces as well as external channels like social media. Violating these terms may lead to a temporary or permanent ban.

    "},{"location":"code_of_conduct/#3-temporary-ban","title":"3. Temporary Ban","text":"

    Community Impact: A serious violation of community standards, including sustained inappropriate behavior.

    Consequence: A temporary ban from any sort of interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Violating these terms may lead to a permanent ban.

    "},{"location":"code_of_conduct/#4-permanent-ban","title":"4. Permanent Ban","text":"

    Community Impact: Demonstrating a pattern of violation of community standards, including sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals.

    Consequence: A permanent ban from any sort of public interaction within the community.

    "},{"location":"code_of_conduct/#attribution","title":"Attribution","text":"

    This Code of Conduct is adapted from the Contributor Covenant, version 2.1, available at https://www.contributor-covenant.org/version/2/1/code_of_conduct.html.

    Community Impact Guidelines were inspired by Mozilla\u2019s code of conduct enforcement ladder.

    For answers to common questions about this code of conduct, see the FAQ at https://www.contributor-covenant.org/faq. Translations are available at https://www.contributor-covenant.org/translations.

    "},{"location":"contribute/","title":"Guide","text":""},{"location":"contribute/#user-experiences-bugs-and-feature-requests","title":"User experiences, bugs, and feature requests","text":"

    To report bugs and suggest features (including better documentation), please head over to issues on GitHub.

    "},{"location":"contribute/#code-contributions","title":"Code contributions","text":"

    In general, we use pull requests to make changes to Jaxley. So, if you are planning to make a contribution, please fork, create a feature branch and then make a PR from your feature branch to the upstream Jaxley (details).

    "},{"location":"contribute/#development-environment","title":"Development environment","text":"

    Clone the repo and install via setup.py using pip install -e \".[dev]\" (the dev flag installs development and testing dependencies).

    "},{"location":"contribute/#style-conventions","title":"Style conventions","text":"

    For docstrings and comments, we use Google Style.

    Code needs to pass through the following tools, which are installed alongside Jaxley:

    black: Automatic code formatting for Python. You can run black manually from the console using black . in the top directory of the repository, which will format all files.

    isort: Used to consistently order imports. You can run isort manually from the console using isort in the top directory.

    black and isort are checked as part of our CI actions. If these checks fail please make sure you have installed the latest versions for each of them and run them locally.

    "},{"location":"contribute/#online-documentation","title":"Online documentation","text":"

    Most of the documentation is written in markdown (basic markdown guide).

    You can directly fix mistakes and suggest clearer formulations in markdown files simply by initiating a PR on through GitHub. Click on documentation file and look for the little pencil at top right.

    "},{"location":"credits/","title":"Credits","text":"

    Jaxley is a collaborative project between the groups of Jakob Macke (Uni T\u00fcbingen), Pedro Gon\u00e7alves (KU Leuven / NERF), and Philipp Berens (Uni T\u00fcbingen).

    "},{"location":"credits/#license","title":"License","text":"

    Jaxley is licensed under the Apache License Version 2.0 (Apache-2.0) and

    Copyright (C) 2024 Michael Deistler, Jakob H. Macke, Pedro J. Goncalves, Philipp Berens.

    "},{"location":"credits/#important-dependencies-and-prior-art","title":"Important dependencies and prior art","text":"
    • We greatly benefited from previous toolboxes for simulating multicompartment neurons, in particular NEURON.
    "},{"location":"credits/#funding","title":"Funding","text":"

    This work was supported by the German Research Foundation (DFG) through Germany\u2019s Excellence Strategy (EXC 2064 \u2013 Project number 390727645) and the CRC 1233 \u201cRobust Vision\u201d, the German Federal Ministry of Education and Research (Tu\u0308bingen AI Center, FKZ: 01IS18039A), the \u2018Certification and Foundations of Safe Machine Learning Systems in Healthcare\u2019 project funded by the Carl Zeiss Foundation, and the European Union (ERC, \u201cDeepCoMechTome\u201d, ref. 101089288, \u201cNextMechMod\u201d, ref. 101039115).

    "},{"location":"faq/","title":"Frequently asked questions","text":"
    • What units does Jaxley use?
    • How can I save and load cells and networks?

    See also the discussion page and the issue tracker on the Jaxley GitHub repository for recent questions and problems.

    "},{"location":"install/","title":"Installation","text":""},{"location":"install/#install-the-most-recent-stable-version","title":"Install the most recent stable version","text":"

    Jaxley is available on PyPI:

    pip install jaxley\n
    This will install Jaxley with CPU support. If you want GPU support, follow the instructions on the JAX github repository to install JAX with GPU support (in addition to installing Jaxley). For example, for NVIDIA GPUs, run
    pip install -U \"jax[cuda12]\"\n

    "},{"location":"install/#install-from-source","title":"Install from source","text":"

    You can also install Jaxley from source:

    git clone https://github.com/jaxleyverse/jaxley.git\ncd jaxley\npip install -e .\n

    Note that pip>=21.3 is required to install the editable version with pyproject.toml see pip docs.

    "},{"location":"faq/question_01/","title":"What units does Jaxley use?","text":"

    Jaxley uses the same units as the NEURON simulator, which are listed here.

    "},{"location":"faq/question_02/","title":"How can I save and load cells and networks?","text":"

    All modules (i.e., compartments, branches, cells, and networks) in Jaxley can be saved and loaded with pickle:

    import jaxley as jx\nimport pickle\n\n# ... define network, cell, etc.\nnetwork = jx.Network([cell1, cell2])\n\n# Save.\nwith open(\"path/to/file.pkl\", \"wb\") as handle:\n    pickle.dump(network, handle)\n\n# Load.\nwith open(\"path/to/file.pkl\", \"rb\") as handle:\n    network = pickle.load(handle)\n

    "},{"location":"reference/connect/","title":"Connecting Cells","text":""},{"location":"reference/connect/#jaxley.connect.connect","title":"connect(pre, post, synapse_type)","text":"

    Connect two compartments with a chemical synapse.

    The pre- and postsynaptic compartments must be different compartments of the same network.

    Parameters:

    Name Type Description Default pre CompartmentView

    View of the presynaptic compartment.

    required post CompartmentView

    View of the postsynaptic compartment.

    required synapse_type Synapse

    The synapse to append

    required Source code in jaxley/connect.py
    def connect(\n    pre: \"CompartmentView\",\n    post: \"CompartmentView\",\n    synapse_type: \"Synapse\",\n):\n    \"\"\"Connect two compartments with a chemical synapse.\n\n    The pre- and postsynaptic compartments must be different compartments of the\n    same network.\n\n    Args:\n        pre: View of the presynaptic compartment.\n        post: View of the postsynaptic compartment.\n        synapse_type: The synapse to append\n    \"\"\"\n    assert is_same_network(\n        pre, post\n    ), \"Pre and post compartments must be part of the same network.\"\n    assert np.all(\n        pre_comp_not_equal_post_comp(pre, post)\n    ), \"Pre and post compartments must be different.\"\n\n    pre._append_multiple_synapses(pre.view, post.view, synapse_type)\n
    "},{"location":"reference/connect/#jaxley.connect.connectivity_matrix_connect","title":"connectivity_matrix_connect(pre_cell_view, post_cell_view, synapse_type, connectivity_matrix)","text":"

    Appends multiple connections which build a custom connected network.

    Connects pre- and postsynaptic cells according to a custom connectivity matrix. Entries > 0 in the matrix indicate a connection between the corresponding cells. Connections are from branch 0 location 0 to a randomly chosen branch and loc.

    Parameters:

    Name Type Description Default pre_cell_view CellView

    View of the presynaptic cell.

    required post_cell_view CellView

    View of the postsynaptic cell.

    required synapse_type Synapse

    The synapse to append.

    required connectivity_matrix ndarray[bool]

    A boolean matrix indicating the connections between cells.

    required Source code in jaxley/connect.py
    def connectivity_matrix_connect(\n    pre_cell_view: \"CellView\",\n    post_cell_view: \"CellView\",\n    synapse_type: \"Synapse\",\n    connectivity_matrix: np.ndarray[bool],\n):\n    \"\"\"Appends multiple connections which build a custom connected network.\n\n    Connects pre- and postsynaptic cells according to a custom connectivity matrix.\n    Entries > 0 in the matrix indicate a connection between the corresponding cells.\n    Connections are from branch 0 location 0 to a randomly chosen branch and loc.\n\n    Args:\n        pre_cell_view: View of the presynaptic cell.\n        post_cell_view: View of the postsynaptic cell.\n        synapse_type: The synapse to append.\n        connectivity_matrix: A boolean matrix indicating the connections between cells.\n    \"\"\"\n    # Get pre- and postsynaptic cell indices.\n    pre_cell_inds, post_cell_inds = get_pre_post_inds(pre_cell_view, post_cell_view)\n\n    assert connectivity_matrix.shape == (\n        pre_cell_view.shape[0],\n        post_cell_view.shape[0],\n    ), \"Connectivity matrix must have shape (num_pre, num_post).\"\n    assert connectivity_matrix.dtype == bool, \"Connectivity matrix must be boolean.\"\n\n    # get connection pairs from connectivity matrix\n    from_idx, to_idx = np.where(connectivity_matrix)\n    pre_cell_inds = pre_cell_inds[from_idx]\n    post_cell_inds = post_cell_inds[to_idx]\n\n    # Sample random postsynaptic compartments (global comp indices).\n    global_post_indices = [\n        sample_comp(post_cell_view, cell_idx).index[0] for cell_idx in post_cell_inds\n    ]\n    post_rows = post_cell_view.view.loc[global_post_indices]\n\n    idcs_to_zero = np.zeros_like(from_idx)\n    get_global_idx = post_cell_view.pointer._local_inds_to_global\n    global_pre_indices = get_global_idx(pre_cell_inds, idcs_to_zero, idcs_to_zero)\n    pre_rows = pre_cell_view.view.loc[global_pre_indices]\n\n    pre_cell_view._append_multiple_synapses(pre_rows, post_rows, synapse_type)\n
    "},{"location":"reference/connect/#jaxley.connect.fully_connect","title":"fully_connect(pre_cell_view, post_cell_view, synapse_type)","text":"

    Appends multiple connections which build a fully connected layer.

    Connections are from branch 0 location 0 to a randomly chosen branch and loc.

    Parameters:

    Name Type Description Default pre_cell_view CellView

    View of the presynaptic cell.

    required post_cell_view CellView

    View of the postsynaptic cell.

    required synapse_type Synapse

    The synapse to append.

    required Source code in jaxley/connect.py
    def fully_connect(\n    pre_cell_view: \"CellView\",\n    post_cell_view: \"CellView\",\n    synapse_type: \"Synapse\",\n):\n    \"\"\"Appends multiple connections which build a fully connected layer.\n\n    Connections are from branch 0 location 0 to a randomly chosen branch and loc.\n\n    Args:\n        pre_cell_view: View of the presynaptic cell.\n        post_cell_view: View of the postsynaptic cell.\n        synapse_type: The synapse to append.\n    \"\"\"\n    # Get pre- and postsynaptic cell indices.\n    pre_cell_inds, post_cell_inds = get_pre_post_inds(pre_cell_view, post_cell_view)\n    num_pre, num_post = len(pre_cell_inds), len(post_cell_inds)\n\n    # Infer indices of (random) postsynaptic compartments.\n    global_post_indices = (\n        post_cell_view.view.groupby(\"cell_index\")\n        .sample(num_pre, replace=True)\n        .index.to_numpy()\n    )\n    global_post_indices = global_post_indices.reshape((-1, num_pre), order=\"F\").ravel()\n    post_rows = post_cell_view.view.loc[global_post_indices]\n\n    # Pre-synapse is at the zero-eth branch and zero-eth compartment.\n    pre_rows = pre_cell_view[0, 0].view\n    # Repeat rows `num_post` times. See SO 50788508.\n    pre_rows = pre_rows.loc[pre_rows.index.repeat(num_post)].reset_index(drop=True)\n\n    pre_cell_view._append_multiple_synapses(pre_rows, post_rows, synapse_type)\n
    "},{"location":"reference/connect/#jaxley.connect.get_pre_post_inds","title":"get_pre_post_inds(pre_cell_view, post_cell_view)","text":"

    Get the unique cell indices of the pre- and postsynaptic cells.

    Source code in jaxley/connect.py
    def get_pre_post_inds(\n    pre_cell_view: \"CellView\", post_cell_view: \"CellView\"\n) -> Tuple[np.ndarray, np.ndarray]:\n    \"\"\"Get the unique cell indices of the pre- and postsynaptic cells.\"\"\"\n    pre_cell_inds = np.unique(pre_cell_view.view[\"cell_index\"].to_numpy())\n    post_cell_inds = np.unique(post_cell_view.view[\"cell_index\"].to_numpy())\n    return pre_cell_inds, post_cell_inds\n
    "},{"location":"reference/connect/#jaxley.connect.is_same_network","title":"is_same_network(pre, post)","text":"

    Check if views are from the same network.

    Source code in jaxley/connect.py
    def is_same_network(pre: \"View\", post: \"View\") -> bool:\n    \"\"\"Check if views are from the same network.\"\"\"\n    is_in_net = \"network\" in pre.pointer.__class__.__name__.lower()\n    is_in_same_net = pre.pointer is post.pointer\n    return is_in_net and is_in_same_net\n
    "},{"location":"reference/connect/#jaxley.connect.pre_comp_not_equal_post_comp","title":"pre_comp_not_equal_post_comp(pre, post)","text":"

    Check if pre and post compartments are different.

    Source code in jaxley/connect.py
    def pre_comp_not_equal_post_comp(\n    pre: \"CompartmentView\", post: \"CompartmentView\"\n) -> np.ndarray[bool]:\n    \"\"\"Check if pre and post compartments are different.\"\"\"\n    cols = [\"cell_index\", \"branch_index\", \"comp_index\"]\n    return np.any(pre.view[cols].values != post.view[cols].values, axis=1)\n
    "},{"location":"reference/connect/#jaxley.connect.sample_comp","title":"sample_comp(cell_view, cell_idx, num=1, replace=True)","text":"

    Sample a compartment from a cell.

    Returns View with shape (num, num_cols).

    Source code in jaxley/connect.py
    def sample_comp(\n    cell_view: \"CellView\", cell_idx: int, num: int = 1, replace=True\n) -> \"CompartmentView\":\n    \"\"\"Sample a compartment from a cell.\n\n    Returns View with shape (num, num_cols).\"\"\"\n    cell_idx_view = lambda view, cell_idx: view[view[\"cell_index\"] == cell_idx]\n    return cell_idx_view(cell_view.view, cell_idx).sample(num, replace=replace)\n
    "},{"location":"reference/connect/#jaxley.connect.sparse_connect","title":"sparse_connect(pre_cell_view, post_cell_view, synapse_type, p)","text":"

    Appends multiple connections which build a sparse, randomly connected layer.

    Connections are from branch 0 location 0 to a randomly chosen branch and loc.

    Parameters:

    Name Type Description Default pre_cell_view CellView

    View of the presynaptic cell.

    required post_cell_view CellView

    View of the postsynaptic cell.

    required synapse_type Synapse

    The synapse to append.

    required p float

    Probability of connection.

    required Source code in jaxley/connect.py
    def sparse_connect(\n    pre_cell_view: \"CellView\",\n    post_cell_view: \"CellView\",\n    synapse_type: \"Synapse\",\n    p: float,\n):\n    \"\"\"Appends multiple connections which build a sparse, randomly connected layer.\n\n    Connections are from branch 0 location 0 to a randomly chosen branch and loc.\n\n    Args:\n        pre_cell_view: View of the presynaptic cell.\n        post_cell_view: View of the postsynaptic cell.\n        synapse_type: The synapse to append.\n        p: Probability of connection.\n    \"\"\"\n    # Get pre- and postsynaptic cell indices.\n    pre_cell_inds, post_cell_inds = get_pre_post_inds(pre_cell_view, post_cell_view)\n    num_pre, num_post = len(pre_cell_inds), len(post_cell_inds)\n\n    num_connections = np.random.binomial(num_pre * num_post, p)\n    pre_syn_neurons = np.random.choice(pre_cell_inds, size=num_connections)\n    post_syn_neurons = np.random.choice(post_cell_inds, size=num_connections)\n\n    # Sort the synapses only for convenience of inspecting `.edges`.\n    sorting = np.argsort(pre_syn_neurons)\n    pre_syn_neurons = pre_syn_neurons[sorting]\n    post_syn_neurons = post_syn_neurons[sorting]\n\n    # Post-synapse is a randomly chosen branch and compartment.\n    global_post_indices = [\n        sample_comp(post_cell_view, cell_idx).index[0] for cell_idx in post_syn_neurons\n    ]\n    post_rows = post_cell_view.view.loc[global_post_indices]\n\n    # Pre-synapse is at the zero-eth branch and zero-eth compartment.\n    idcs_to_zero = np.zeros_like(num_pre)\n    get_global_idx = pre_cell_view.pointer._local_inds_to_global\n    global_pre_indices = get_global_idx(pre_syn_neurons, idcs_to_zero, idcs_to_zero)\n    pre_rows = pre_cell_view.view.loc[global_pre_indices]\n\n    pre_cell_view._append_multiple_synapses(pre_rows, post_rows, synapse_type)\n
    "},{"location":"reference/integration/","title":"Simulation","text":""},{"location":"reference/integration/#jaxley.integrate.integrate","title":"integrate(module, params=[], *, param_state=None, data_stimuli=None, t_max=None, delta_t=0.025, solver='bwd_euler', voltage_solver='jaxley.stone', checkpoint_lengths=None, all_states=None, return_states=False)","text":"

    Solves ODE and simulates neuron model.

    Parameters:

    Name Type Description Default params List[Dict[str, ndarray]]

    Trainable parameters returned by get_parameters().

    [] param_state Optional[List[Dict]]

    Parameters returned by data_set.

    None data_stimuli Optional[Tuple[ndarray, DataFrame]]

    Outputs of .data_stimulate(), only needed if stimuli change across function calls.

    None t_max Optional[float]

    Duration of the simulation in milliseconds. If t_max is greater than the length of the stimulus input, the stimulus will be padded at the end with zeros. If t_max is smaller, then the stimulus with be truncated.

    None delta_t float

    Time step of the solver in milliseconds.

    0.025 solver str

    Which ODE solver to use. Either of [\u201cfwd_euler\u201d, \u201cbwd_euler\u201d, \u201ccranck\u201d].

    'bwd_euler' tridiag_solver

    Algorithm to solve tridiagonal systems. The different options only affect bwd_euler and cranck solvers. Either of [\u201cstone\u201d, \u201cthomas\u201d], where stone is much faster on GPU for long branches with many compartments and thomas is slightly faster on CPU (thomas is used in NEURON).

    required checkpoint_lengths Optional[List[int]]

    Number of timesteps at every level of checkpointing. The prod(checkpoint_lengths) must be larger or equal to the desired number of simulated timesteps. Warning: the simulation is run for prod(checkpoint_lengths) timesteps, and the result is posthoc truncated to the desired simulation length. Therefore, a poor choice of checkpoint_lengths can lead to longer simulation time. If None, no checkpointing is applied.

    None all_states Optional[Dict]

    An optional initial state that was returned by a previous jx.integrate(..., return_states=True) run. Overrides potentially trainable initial states.

    None return_states bool

    If True, it returns all states such that the current state of the Module can be set with set_states.

    False Source code in jaxley/integrate.py
    def integrate(\n    module: Module,\n    params: List[Dict[str, jnp.ndarray]] = [],\n    *,\n    param_state: Optional[List[Dict]] = None,\n    data_stimuli: Optional[Tuple[jnp.ndarray, pd.DataFrame]] = None,\n    t_max: Optional[float] = None,\n    delta_t: float = 0.025,\n    solver: str = \"bwd_euler\",\n    voltage_solver: str = \"jaxley.stone\",\n    checkpoint_lengths: Optional[List[int]] = None,\n    all_states: Optional[Dict] = None,\n    return_states: bool = False,\n) -> jnp.ndarray:\n    \"\"\"\n    Solves ODE and simulates neuron model.\n\n    Args:\n        params: Trainable parameters returned by `get_parameters()`.\n        param_state: Parameters returned by `data_set`.\n        data_stimuli: Outputs of `.data_stimulate()`, only needed if stimuli change\n            across function calls.\n        t_max: Duration of the simulation in milliseconds. If `t_max` is greater than\n            the length of the stimulus input, the stimulus will be padded at the end\n            with zeros. If `t_max` is smaller, then the stimulus with be truncated.\n        delta_t: Time step of the solver in milliseconds.\n        solver: Which ODE solver to use. Either of [\"fwd_euler\", \"bwd_euler\", \"cranck\"].\n        tridiag_solver: Algorithm to solve tridiagonal systems. The  different options\n            only affect `bwd_euler` and `cranck` solvers. Either of [\"stone\",\n            \"thomas\"], where `stone` is much faster on GPU for long branches\n            with many compartments and `thomas` is slightly faster on CPU (`thomas` is\n            used in NEURON).\n        checkpoint_lengths: Number of timesteps at every level of checkpointing. The\n            `prod(checkpoint_lengths)` must be larger or equal to the desired number of\n            simulated timesteps. Warning: the simulation is run for\n            `prod(checkpoint_lengths)` timesteps, and the result is posthoc truncated\n            to the desired simulation length. Therefore, a poor choice of\n            `checkpoint_lengths` can lead to longer simulation time. If `None`, no\n            checkpointing is applied.\n        all_states: An optional initial state that was returned by a previous\n            `jx.integrate(..., return_states=True)` run. Overrides potentially\n            trainable initial states.\n        return_states: If True, it returns all states such that the current state of\n            the `Module` can be set with `set_states`.\n    \"\"\"\n\n    assert module.initialized, \"Module is not initialized, run `.initialize()`.\"\n    module.to_jax()  # Creates `.jaxnodes` from `.nodes` and `.jaxedges` from `.edges`.\n\n    # Initialize the external inputs and their indices.\n    externals = module.externals.copy()\n    external_inds = module.external_inds.copy()\n\n    # If stimulus is inserted, add it to the external inputs.\n    if \"i\" in module.externals.keys() or data_stimuli is not None:\n        if \"i\" in module.externals.keys():\n            if data_stimuli is not None:\n                externals[\"i\"] = jnp.concatenate([externals[\"i\"], data_stimuli[0]])\n                external_inds[\"i\"] = jnp.concatenate(\n                    [external_inds[\"i\"], data_stimuli[1].comp_index.to_numpy()]\n                )\n        else:\n            externals[\"i\"] = data_stimuli[0]\n            external_inds[\"i\"] = data_stimuli[1].comp_index.to_numpy()\n    else:\n        externals[\"i\"] = jnp.asarray([[]]).astype(\"float\")\n        external_inds[\"i\"] = jnp.asarray([]).astype(\"int32\")\n\n    if not externals.keys():\n        # No stimulus was inserted and no clamp was set.\n        assert (\n            t_max is not None\n        ), \"If no stimulus or clamp are inserted you have to specify the simulation duration at `jx.integrate(..., t_max=)`.\"\n\n    for key in externals.keys():\n        externals[key] = externals[key].T  # Shape `(time, num_stimuli)`.\n\n    rec_inds = module.recordings.rec_index.to_numpy()\n    rec_states = module.recordings.state.to_numpy()\n\n    # Shorten or pad stimulus depending on `t_max`.\n    if t_max is not None:\n        t_max_steps = int(t_max // delta_t + 1)\n\n        # Pad or truncate the stimulus.\n        if \"i\" in externals.keys() and t_max_steps > externals[\"i\"].shape[0]:\n            pad = jnp.zeros(\n                (t_max_steps - externals[\"i\"].shape[0], externals[\"i\"].shape[1])\n            )\n            externals[\"i\"] = jnp.concatenate((externals[\"i\"], pad))\n\n        for key in externals.keys():\n            if t_max_steps > externals[key].shape[0]:\n                raise NotImplementedError(\n                    \"clamp must be at least as long as simulation.\"\n                )\n            else:\n                externals[key] = externals[key][:t_max_steps, :]\n\n    # Make the `trainable_params` of the same shape as the `param_state`, such that they\n    # can be processed together by `get_all_parameters`.\n    pstate = params_to_pstate(params, module.indices_set_by_trainables)\n\n    # Gather parameters from `make_trainable` and `data_set` into a single list.\n    if param_state is not None:\n        pstate += param_state\n\n    all_params = module.get_all_parameters(pstate)\n    all_states = (\n        module.get_all_states(pstate, all_params, delta_t)\n        if all_states is None\n        else all_states\n    )\n\n    def _body_fun(state, externals):\n        state = module.step(\n            state,\n            delta_t,\n            external_inds,\n            externals,\n            params=all_params,\n            solver=solver,\n            voltage_solver=voltage_solver,\n        )\n        recs = jnp.asarray(\n            [\n                state[rec_state][rec_ind]\n                for rec_state, rec_ind in zip(rec_states, rec_inds)\n            ]\n        )\n        return state, recs\n\n    # If necessary, pad the stimulus with zeros in order to simulate sufficiently long.\n    # The total simulation length will be `prod(checkpoint_lengths)`. At the end, we\n    # return only the first `nsteps_to_return` elements (plus the initial state).\n    example_key = list(externals.keys())[0]\n    nsteps_to_return = len(externals[example_key])\n    if checkpoint_lengths is None:\n        checkpoint_lengths = [len(externals[example_key])]\n        length = len(externals[example_key])\n    else:\n        length = prod(checkpoint_lengths)\n        size_difference = length - len(externals[example_key])\n        dummy_external = jnp.zeros((size_difference, externals[example_key].shape[1]))\n        assert (\n            len(externals[example_key]) <= length\n        ), \"The desired simulation duration is longer than `prod(nested_length)`.\"\n        for key in externals.keys():\n            externals[key] = jnp.concatenate([externals[key], dummy_external])\n\n    # Record the initial state.\n    init_recs = jnp.asarray(\n        [\n            all_states[rec_state][rec_ind]\n            for rec_state, rec_ind in zip(rec_states, rec_inds)\n        ]\n    )\n    init_recording = jnp.expand_dims(init_recs, axis=0)\n\n    # Run simulation.\n    all_states, recordings = nested_checkpoint_scan(\n        _body_fun,\n        all_states,\n        externals,\n        length=length,\n        nested_lengths=checkpoint_lengths,\n    )\n    recs = jnp.concatenate([init_recording, recordings[:nsteps_to_return]], axis=0).T\n    return (recs, all_states) if return_states else recs\n
    "},{"location":"reference/integration/#jaxley.solver_gate.exponential_euler","title":"exponential_euler(x, dt, x_inf, x_tau)","text":"

    An exact solver for the linear dynamical system dx = -(x - x_inf) / x_tau.

    Source code in jaxley/solver_gate.py
    def exponential_euler(\n    x: jnp.ndarray,\n    dt: float,\n    x_inf: jnp.ndarray,\n    x_tau: jnp.ndarray,\n):\n    \"\"\"An exact solver for the linear dynamical system `dx = -(x - x_inf) / x_tau`.\"\"\"\n    exp_term = save_exp(-dt / x_tau)\n    return x * exp_term + x_inf * (1.0 - exp_term)\n
    "},{"location":"reference/integration/#jaxley.solver_gate.save_exp","title":"save_exp(x, max_value=20.0)","text":"

    Clip the input to a maximum value and return its exponential.

    Source code in jaxley/solver_gate.py
    def save_exp(x, max_value: float = 20.0):\n    \"\"\"Clip the input to a maximum value and return its exponential.\"\"\"\n    x = jnp.clip(x, a_max=max_value)\n    return jnp.exp(x)\n
    "},{"location":"reference/integration/#jaxley.solver_gate.solve_inf_gate_exponential","title":"solve_inf_gate_exponential(x, dt, s_inf, tau_s)","text":"

    solves dx/dt = (s_inf - x) / tau_s via exponential Euler

    Parameters:

    Name Type Description Default x ndarray

    gate variable

    required dt float

    time_delta

    required s_inf ndarray

    description

    required tau_s ndarray

    description

    required

    Returns:

    Name Type Description _type_

    updated gate

    Source code in jaxley/solver_gate.py
    def solve_inf_gate_exponential(\n    x: jnp.ndarray,\n    dt: float,\n    s_inf: jnp.ndarray,\n    tau_s: jnp.ndarray,\n):\n    \"\"\"solves dx/dt = (s_inf - x) / tau_s\n    via exponential Euler\n\n    Args:\n        x (jnp.ndarray): gate variable\n        dt (float): time_delta\n        s_inf (jnp.ndarray): _description_\n        tau_s (jnp.ndarray): _description_\n\n    Returns:\n        _type_: updated gate\n    \"\"\"\n    slope = -1.0 / tau_s\n    exp_term = save_exp(slope * dt)\n    return x * exp_term + s_inf * (1.0 - exp_term)\n
    "},{"location":"reference/integration/#jaxley.solver_voltage.step_voltage_explicit","title":"step_voltage_explicit(voltages, voltage_terms, constant_terms, coupling_conds_bwd, coupling_conds_fwd, branch_cond_fwd, branch_cond_bwd, nbranches, parents, delta_t)","text":"

    Solve one timestep of branched nerve equations with explicit (forward) Euler.

    Source code in jaxley/solver_voltage.py
    def step_voltage_explicit(\n    voltages: jnp.ndarray,\n    voltage_terms: jnp.ndarray,\n    constant_terms: jnp.ndarray,\n    coupling_conds_bwd: jnp.ndarray,\n    coupling_conds_fwd: jnp.ndarray,\n    branch_cond_fwd: jnp.ndarray,\n    branch_cond_bwd: jnp.ndarray,\n    nbranches: int,\n    parents: jnp.ndarray,\n    delta_t: float,\n) -> jnp.ndarray:\n    \"\"\"Solve one timestep of branched nerve equations with explicit (forward) Euler.\"\"\"\n    voltages = jnp.reshape(voltages, (nbranches, -1))\n    voltage_terms = jnp.reshape(voltage_terms, (nbranches, -1))\n    constant_terms = jnp.reshape(constant_terms, (nbranches, -1))\n\n    update = voltage_vectorfield(\n        parents,\n        voltages,\n        voltage_terms,\n        constant_terms,\n        coupling_conds_bwd,\n        coupling_conds_fwd,\n        branch_cond_fwd,\n        branch_cond_bwd,\n    )\n    new_voltates = voltages + delta_t * update\n    return new_voltates\n
    "},{"location":"reference/integration/#jaxley.solver_voltage.step_voltage_implicit","title":"step_voltage_implicit(voltages, voltage_terms, constant_terms, coupling_conds_upper, coupling_conds_lower, summed_coupling_conds, branchpoint_conds_children, branchpoint_conds_parents, branchpoint_weights_children, branchpoint_weights_parents, par_inds, child_inds, nbranches, solver, delta_t, children_in_level, parents_in_level, root_inds, branchpoint_group_inds, debug_states)","text":"

    Solve one timestep of branched nerve equations with implicit (backward) Euler.

    Source code in jaxley/solver_voltage.py
    def step_voltage_implicit(\n    voltages,\n    voltage_terms,\n    constant_terms,\n    coupling_conds_upper,\n    coupling_conds_lower,\n    summed_coupling_conds,\n    branchpoint_conds_children,\n    branchpoint_conds_parents,\n    branchpoint_weights_children,\n    branchpoint_weights_parents,\n    par_inds,\n    child_inds,\n    nbranches,\n    solver: str,\n    delta_t,\n    children_in_level,\n    parents_in_level,\n    root_inds,\n    branchpoint_group_inds,\n    debug_states,\n):\n    \"\"\"Solve one timestep of branched nerve equations with implicit (backward) Euler.\"\"\"\n    voltages = jnp.reshape(voltages, (nbranches, -1))\n    voltage_terms = jnp.reshape(voltage_terms, (nbranches, -1))\n    constant_terms = jnp.reshape(constant_terms, (nbranches, -1))\n    coupling_conds_upper = jnp.reshape(coupling_conds_upper, (nbranches, -1))\n    coupling_conds_lower = jnp.reshape(coupling_conds_lower, (nbranches, -1))\n    summed_coupling_conds = jnp.reshape(summed_coupling_conds, (nbranches, -1))\n\n    # Define quasi-tridiagonal system.\n    lowers, diags, uppers, solves = define_all_tridiags(\n        voltages,\n        voltage_terms,\n        constant_terms,\n        nbranches,\n        coupling_conds_upper,\n        coupling_conds_lower,\n        summed_coupling_conds,\n        delta_t,\n    )\n    all_branchpoint_vals = jnp.concatenate(\n        [branchpoint_weights_parents, branchpoint_weights_children]\n    )\n    # Find unique group identifiers\n    num_branchpoints = len(branchpoint_conds_parents)\n    branchpoint_diags = -group_and_sum(\n        all_branchpoint_vals, branchpoint_group_inds, num_branchpoints\n    )\n    branchpoint_solves = jnp.zeros((num_branchpoints,))\n\n    branchpoint_conds_children = -delta_t * branchpoint_conds_children\n    branchpoint_conds_parents = -delta_t * branchpoint_conds_parents\n\n    # Here, I move all child and parent indices towards a branchpoint into a larger\n    # vector. This is wasteful, but it makes indexing much easier. JIT compiling\n    # makes the speed difference negligible.\n    # Children.\n    bp_conds_children = jnp.zeros(nbranches)\n    bp_weights_children = jnp.zeros(nbranches)\n    # Parents.\n    bp_conds_parents = jnp.zeros(nbranches)\n    bp_weights_parents = jnp.zeros(nbranches)\n\n    # `.at[inds]` requires that `inds` is not empty, so we need an if-case here.\n    # `len(inds) == 0` is the case for branches and compartments.\n    if num_branchpoints > 0:\n        bp_conds_children = bp_conds_children.at[child_inds].set(\n            branchpoint_conds_children\n        )\n        bp_weights_children = bp_weights_children.at[child_inds].set(\n            branchpoint_weights_children\n        )\n        bp_conds_parents = bp_conds_parents.at[par_inds].set(branchpoint_conds_parents)\n        bp_weights_parents = bp_weights_parents.at[par_inds].set(\n            branchpoint_weights_parents\n        )\n\n    # Triangulate the linear system of equations.\n    (\n        diags,\n        lowers,\n        solves,\n        uppers,\n        branchpoint_diags,\n        branchpoint_solves,\n        bp_weights_children,\n        bp_conds_parents,\n    ) = _triang_branched(\n        lowers,\n        diags,\n        uppers,\n        solves,\n        bp_conds_children,\n        bp_conds_parents,\n        bp_weights_children,\n        bp_weights_parents,\n        branchpoint_diags,\n        branchpoint_solves,\n        solver,\n        children_in_level,\n        parents_in_level,\n        root_inds,\n        debug_states,\n    )\n\n    # Backsubstitute the linear system of equations.\n    (\n        solves,\n        lowers,\n        diags,\n        bp_weights_parents,\n        branchpoint_solves,\n        bp_conds_children,\n    ) = _backsub_branched(\n        lowers,\n        diags,\n        uppers,\n        solves,\n        bp_conds_children,\n        bp_conds_parents,\n        bp_weights_children,\n        bp_weights_parents,\n        branchpoint_diags,\n        branchpoint_solves,\n        solver,\n        children_in_level,\n        parents_in_level,\n        root_inds,\n        debug_states,\n    )\n\n    return solves\n
    "},{"location":"reference/integration/#jaxley.solver_voltage.voltage_vectorfield","title":"voltage_vectorfield(parents, voltages, voltage_terms, constant_terms, coupling_conds_bwd, coupling_conds_fwd, branch_cond_fwd, branch_cond_bwd)","text":"

    Evaluate the vectorfield of the nerve equation.

    Source code in jaxley/solver_voltage.py
    def voltage_vectorfield(\n    parents: jnp.ndarray,\n    voltages: jnp.ndarray,\n    voltage_terms: jnp.ndarray,\n    constant_terms: jnp.ndarray,\n    coupling_conds_bwd: jnp.ndarray,\n    coupling_conds_fwd: jnp.ndarray,\n    branch_cond_fwd: jnp.ndarray,\n    branch_cond_bwd: jnp.ndarray,\n) -> jnp.ndarray:\n    \"\"\"Evaluate the vectorfield of the nerve equation.\"\"\"\n    # Membrane current update.\n    vecfield = -voltage_terms * voltages + constant_terms\n\n    # Current through segments within the same branch.\n    vecfield = vecfield.at[:, :-1].add(\n        (voltages[:, 1:] - voltages[:, :-1]) * coupling_conds_bwd\n    )\n    vecfield = vecfield.at[:, 1:].add(\n        (voltages[:, :-1] - voltages[:, 1:]) * coupling_conds_fwd\n    )\n\n    # Current through branch points.\n    if len(branch_cond_bwd) > 0:\n        vecfield = vecfield.at[:, -1].add(\n            (voltages[parents, 0] - voltages[:, -1]) * branch_cond_bwd\n        )\n\n        # Several branches might have the same parent, so we have to either update these\n        # entries sequentially or we have to build a matrix with width being the maximum\n        # number of children and then sum.\n        term_to_add = (voltages[:, -1] - voltages[parents, 0]) * branch_cond_fwd\n        inds = jnp.stack([parents, jnp.zeros_like(parents)]).T\n        dnums = ScatterDimensionNumbers(\n            update_window_dims=(),\n            inserted_window_dims=(0, 1),\n            scatter_dims_to_operand_dims=(0, 1),\n        )\n        vecfield = scatter_add(vecfield, inds, term_to_add, dnums)\n\n    return vecfield\n
    "},{"location":"reference/mechanisms/","title":"Channels","text":""},{"location":"reference/mechanisms/#channel","title":"Channel","text":"

    Channel base class. All channels inherit from this class.

    As in NEURON, a Channel is considered a distributed process, which means that its conductances are to be specified in S/cm2 and its currents are to be specified in uA/cm2.

    Source code in jaxley/channels/channel.py
    class Channel:\n    \"\"\"Channel base class. All channels inherit from this class.\n\n    As in NEURON, a `Channel` is considered a distributed process, which means that its\n    conductances are to be specified in `S/cm2` and its currents are to be specified in\n    `uA/cm2`.\"\"\"\n\n    _name = None\n    channel_params = None\n    channel_states = None\n    current_name = None\n\n    def __init__(self, name: Optional[str] = None):\n        self._name = name if name else self.__class__.__name__\n\n    @property\n    def name(self) -> Optional[str]:\n        \"\"\"The name of the channel (by default, this is the class name).\"\"\"\n        return self._name\n\n    def change_name(self, new_name: str):\n        \"\"\"Change the channel name.\n\n        Args:\n            new_name: The new name of the channel.\n\n        Returns:\n            Renamed channel, such that this function is chainable.\n        \"\"\"\n        old_prefix = self._name + \"_\"\n        new_prefix = new_name + \"_\"\n\n        self._name = new_name\n        self.channel_params = {\n            (\n                new_prefix + key[len(old_prefix) :]\n                if key.startswith(old_prefix)\n                else key\n            ): value\n            for key, value in self.channel_params.items()\n        }\n\n        self.channel_states = {\n            (\n                new_prefix + key[len(old_prefix) :]\n                if key.startswith(old_prefix)\n                else key\n            ): value\n            for key, value in self.channel_states.items()\n        }\n        return self\n\n    def update_states(\n        self, states, dt, v, params\n    ) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:\n        \"\"\"Return the updated states.\"\"\"\n        raise NotImplementedError\n\n    def compute_current(\n        self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n    ):\n        \"\"\"Given channel states and voltage, return the current through the channel.\n\n        Args:\n            states: All states of the compartment.\n            v: Voltage of the compartment in mV.\n            params: Parameters of the channel (conductances in `S/cm2`).\n\n        Returns:\n            Current in `uA/cm2`.\n        \"\"\"\n        raise NotImplementedError\n
    "},{"location":"reference/mechanisms/#jaxley.channels.channel.Channel.name","title":"name: Optional[str] property","text":"

    The name of the channel (by default, this is the class name).

    "},{"location":"reference/mechanisms/#jaxley.channels.channel.Channel.change_name","title":"change_name(new_name)","text":"

    Change the channel name.

    Parameters:

    Name Type Description Default new_name str

    The new name of the channel.

    required

    Returns:

    Type Description

    Renamed channel, such that this function is chainable.

    Source code in jaxley/channels/channel.py
    def change_name(self, new_name: str):\n    \"\"\"Change the channel name.\n\n    Args:\n        new_name: The new name of the channel.\n\n    Returns:\n        Renamed channel, such that this function is chainable.\n    \"\"\"\n    old_prefix = self._name + \"_\"\n    new_prefix = new_name + \"_\"\n\n    self._name = new_name\n    self.channel_params = {\n        (\n            new_prefix + key[len(old_prefix) :]\n            if key.startswith(old_prefix)\n            else key\n        ): value\n        for key, value in self.channel_params.items()\n    }\n\n    self.channel_states = {\n        (\n            new_prefix + key[len(old_prefix) :]\n            if key.startswith(old_prefix)\n            else key\n        ): value\n        for key, value in self.channel_states.items()\n    }\n    return self\n
    "},{"location":"reference/mechanisms/#jaxley.channels.channel.Channel.compute_current","title":"compute_current(states, v, params)","text":"

    Given channel states and voltage, return the current through the channel.

    Parameters:

    Name Type Description Default states Dict[str, ndarray]

    All states of the compartment.

    required v

    Voltage of the compartment in mV.

    required params Dict[str, ndarray]

    Parameters of the channel (conductances in S/cm2).

    required

    Returns:

    Type Description

    Current in uA/cm2.

    Source code in jaxley/channels/channel.py
    def compute_current(\n    self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n):\n    \"\"\"Given channel states and voltage, return the current through the channel.\n\n    Args:\n        states: All states of the compartment.\n        v: Voltage of the compartment in mV.\n        params: Parameters of the channel (conductances in `S/cm2`).\n\n    Returns:\n        Current in `uA/cm2`.\n    \"\"\"\n    raise NotImplementedError\n
    "},{"location":"reference/mechanisms/#jaxley.channels.channel.Channel.update_states","title":"update_states(states, dt, v, params)","text":"

    Return the updated states.

    Source code in jaxley/channels/channel.py
    def update_states(\n    self, states, dt, v, params\n) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:\n    \"\"\"Return the updated states.\"\"\"\n    raise NotImplementedError\n
    "},{"location":"reference/mechanisms/#hh","title":"HH","text":"

    Bases: Channel

    Hodgkin-Huxley channel.

    Source code in jaxley/channels/hh.py
    class HH(Channel):\n    \"\"\"Hodgkin-Huxley channel.\"\"\"\n\n    def __init__(self, name: Optional[str] = None):\n        super().__init__(name)\n        prefix = self._name\n        self.channel_params = {\n            f\"{prefix}_gNa\": 0.12,\n            f\"{prefix}_gK\": 0.036,\n            f\"{prefix}_gLeak\": 0.0003,\n            f\"{prefix}_eNa\": 50.0,\n            f\"{prefix}_eK\": -77.0,\n            f\"{prefix}_eLeak\": -54.3,\n        }\n        self.channel_states = {\n            f\"{prefix}_m\": 0.2,\n            f\"{prefix}_h\": 0.2,\n            f\"{prefix}_n\": 0.2,\n        }\n        self.current_name = f\"i_HH\"\n\n    def update_states(\n        self,\n        states: Dict[str, jnp.ndarray],\n        dt,\n        v,\n        params: Dict[str, jnp.ndarray],\n    ):\n        \"\"\"Return updated HH channel state.\"\"\"\n        prefix = self._name\n        m, h, n = states[f\"{prefix}_m\"], states[f\"{prefix}_h\"], states[f\"{prefix}_n\"]\n        new_m = solve_gate_exponential(m, dt, *self.m_gate(v))\n        new_h = solve_gate_exponential(h, dt, *self.h_gate(v))\n        new_n = solve_gate_exponential(n, dt, *self.n_gate(v))\n        return {f\"{prefix}_m\": new_m, f\"{prefix}_h\": new_h, f\"{prefix}_n\": new_n}\n\n    def compute_current(\n        self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n    ):\n        \"\"\"Return current through HH channels.\"\"\"\n        prefix = self._name\n        m, h, n = states[f\"{prefix}_m\"], states[f\"{prefix}_h\"], states[f\"{prefix}_n\"]\n\n        # Multiply with 1000 to convert Siemens to milli Siemens.\n        gNa = params[f\"{prefix}_gNa\"] * (m**3) * h * 1000  # mS/cm^2\n        gK = params[f\"{prefix}_gK\"] * n**4 * 1000  # mS/cm^2\n        gLeak = params[f\"{prefix}_gLeak\"] * 1000  # mS/cm^2\n\n        return (\n            gNa * (v - params[f\"{prefix}_eNa\"])\n            + gK * (v - params[f\"{prefix}_eK\"])\n            + gLeak * (v - params[f\"{prefix}_eLeak\"])\n        )\n\n    def init_state(self, v, params):\n        \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n        prefix = self._name\n        alpha_m, beta_m = self.m_gate(v)\n        alpha_h, beta_h = self.h_gate(v)\n        alpha_n, beta_n = self.n_gate(v)\n        return {\n            f\"{prefix}_m\": alpha_m / (alpha_m + beta_m),\n            f\"{prefix}_h\": alpha_h / (alpha_h + beta_h),\n            f\"{prefix}_n\": alpha_n / (alpha_n + beta_n),\n        }\n\n    @staticmethod\n    def m_gate(v):\n        alpha = 0.1 * _vtrap(-(v + 40), 10)\n        beta = 4.0 * save_exp(-(v + 65) / 18)\n        return alpha, beta\n\n    @staticmethod\n    def h_gate(v):\n        alpha = 0.07 * save_exp(-(v + 65) / 20)\n        beta = 1.0 / (save_exp(-(v + 35) / 10) + 1)\n        return alpha, beta\n\n    @staticmethod\n    def n_gate(v):\n        alpha = 0.01 * _vtrap(-(v + 55), 10)\n        beta = 0.125 * save_exp(-(v + 65) / 80)\n        return alpha, beta\n
    "},{"location":"reference/mechanisms/#jaxley.channels.hh.HH.compute_current","title":"compute_current(states, v, params)","text":"

    Return current through HH channels.

    Source code in jaxley/channels/hh.py
    def compute_current(\n    self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n):\n    \"\"\"Return current through HH channels.\"\"\"\n    prefix = self._name\n    m, h, n = states[f\"{prefix}_m\"], states[f\"{prefix}_h\"], states[f\"{prefix}_n\"]\n\n    # Multiply with 1000 to convert Siemens to milli Siemens.\n    gNa = params[f\"{prefix}_gNa\"] * (m**3) * h * 1000  # mS/cm^2\n    gK = params[f\"{prefix}_gK\"] * n**4 * 1000  # mS/cm^2\n    gLeak = params[f\"{prefix}_gLeak\"] * 1000  # mS/cm^2\n\n    return (\n        gNa * (v - params[f\"{prefix}_eNa\"])\n        + gK * (v - params[f\"{prefix}_eK\"])\n        + gLeak * (v - params[f\"{prefix}_eLeak\"])\n    )\n
    "},{"location":"reference/mechanisms/#jaxley.channels.hh.HH.init_state","title":"init_state(v, params)","text":"

    Initialize the state such at fixed point of gate dynamics.

    Source code in jaxley/channels/hh.py
    def init_state(self, v, params):\n    \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n    prefix = self._name\n    alpha_m, beta_m = self.m_gate(v)\n    alpha_h, beta_h = self.h_gate(v)\n    alpha_n, beta_n = self.n_gate(v)\n    return {\n        f\"{prefix}_m\": alpha_m / (alpha_m + beta_m),\n        f\"{prefix}_h\": alpha_h / (alpha_h + beta_h),\n        f\"{prefix}_n\": alpha_n / (alpha_n + beta_n),\n    }\n
    "},{"location":"reference/mechanisms/#jaxley.channels.hh.HH.update_states","title":"update_states(states, dt, v, params)","text":"

    Return updated HH channel state.

    Source code in jaxley/channels/hh.py
    def update_states(\n    self,\n    states: Dict[str, jnp.ndarray],\n    dt,\n    v,\n    params: Dict[str, jnp.ndarray],\n):\n    \"\"\"Return updated HH channel state.\"\"\"\n    prefix = self._name\n    m, h, n = states[f\"{prefix}_m\"], states[f\"{prefix}_h\"], states[f\"{prefix}_n\"]\n    new_m = solve_gate_exponential(m, dt, *self.m_gate(v))\n    new_h = solve_gate_exponential(h, dt, *self.h_gate(v))\n    new_n = solve_gate_exponential(n, dt, *self.n_gate(v))\n    return {f\"{prefix}_m\": new_m, f\"{prefix}_h\": new_h, f\"{prefix}_n\": new_n}\n
    "},{"location":"reference/mechanisms/#pospischil","title":"Pospischil","text":"

    Bases: Channel

    Leak current

    Source code in jaxley/channels/pospischil.py
    class Leak(Channel):\n    \"\"\"Leak current\"\"\"\n\n    def __init__(self, name: Optional[str] = None):\n        super().__init__(name)\n        prefix = self._name\n        self.channel_params = {\n            f\"{prefix}_gLeak\": 1e-4,\n            f\"{prefix}_eLeak\": -70.0,\n        }\n        self.channel_states = {}\n        self.current_name = f\"i_{prefix}\"\n\n    def update_states(\n        self,\n        states: Dict[str, jnp.ndarray],\n        dt,\n        v,\n        params: Dict[str, jnp.ndarray],\n    ):\n        \"\"\"No state to update.\"\"\"\n        return {}\n\n    def compute_current(\n        self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n    ):\n        \"\"\"Return current.\"\"\"\n        prefix = self._name\n        # Multiply with 1000 to convert Siemens to milli Siemens.\n        gLeak = params[f\"{prefix}_gLeak\"] * 1000  # mS/cm^2\n        return gLeak * (v - params[f\"{prefix}_eLeak\"])\n\n    def init_state(self, v, params):\n        return {}\n

    Bases: Channel

    Sodium channel

    Source code in jaxley/channels/pospischil.py
    class Na(Channel):\n    \"\"\"Sodium channel\"\"\"\n\n    def __init__(self, name: Optional[str] = None):\n        super().__init__(name)\n        prefix = self._name\n        self.channel_params = {\n            f\"{prefix}_gNa\": 50e-3,\n            \"eNa\": 50.0,\n            \"vt\": -60.0,  # Global parameter, not prefixed with `Na`.\n        }\n        self.channel_states = {f\"{prefix}_m\": 0.2, f\"{prefix}_h\": 0.2}\n        self.current_name = f\"i_Na\"\n\n    def update_states(\n        self,\n        states: Dict[str, jnp.ndarray],\n        dt,\n        v,\n        params: Dict[str, jnp.ndarray],\n    ):\n        \"\"\"Update state.\"\"\"\n        prefix = self._name\n        m, h = states[f\"{prefix}_m\"], states[f\"{prefix}_h\"]\n        new_m = solve_gate_exponential(m, dt, *self.m_gate(v, params[\"vt\"]))\n        new_h = solve_gate_exponential(h, dt, *self.h_gate(v, params[\"vt\"]))\n        return {f\"{prefix}_m\": new_m, f\"{prefix}_h\": new_h}\n\n    def compute_current(\n        self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n    ):\n        \"\"\"Return current.\"\"\"\n        prefix = self._name\n        m, h = states[f\"{prefix}_m\"], states[f\"{prefix}_h\"]\n\n        # Multiply with 1000 to convert Siemens to milli Siemens.\n        gNa = params[f\"{prefix}_gNa\"] * (m**3) * h * 1000  # mS/cm^2\n\n        current = gNa * (v - params[\"eNa\"])\n        return current\n\n    def init_state(self, v, params):\n        \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n        prefix = self._name\n        alpha_m, beta_m = self.m_gate(v, params[\"vt\"])\n        alpha_h, beta_h = self.h_gate(v, params[\"vt\"])\n        return {\n            f\"{prefix}_m\": alpha_m / (alpha_m + beta_m),\n            f\"{prefix}_h\": alpha_h / (alpha_h + beta_h),\n        }\n\n    @staticmethod\n    def m_gate(v, vt):\n        v_alpha = v - vt - 13.0\n        alpha = 0.32 * efun(-0.25 * v_alpha) / 0.25\n\n        v_beta = v - vt - 40.0\n        beta = 0.28 * efun(0.2 * v_beta) / 0.2\n        return alpha, beta\n\n    @staticmethod\n    def h_gate(v, vt):\n        v_alpha = v - vt - 17.0\n        alpha = 0.128 * save_exp(-v_alpha / 18.0)\n\n        v_beta = v - vt - 40.0\n        beta = 4.0 / (save_exp(-v_beta / 5.0) + 1.0)\n        return alpha, beta\n

    Bases: Channel

    Potassium channel

    Source code in jaxley/channels/pospischil.py
    class K(Channel):\n    \"\"\"Potassium channel\"\"\"\n\n    def __init__(self, name: Optional[str] = None):\n        super().__init__(name)\n        prefix = self._name\n        self.channel_params = {\n            f\"{prefix}_gK\": 5e-3,\n            \"eK\": -90.0,\n            \"vt\": -60.0,  # Global parameter, not prefixed with `Na`.\n        }\n        self.channel_states = {f\"{prefix}_n\": 0.2}\n        self.current_name = f\"i_K\"\n\n    def update_states(\n        self,\n        states: Dict[str, jnp.ndarray],\n        dt,\n        v,\n        params: Dict[str, jnp.ndarray],\n    ):\n        \"\"\"Update state.\"\"\"\n        prefix = self._name\n        n = states[f\"{prefix}_n\"]\n        new_n = solve_gate_exponential(n, dt, *self.n_gate(v, params[\"vt\"]))\n        return {f\"{prefix}_n\": new_n}\n\n    def compute_current(\n        self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n    ):\n        \"\"\"Return current.\"\"\"\n        prefix = self._name\n        n = states[f\"{prefix}_n\"]\n\n        # Multiply with 1000 to convert Siemens to milli Siemens.\n        gK = params[f\"{prefix}_gK\"] * (n**4) * 1000  # mS/cm^2\n\n        return gK * (v - params[\"eK\"])\n\n    def init_state(self, v, params):\n        \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n        prefix = self._name\n        alpha_n, beta_n = self.n_gate(v, params[\"vt\"])\n        return {f\"{prefix}_n\": alpha_n / (alpha_n + beta_n)}\n\n    @staticmethod\n    def n_gate(v, vt):\n        v_alpha = v - vt - 15.0\n        alpha = 0.032 * efun(-0.2 * v_alpha) / 0.2\n\n        v_beta = v - vt - 10.0\n        beta = 0.5 * save_exp(-v_beta / 40.0)\n        return alpha, beta\n

    Bases: Channel

    Slow M Potassium channel

    Source code in jaxley/channels/pospischil.py
    class Km(Channel):\n    \"\"\"Slow M Potassium channel\"\"\"\n\n    def __init__(self, name: Optional[str] = None):\n        super().__init__(name)\n        prefix = self._name\n        self.channel_params = {\n            f\"{prefix}_gKm\": 0.004e-3,\n            f\"{prefix}_taumax\": 4000.0,\n            f\"eK\": -90.0,\n        }\n        self.channel_states = {f\"{prefix}_p\": 0.2}\n        self.current_name = f\"i_K\"\n\n    def update_states(\n        self,\n        states: Dict[str, jnp.ndarray],\n        dt,\n        v,\n        params: Dict[str, jnp.ndarray],\n    ):\n        \"\"\"Update state.\"\"\"\n        prefix = self._name\n        p = states[f\"{prefix}_p\"]\n        new_p = solve_inf_gate_exponential(\n            p, dt, *self.p_gate(v, params[f\"{prefix}_taumax\"])\n        )\n        return {f\"{prefix}_p\": new_p}\n\n    def compute_current(\n        self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n    ):\n        \"\"\"Return current.\"\"\"\n        prefix = self._name\n        p = states[f\"{prefix}_p\"]\n\n        # Multiply with 1000 to convert Siemens to milli Siemens.\n        gKm = params[f\"{prefix}_gKm\"] * p * 1000  # mS/cm^2\n        return gKm * (v - params[\"eK\"])\n\n    def init_state(self, v, params):\n        \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n        prefix = self._name\n        alpha_p, beta_p = self.p_gate(v, params[f\"{prefix}_taumax\"])\n        return {f\"{prefix}_p\": alpha_p / (alpha_p + beta_p)}\n\n    @staticmethod\n    def p_gate(v, taumax):\n        v_p = v + 35.0\n        p_inf = 1.0 / (1.0 + save_exp(-0.1 * v_p))\n\n        tau_p = taumax / (3.3 * save_exp(0.05 * v_p) + save_exp(-0.05 * v_p))\n\n        return p_inf, tau_p\n

    Bases: Channel

    L-type Calcium channel

    Source code in jaxley/channels/pospischil.py
    class CaL(Channel):\n    \"\"\"L-type Calcium channel\"\"\"\n\n    def __init__(self, name: Optional[str] = None):\n        super().__init__(name)\n        prefix = self._name\n        self.channel_params = {\n            f\"{prefix}_gCaL\": 0.1e-3,\n            \"eCa\": 120.0,\n        }\n        self.channel_states = {f\"{prefix}_q\": 0.2, f\"{prefix}_r\": 0.2}\n        self.current_name = f\"i_Ca\"\n\n    def update_states(\n        self,\n        states: Dict[str, jnp.ndarray],\n        dt,\n        v,\n        params: Dict[str, jnp.ndarray],\n    ):\n        \"\"\"Update state.\"\"\"\n        prefix = self._name\n        q, r = states[f\"{prefix}_q\"], states[f\"{prefix}_r\"]\n        new_q = solve_gate_exponential(q, dt, *self.q_gate(v))\n        new_r = solve_gate_exponential(r, dt, *self.r_gate(v))\n        return {f\"{prefix}_q\": new_q, f\"{prefix}_r\": new_r}\n\n    def compute_current(\n        self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n    ):\n        \"\"\"Return current.\"\"\"\n        prefix = self._name\n        q, r = states[f\"{prefix}_q\"], states[f\"{prefix}_r\"]\n\n        # Multiply with 1000 to convert Siemens to milli Siemens.\n        gCaL = params[f\"{prefix}_gCaL\"] * (q**2) * r * 1000  # mS/cm^2\n\n        return gCaL * (v - params[\"eCa\"])\n\n    def init_state(self, v, params):\n        \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n        prefix = self._name\n        alpha_q, beta_q = self.q_gate(v)\n        alpha_r, beta_r = self.r_gate(v)\n        return {\n            f\"{prefix}_q\": alpha_q / (alpha_q + beta_q),\n            f\"{prefix}_r\": alpha_r / (alpha_r + beta_r),\n        }\n\n    @staticmethod\n    def q_gate(v):\n        v_alpha = -v - 27.0\n        alpha = 0.055 * efun(v_alpha / 3.8) * 3.8\n\n        v_beta = -v - 75.0\n        beta = 0.94 * save_exp(v_beta / 17.0)\n        return alpha, beta\n\n    @staticmethod\n    def r_gate(v):\n        v_alpha = -v - 13.0\n        alpha = 0.000457 * save_exp(v_alpha / 50)\n\n        v_beta = -v - 15.0\n        beta = 0.0065 / (save_exp(v_beta / 28.0) + 1)\n        return alpha, beta\n

    Bases: Channel

    T-type Calcium channel

    Source code in jaxley/channels/pospischil.py
    class CaT(Channel):\n    \"\"\"T-type Calcium channel\"\"\"\n\n    def __init__(self, name: Optional[str] = None):\n        super().__init__(name)\n        prefix = self._name\n        self.channel_params = {\n            f\"{prefix}_gCaT\": 0.4e-4,\n            f\"{prefix}_vx\": 2.0,\n            \"eCa\": 120.0,  # Global parameter, not prefixed with `CaT`.\n        }\n        self.channel_states = {f\"{prefix}_u\": 0.2}\n        self.current_name = f\"i_Ca\"\n\n    def update_states(\n        self,\n        states: Dict[str, jnp.ndarray],\n        dt,\n        v,\n        params: Dict[str, jnp.ndarray],\n    ):\n        \"\"\"Update state.\"\"\"\n        prefix = self._name\n        u = states[f\"{prefix}_u\"]\n        new_u = solve_inf_gate_exponential(\n            u, dt, *self.u_gate(v, params[f\"{prefix}_vx\"])\n        )\n        return {f\"{prefix}_u\": new_u}\n\n    def compute_current(\n        self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n    ):\n        \"\"\"Return current.\"\"\"\n        prefix = self._name\n        u = states[f\"{prefix}_u\"]\n        s_inf = 1.0 / (1.0 + save_exp(-(v + params[f\"{prefix}_vx\"] + 57.0) / 6.2))\n\n        # Multiply with 1000 to convert Siemens to milli Siemens.\n        gCaT = params[f\"{prefix}_gCaT\"] * (s_inf**2) * u * 1000  # mS/cm^2\n\n        return gCaT * (v - params[\"eCa\"])\n\n    def init_state(self, v, params):\n        \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n        prefix = self._name\n        alpha_u, beta_u = self.u_gate(v, params[f\"{prefix}_vx\"])\n        return {f\"{prefix}_u\": alpha_u / (alpha_u + beta_u)}\n\n    @staticmethod\n    def u_gate(v, vx):\n        v_u1 = v + vx + 81.0\n        u_inf = 1.0 / (1.0 + save_exp(v_u1 / 4))\n\n        tau_u = (30.8 + (211.4 + save_exp((v + vx + 113.2) / 5.0))) / (\n            3.7 * (1 + save_exp((v + vx + 84.0) / 3.2))\n        )\n\n        return u_inf, tau_u\n
    "},{"location":"reference/mechanisms/#jaxley.channels.pospischil.Leak.compute_current","title":"compute_current(states, v, params)","text":"

    Return current.

    Source code in jaxley/channels/pospischil.py
    def compute_current(\n    self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n):\n    \"\"\"Return current.\"\"\"\n    prefix = self._name\n    # Multiply with 1000 to convert Siemens to milli Siemens.\n    gLeak = params[f\"{prefix}_gLeak\"] * 1000  # mS/cm^2\n    return gLeak * (v - params[f\"{prefix}_eLeak\"])\n
    "},{"location":"reference/mechanisms/#jaxley.channels.pospischil.Leak.update_states","title":"update_states(states, dt, v, params)","text":"

    No state to update.

    Source code in jaxley/channels/pospischil.py
    def update_states(\n    self,\n    states: Dict[str, jnp.ndarray],\n    dt,\n    v,\n    params: Dict[str, jnp.ndarray],\n):\n    \"\"\"No state to update.\"\"\"\n    return {}\n
    "},{"location":"reference/mechanisms/#jaxley.channels.pospischil.Na.compute_current","title":"compute_current(states, v, params)","text":"

    Return current.

    Source code in jaxley/channels/pospischil.py
    def compute_current(\n    self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n):\n    \"\"\"Return current.\"\"\"\n    prefix = self._name\n    m, h = states[f\"{prefix}_m\"], states[f\"{prefix}_h\"]\n\n    # Multiply with 1000 to convert Siemens to milli Siemens.\n    gNa = params[f\"{prefix}_gNa\"] * (m**3) * h * 1000  # mS/cm^2\n\n    current = gNa * (v - params[\"eNa\"])\n    return current\n
    "},{"location":"reference/mechanisms/#jaxley.channels.pospischil.Na.init_state","title":"init_state(v, params)","text":"

    Initialize the state such at fixed point of gate dynamics.

    Source code in jaxley/channels/pospischil.py
    def init_state(self, v, params):\n    \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n    prefix = self._name\n    alpha_m, beta_m = self.m_gate(v, params[\"vt\"])\n    alpha_h, beta_h = self.h_gate(v, params[\"vt\"])\n    return {\n        f\"{prefix}_m\": alpha_m / (alpha_m + beta_m),\n        f\"{prefix}_h\": alpha_h / (alpha_h + beta_h),\n    }\n
    "},{"location":"reference/mechanisms/#jaxley.channels.pospischil.Na.update_states","title":"update_states(states, dt, v, params)","text":"

    Update state.

    Source code in jaxley/channels/pospischil.py
    def update_states(\n    self,\n    states: Dict[str, jnp.ndarray],\n    dt,\n    v,\n    params: Dict[str, jnp.ndarray],\n):\n    \"\"\"Update state.\"\"\"\n    prefix = self._name\n    m, h = states[f\"{prefix}_m\"], states[f\"{prefix}_h\"]\n    new_m = solve_gate_exponential(m, dt, *self.m_gate(v, params[\"vt\"]))\n    new_h = solve_gate_exponential(h, dt, *self.h_gate(v, params[\"vt\"]))\n    return {f\"{prefix}_m\": new_m, f\"{prefix}_h\": new_h}\n
    "},{"location":"reference/mechanisms/#jaxley.channels.pospischil.K.compute_current","title":"compute_current(states, v, params)","text":"

    Return current.

    Source code in jaxley/channels/pospischil.py
    def compute_current(\n    self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n):\n    \"\"\"Return current.\"\"\"\n    prefix = self._name\n    n = states[f\"{prefix}_n\"]\n\n    # Multiply with 1000 to convert Siemens to milli Siemens.\n    gK = params[f\"{prefix}_gK\"] * (n**4) * 1000  # mS/cm^2\n\n    return gK * (v - params[\"eK\"])\n
    "},{"location":"reference/mechanisms/#jaxley.channels.pospischil.K.init_state","title":"init_state(v, params)","text":"

    Initialize the state such at fixed point of gate dynamics.

    Source code in jaxley/channels/pospischil.py
    def init_state(self, v, params):\n    \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n    prefix = self._name\n    alpha_n, beta_n = self.n_gate(v, params[\"vt\"])\n    return {f\"{prefix}_n\": alpha_n / (alpha_n + beta_n)}\n
    "},{"location":"reference/mechanisms/#jaxley.channels.pospischil.K.update_states","title":"update_states(states, dt, v, params)","text":"

    Update state.

    Source code in jaxley/channels/pospischil.py
    def update_states(\n    self,\n    states: Dict[str, jnp.ndarray],\n    dt,\n    v,\n    params: Dict[str, jnp.ndarray],\n):\n    \"\"\"Update state.\"\"\"\n    prefix = self._name\n    n = states[f\"{prefix}_n\"]\n    new_n = solve_gate_exponential(n, dt, *self.n_gate(v, params[\"vt\"]))\n    return {f\"{prefix}_n\": new_n}\n
    "},{"location":"reference/mechanisms/#jaxley.channels.pospischil.Km.compute_current","title":"compute_current(states, v, params)","text":"

    Return current.

    Source code in jaxley/channels/pospischil.py
    def compute_current(\n    self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n):\n    \"\"\"Return current.\"\"\"\n    prefix = self._name\n    p = states[f\"{prefix}_p\"]\n\n    # Multiply with 1000 to convert Siemens to milli Siemens.\n    gKm = params[f\"{prefix}_gKm\"] * p * 1000  # mS/cm^2\n    return gKm * (v - params[\"eK\"])\n
    "},{"location":"reference/mechanisms/#jaxley.channels.pospischil.Km.init_state","title":"init_state(v, params)","text":"

    Initialize the state such at fixed point of gate dynamics.

    Source code in jaxley/channels/pospischil.py
    def init_state(self, v, params):\n    \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n    prefix = self._name\n    alpha_p, beta_p = self.p_gate(v, params[f\"{prefix}_taumax\"])\n    return {f\"{prefix}_p\": alpha_p / (alpha_p + beta_p)}\n
    "},{"location":"reference/mechanisms/#jaxley.channels.pospischil.Km.update_states","title":"update_states(states, dt, v, params)","text":"

    Update state.

    Source code in jaxley/channels/pospischil.py
    def update_states(\n    self,\n    states: Dict[str, jnp.ndarray],\n    dt,\n    v,\n    params: Dict[str, jnp.ndarray],\n):\n    \"\"\"Update state.\"\"\"\n    prefix = self._name\n    p = states[f\"{prefix}_p\"]\n    new_p = solve_inf_gate_exponential(\n        p, dt, *self.p_gate(v, params[f\"{prefix}_taumax\"])\n    )\n    return {f\"{prefix}_p\": new_p}\n
    "},{"location":"reference/mechanisms/#jaxley.channels.pospischil.CaL.compute_current","title":"compute_current(states, v, params)","text":"

    Return current.

    Source code in jaxley/channels/pospischil.py
    def compute_current(\n    self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n):\n    \"\"\"Return current.\"\"\"\n    prefix = self._name\n    q, r = states[f\"{prefix}_q\"], states[f\"{prefix}_r\"]\n\n    # Multiply with 1000 to convert Siemens to milli Siemens.\n    gCaL = params[f\"{prefix}_gCaL\"] * (q**2) * r * 1000  # mS/cm^2\n\n    return gCaL * (v - params[\"eCa\"])\n
    "},{"location":"reference/mechanisms/#jaxley.channels.pospischil.CaL.init_state","title":"init_state(v, params)","text":"

    Initialize the state such at fixed point of gate dynamics.

    Source code in jaxley/channels/pospischil.py
    def init_state(self, v, params):\n    \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n    prefix = self._name\n    alpha_q, beta_q = self.q_gate(v)\n    alpha_r, beta_r = self.r_gate(v)\n    return {\n        f\"{prefix}_q\": alpha_q / (alpha_q + beta_q),\n        f\"{prefix}_r\": alpha_r / (alpha_r + beta_r),\n    }\n
    "},{"location":"reference/mechanisms/#jaxley.channels.pospischil.CaL.update_states","title":"update_states(states, dt, v, params)","text":"

    Update state.

    Source code in jaxley/channels/pospischil.py
    def update_states(\n    self,\n    states: Dict[str, jnp.ndarray],\n    dt,\n    v,\n    params: Dict[str, jnp.ndarray],\n):\n    \"\"\"Update state.\"\"\"\n    prefix = self._name\n    q, r = states[f\"{prefix}_q\"], states[f\"{prefix}_r\"]\n    new_q = solve_gate_exponential(q, dt, *self.q_gate(v))\n    new_r = solve_gate_exponential(r, dt, *self.r_gate(v))\n    return {f\"{prefix}_q\": new_q, f\"{prefix}_r\": new_r}\n
    "},{"location":"reference/mechanisms/#jaxley.channels.pospischil.CaT.compute_current","title":"compute_current(states, v, params)","text":"

    Return current.

    Source code in jaxley/channels/pospischil.py
    def compute_current(\n    self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n):\n    \"\"\"Return current.\"\"\"\n    prefix = self._name\n    u = states[f\"{prefix}_u\"]\n    s_inf = 1.0 / (1.0 + save_exp(-(v + params[f\"{prefix}_vx\"] + 57.0) / 6.2))\n\n    # Multiply with 1000 to convert Siemens to milli Siemens.\n    gCaT = params[f\"{prefix}_gCaT\"] * (s_inf**2) * u * 1000  # mS/cm^2\n\n    return gCaT * (v - params[\"eCa\"])\n
    "},{"location":"reference/mechanisms/#jaxley.channels.pospischil.CaT.init_state","title":"init_state(v, params)","text":"

    Initialize the state such at fixed point of gate dynamics.

    Source code in jaxley/channels/pospischil.py
    def init_state(self, v, params):\n    \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n    prefix = self._name\n    alpha_u, beta_u = self.u_gate(v, params[f\"{prefix}_vx\"])\n    return {f\"{prefix}_u\": alpha_u / (alpha_u + beta_u)}\n
    "},{"location":"reference/mechanisms/#jaxley.channels.pospischil.CaT.update_states","title":"update_states(states, dt, v, params)","text":"

    Update state.

    Source code in jaxley/channels/pospischil.py
    def update_states(\n    self,\n    states: Dict[str, jnp.ndarray],\n    dt,\n    v,\n    params: Dict[str, jnp.ndarray],\n):\n    \"\"\"Update state.\"\"\"\n    prefix = self._name\n    u = states[f\"{prefix}_u\"]\n    new_u = solve_inf_gate_exponential(\n        u, dt, *self.u_gate(v, params[f\"{prefix}_vx\"])\n    )\n    return {f\"{prefix}_u\": new_u}\n
    "},{"location":"reference/mechanisms/#synapses","title":"Synapses","text":""},{"location":"reference/mechanisms/#synapse","title":"Synapse","text":"

    Base class for a synapse.

    As in NEURON, a Synapse is considered a point process, which means that its conductances are to be specified in uS and its currents are to be specified in nA.

    Source code in jaxley/synapses/synapse.py
    class Synapse:\n    \"\"\"Base class for a synapse.\n\n    As in NEURON, a `Synapse` is considered a point process, which means that its\n    conductances are to be specified in `uS` and its currents are to be specified in\n    `nA`.\n    \"\"\"\n\n    _name = None\n    synapse_params = None\n    synapse_states = None\n\n    def __init__(self, name: Optional[str] = None):\n        self._name = name if name else self.__class__.__name__\n\n    @property\n    def name(self) -> Optional[str]:\n        return self._name\n\n    def change_name(self, new_name: str):\n        \"\"\"Change the synapse name.\n\n        Args:\n            new_name: The new name of the channel.\n\n        Returns:\n            Renamed channel, such that this function is chainable.\n        \"\"\"\n        old_prefix = self._name + \"_\"\n        new_prefix = new_name + \"_\"\n\n        self._name = new_name\n        self.synapse_params = {\n            (\n                new_prefix + key[len(old_prefix) :]\n                if key.startswith(old_prefix)\n                else key\n            ): value\n            for key, value in self.synapse_params.items()\n        }\n\n        self.synapse_states = {\n            (\n                new_prefix + key[len(old_prefix) :]\n                if key.startswith(old_prefix)\n                else key\n            ): value\n            for key, value in self.synapse_states.items()\n        }\n        return self\n\n    def update_states(\n        states: Dict[str, jnp.ndarray],\n        delta_t: float,\n        pre_voltage: jnp.ndarray,\n        post_voltage: jnp.ndarray,\n        params: Dict[str, jnp.ndarray],\n    ) -> Dict[str, jnp.ndarray]:\n        \"\"\"ODE update step.\n\n        Args:\n            states: States of the synapse.\n            delta_t: Time step in `ms`.\n            pre_voltage: Voltage of the presynaptic compartment, shape `()`.\n            post_voltage: Voltage of the postsynaptic compartment, shape `()`.\n            params: Parameters of the synapse. Conductances in `uS`.\n\n        Returns:\n            Updated states.\"\"\"\n        raise NotImplementedError\n\n    def compute_current(\n        states: Dict[str, jnp.ndarray],\n        pre_voltage: jnp.ndarray,\n        post_voltage: jnp.ndarray,\n        params: Dict[str, jnp.ndarray],\n    ) -> jnp.ndarray:\n        \"\"\"Return current through one synapse in `nA`.\n\n        Internally, we use `jax.vmap` to vectorize this function across many synapses.\n\n        Args:\n            states: States of the synapse.\n            pre_voltage: Voltage of the presynaptic compartment, shape `()`.\n            post_voltage: Voltage of the postsynaptic compartment, shape `()`.\n            params: Parameters of the synapse. Conductances in `uS`.\n\n        Returns:\n            Current through the synapse in `nA`, shape `()`.\n        \"\"\"\n        raise NotImplementedError\n
    "},{"location":"reference/mechanisms/#jaxley.synapses.synapse.Synapse.change_name","title":"change_name(new_name)","text":"

    Change the synapse name.

    Parameters:

    Name Type Description Default new_name str

    The new name of the channel.

    required

    Returns:

    Type Description

    Renamed channel, such that this function is chainable.

    Source code in jaxley/synapses/synapse.py
    def change_name(self, new_name: str):\n    \"\"\"Change the synapse name.\n\n    Args:\n        new_name: The new name of the channel.\n\n    Returns:\n        Renamed channel, such that this function is chainable.\n    \"\"\"\n    old_prefix = self._name + \"_\"\n    new_prefix = new_name + \"_\"\n\n    self._name = new_name\n    self.synapse_params = {\n        (\n            new_prefix + key[len(old_prefix) :]\n            if key.startswith(old_prefix)\n            else key\n        ): value\n        for key, value in self.synapse_params.items()\n    }\n\n    self.synapse_states = {\n        (\n            new_prefix + key[len(old_prefix) :]\n            if key.startswith(old_prefix)\n            else key\n        ): value\n        for key, value in self.synapse_states.items()\n    }\n    return self\n
    "},{"location":"reference/mechanisms/#jaxley.synapses.synapse.Synapse.compute_current","title":"compute_current(states, pre_voltage, post_voltage, params)","text":"

    Return current through one synapse in nA.

    Internally, we use jax.vmap to vectorize this function across many synapses.

    Parameters:

    Name Type Description Default states Dict[str, ndarray]

    States of the synapse.

    required pre_voltage ndarray

    Voltage of the presynaptic compartment, shape ().

    required post_voltage ndarray

    Voltage of the postsynaptic compartment, shape ().

    required params Dict[str, ndarray]

    Parameters of the synapse. Conductances in uS.

    required

    Returns:

    Type Description ndarray

    Current through the synapse in nA, shape ().

    Source code in jaxley/synapses/synapse.py
    def compute_current(\n    states: Dict[str, jnp.ndarray],\n    pre_voltage: jnp.ndarray,\n    post_voltage: jnp.ndarray,\n    params: Dict[str, jnp.ndarray],\n) -> jnp.ndarray:\n    \"\"\"Return current through one synapse in `nA`.\n\n    Internally, we use `jax.vmap` to vectorize this function across many synapses.\n\n    Args:\n        states: States of the synapse.\n        pre_voltage: Voltage of the presynaptic compartment, shape `()`.\n        post_voltage: Voltage of the postsynaptic compartment, shape `()`.\n        params: Parameters of the synapse. Conductances in `uS`.\n\n    Returns:\n        Current through the synapse in `nA`, shape `()`.\n    \"\"\"\n    raise NotImplementedError\n
    "},{"location":"reference/mechanisms/#jaxley.synapses.synapse.Synapse.update_states","title":"update_states(states, delta_t, pre_voltage, post_voltage, params)","text":"

    ODE update step.

    Parameters:

    Name Type Description Default states Dict[str, ndarray]

    States of the synapse.

    required delta_t float

    Time step in ms.

    required pre_voltage ndarray

    Voltage of the presynaptic compartment, shape ().

    required post_voltage ndarray

    Voltage of the postsynaptic compartment, shape ().

    required params Dict[str, ndarray]

    Parameters of the synapse. Conductances in uS.

    required

    Returns:

    Type Description Dict[str, ndarray]

    Updated states.

    Source code in jaxley/synapses/synapse.py
    def update_states(\n    states: Dict[str, jnp.ndarray],\n    delta_t: float,\n    pre_voltage: jnp.ndarray,\n    post_voltage: jnp.ndarray,\n    params: Dict[str, jnp.ndarray],\n) -> Dict[str, jnp.ndarray]:\n    \"\"\"ODE update step.\n\n    Args:\n        states: States of the synapse.\n        delta_t: Time step in `ms`.\n        pre_voltage: Voltage of the presynaptic compartment, shape `()`.\n        post_voltage: Voltage of the postsynaptic compartment, shape `()`.\n        params: Parameters of the synapse. Conductances in `uS`.\n\n    Returns:\n        Updated states.\"\"\"\n    raise NotImplementedError\n
    "},{"location":"reference/mechanisms/#ionotropic-synapse","title":"Ionotropic Synapse","text":"

    Bases: Synapse

    Compute synaptic current and update synapse state for a generic ionotropic synapse.

    The synapse state \u201cs\u201d is the probability that a postsynaptic receptor channel is open, and this depends on the amount of neurotransmitter released, which is in turn dependent on the presynaptic voltage.

    The synaptic parameters are
    • gS: the maximal conductance across the postsynaptic membrane (uS)
    • e_syn: the reversal potential across the postsynaptic membrane (mV)
    • k_minus: the rate constant of neurotransmitter unbinding from the postsynaptic receptor (s^-1)
    Details of this implementation can be found in the following book chapter

    L. F. Abbott and E. Marder, \u201cModeling Small Networks,\u201d in Methods in Neuronal Modeling, C. Koch and I. Sergev, Eds. Cambridge: MIT Press, 1998.

    Source code in jaxley/synapses/ionotropic.py
    class IonotropicSynapse(Synapse):\n    \"\"\"\n    Compute synaptic current and update synapse state for a generic ionotropic synapse.\n\n    The synapse state \"s\" is the probability that a postsynaptic receptor channel is\n    open, and this depends on the amount of neurotransmitter released, which is in turn\n    dependent on the presynaptic voltage.\n\n    The synaptic parameters are:\n        - gS: the maximal conductance across the postsynaptic membrane (uS)\n        - e_syn: the reversal potential across the postsynaptic membrane (mV)\n        - k_minus: the rate constant of neurotransmitter unbinding from the postsynaptic\n            receptor (s^-1)\n\n    Details of this implementation can be found in the following book chapter:\n        L. F. Abbott and E. Marder, \"Modeling Small Networks,\" in Methods in Neuronal\n        Modeling, C. Koch and I. Sergev, Eds. Cambridge: MIT Press, 1998.\n\n    \"\"\"\n\n    def __init__(self, name: Optional[str] = None):\n        super().__init__(name)\n        prefix = self._name\n        self.synapse_params = {\n            f\"{prefix}_gS\": 1e-4,\n            f\"{prefix}_e_syn\": 0.0,\n            f\"{prefix}_k_minus\": 0.025,\n        }\n        self.synapse_states = {f\"{prefix}_s\": 0.2}\n\n    def update_states(\n        self,\n        states: Dict,\n        delta_t: float,\n        pre_voltage: float,\n        post_voltage: float,\n        params: Dict,\n    ) -> Dict:\n        \"\"\"Return updated synapse state and current.\"\"\"\n        prefix = self._name\n        v_th = -35.0  # mV\n        delta = 10.0  # mV\n\n        s_inf = 1.0 / (1.0 + save_exp((v_th - pre_voltage) / delta))\n        tau_s = (1.0 - s_inf) / params[f\"{prefix}_k_minus\"]\n\n        slope = -1.0 / tau_s\n        exp_term = save_exp(slope * delta_t)\n        new_s = states[f\"{prefix}_s\"] * exp_term + s_inf * (1.0 - exp_term)\n        return {f\"{prefix}_s\": new_s}\n\n    def compute_current(\n        self, states: Dict, pre_voltage: float, post_voltage: float, params: Dict\n    ) -> float:\n        prefix = self._name\n        g_syn = params[f\"{prefix}_gS\"] * states[f\"{prefix}_s\"]\n        return g_syn * (post_voltage - params[f\"{prefix}_e_syn\"])\n
    "},{"location":"reference/mechanisms/#jaxley.synapses.ionotropic.IonotropicSynapse.update_states","title":"update_states(states, delta_t, pre_voltage, post_voltage, params)","text":"

    Return updated synapse state and current.

    Source code in jaxley/synapses/ionotropic.py
    def update_states(\n    self,\n    states: Dict,\n    delta_t: float,\n    pre_voltage: float,\n    post_voltage: float,\n    params: Dict,\n) -> Dict:\n    \"\"\"Return updated synapse state and current.\"\"\"\n    prefix = self._name\n    v_th = -35.0  # mV\n    delta = 10.0  # mV\n\n    s_inf = 1.0 / (1.0 + save_exp((v_th - pre_voltage) / delta))\n    tau_s = (1.0 - s_inf) / params[f\"{prefix}_k_minus\"]\n\n    slope = -1.0 / tau_s\n    exp_term = save_exp(slope * delta_t)\n    new_s = states[f\"{prefix}_s\"] * exp_term + s_inf * (1.0 - exp_term)\n    return {f\"{prefix}_s\": new_s}\n
    "},{"location":"reference/mechanisms/#tanh-rate-synapse","title":"TanH Rate Synapse","text":"

    Bases: Synapse

    Compute synaptic current for tanh synapse (no state).

    Source code in jaxley/synapses/tanh_rate.py
    class TanhRateSynapse(Synapse):\n    \"\"\"\n    Compute synaptic current for tanh synapse (no state).\n    \"\"\"\n\n    def __init__(self, name: Optional[str] = None):\n        super().__init__(name)\n        prefix = self._name\n        self.synapse_params = {\n            f\"{prefix}_gS\": 1e-4,\n            f\"{prefix}_x_offset\": -70.0,\n            f\"{prefix}_slope\": 1.0,\n        }\n        self.synapse_states = {}\n\n    def update_states(\n        self,\n        states: Dict,\n        delta_t: float,\n        pre_voltage: float,\n        post_voltage: float,\n        params: Dict,\n    ) -> Dict:\n        \"\"\"Return updated synapse state and current.\"\"\"\n        return {}\n\n    def compute_current(\n        self, states: Dict, pre_voltage: float, post_voltage: float, params: Dict\n    ) -> float:\n        \"\"\"Return updated synapse state and current.\"\"\"\n        prefix = self._name\n        current = (\n            -1\n            * params[f\"{prefix}_gS\"]\n            * jnp.tanh(\n                (pre_voltage - params[f\"{prefix}_x_offset\"]) * params[f\"{prefix}_slope\"]\n            )\n        )\n        return current\n
    "},{"location":"reference/mechanisms/#jaxley.synapses.tanh_rate.TanhRateSynapse.compute_current","title":"compute_current(states, pre_voltage, post_voltage, params)","text":"

    Return updated synapse state and current.

    Source code in jaxley/synapses/tanh_rate.py
    def compute_current(\n    self, states: Dict, pre_voltage: float, post_voltage: float, params: Dict\n) -> float:\n    \"\"\"Return updated synapse state and current.\"\"\"\n    prefix = self._name\n    current = (\n        -1\n        * params[f\"{prefix}_gS\"]\n        * jnp.tanh(\n            (pre_voltage - params[f\"{prefix}_x_offset\"]) * params[f\"{prefix}_slope\"]\n        )\n    )\n    return current\n
    "},{"location":"reference/mechanisms/#jaxley.synapses.tanh_rate.TanhRateSynapse.update_states","title":"update_states(states, delta_t, pre_voltage, post_voltage, params)","text":"

    Return updated synapse state and current.

    Source code in jaxley/synapses/tanh_rate.py
    def update_states(\n    self,\n    states: Dict,\n    delta_t: float,\n    pre_voltage: float,\n    post_voltage: float,\n    params: Dict,\n) -> Dict:\n    \"\"\"Return updated synapse state and current.\"\"\"\n    return {}\n
    "},{"location":"reference/modules/","title":"Modules","text":""},{"location":"reference/modules/#module","title":"Module","text":"

    Bases: ABC

    Module base class.

    Modules are everything that can be passed to jx.integrate, i.e. compartments, branches, cells, and networks.

    This base class defines the scaffold for all jaxley modules (compartments, branches, cells, networks).

    Source code in jaxley/modules/base.py
    class Module(ABC):\n    \"\"\"Module base class.\n\n    Modules are everything that can be passed to `jx.integrate`, i.e. compartments,\n    branches, cells, and networks.\n\n    This base class defines the scaffold for all jaxley modules (compartments,\n    branches, cells, networks).\n    \"\"\"\n\n    def __init__(self):\n        self.nseg: int = None\n        self.total_nbranches: int = 0\n        self.nbranches_per_cell: List[int] = None\n\n        self.group_nodes = {}\n\n        self.nodes: Optional[pd.DataFrame] = None\n\n        self.edges = pd.DataFrame(\n            columns=[\n                \"pre_locs\",\n                \"pre_branch_index\",\n                \"pre_cell_index\",\n                \"post_locs\",\n                \"post_branch_index\",\n                \"post_cell_index\",\n                \"type\",\n                \"type_ind\",\n                \"global_pre_comp_index\",\n                \"global_post_comp_index\",\n                \"global_pre_branch_index\",\n                \"global_post_branch_index\",\n            ]\n        )\n\n        self.cumsum_nbranches: Optional[jnp.ndarray] = None\n\n        self.comb_parents: jnp.ndarray = jnp.asarray([-1])\n        self.comb_branches_in_each_level: List[jnp.ndarray] = [jnp.asarray([0])]\n\n        self.initialized_morph: bool = False\n        self.initialized_syns: bool = False\n\n        # List of all types of `jx.Synapse`s.\n        self.synapses: List = []\n        self.synapse_param_names = []\n        self.synapse_state_names = []\n        self.synapse_names = []\n\n        # List of types of all `jx.Channel`s.\n        self.channels: List[Channel] = []\n        self.membrane_current_names: List[str] = []\n\n        # For trainable parameters.\n        self.indices_set_by_trainables: List[jnp.ndarray] = []\n        self.trainable_params: List[Dict[str, jnp.ndarray]] = []\n        self.allow_make_trainable: bool = True\n        self.num_trainable_params: int = 0\n\n        # For recordings.\n        self.recordings: pd.DataFrame = pd.DataFrame().from_dict({})\n\n        # For stimuli or clamps.\n        # E.g. `self.externals = {\"v\": zeros(1000,2), \"i\": ones(1000, 2)}`\n        # for 1000 timesteps and two compartments.\n        self.externals: Dict[str, jnp.ndarray] = {}\n        # E.g. `self.external)inds = {\"v\": jnp.asarray([0,1]), \"i\": jnp.asarray([2,3])}`\n        self.external_inds: Dict[str, jnp.ndarray] = {}\n\n        # x, y, z coordinates and radius.\n        self.xyzr: List[np.ndarray] = []\n\n        # For debugging the solver. Will be empty by default and only filled if\n        # `self._init_morph_for_debugging` is run.\n        self.debug_states = {}\n\n    def _update_nodes_with_xyz(self):\n        \"\"\"Add xyz coordinates to nodes.\"\"\"\n        num_branches = len(self.xyzr)\n        x = np.linspace(\n            0.5 / self.nseg,\n            (num_branches * 1 - 0.5 / self.nseg),\n            num_branches * self.nseg,\n        )\n        x += np.arange(num_branches).repeat(\n            self.nseg\n        )  # add offset to prevent branch loc overlap\n        xp = np.hstack(\n            [np.linspace(0, 1, x.shape[0]) + 2 * i for i, x in enumerate(self.xyzr)]\n        )\n        xyz = v_interp(x, xp, np.vstack(self.xyzr)[:, :3])\n        idcs = self.nodes[\"comp_index\"]\n        self.nodes.loc[idcs, [\"x\", \"y\", \"z\"]] = xyz.T\n        return xyz.T\n\n    def __repr__(self):\n        return f\"{type(self).__name__} with {len(self.channels)} different channels. Use `.show()` for details.\"\n\n    def __str__(self):\n        return f\"jx.{type(self).__name__}\"\n\n    def __dir__(self):\n        base_dir = object.__dir__(self)\n        return sorted(base_dir + self.synapse_names + list(self.group_nodes.keys()))\n\n    def _append_params_and_states(self, param_dict: Dict, state_dict: Dict):\n        \"\"\"Insert the default params of the module (e.g. radius, length).\n\n        This is run at `__init__()`. It does not deal with channels.\n        \"\"\"\n        for param_name, param_value in param_dict.items():\n            self.nodes[param_name] = param_value\n        for state_name, state_value in state_dict.items():\n            self.nodes[state_name] = state_value\n\n    def _gather_channels_from_constituents(self, constituents: List):\n        \"\"\"Modify `self.channels` and `self.nodes` with channel info from constituents.\n\n        This is run at `__init__()`. It takes all branches of constituents (e.g.\n        of all branches when the are assembled into a cell) and adds columns to\n        `.nodes` for the relevant channels.\n        \"\"\"\n        for module in constituents:\n            for channel in module.channels:\n                if channel._name not in [c._name for c in self.channels]:\n                    self.channels.append(channel)\n                if channel.current_name not in self.membrane_current_names:\n                    self.membrane_current_names.append(channel.current_name)\n        # Setting columns of channel names to `False` instead of `NaN`.\n        for channel in self.channels:\n            name = channel._name\n            self.nodes.loc[self.nodes[name].isna(), name] = False\n\n    def to_jax(self):\n        \"\"\"Move `.nodes` to `.jaxnodes`.\n\n        Before the actual simulation is run (via `jx.integrate`), all parameters of\n        the `jx.Module` are stored in `.nodes` (a `pd.DataFrame`). However, for\n        simulation, these parameters have to be moved to be `jnp.ndarrays` such that\n        they can be processed on GPU/TPU and such that the simulation can be\n        differentiated. `.to_jax()` copies the `.nodes` to `.jaxnodes`.\n        \"\"\"\n        self.jaxnodes = {}\n        for key, value in self.nodes.to_dict(orient=\"list\").items():\n            inds = jnp.arange(len(value))\n            self.jaxnodes[key] = jnp.asarray(value)[inds]\n\n        # `jaxedges` contains only parameters (no indices).\n        # `jaxedges` contains only non-Nan elements. This is unlike the channels where\n        # we allow parameter sharing.\n        self.jaxedges = {}\n        edges = self.edges.to_dict(orient=\"list\")\n        for i, synapse in enumerate(self.synapses):\n            for key in synapse.synapse_params:\n                condition = np.asarray(edges[\"type_ind\"]) == i\n                self.jaxedges[key] = jnp.asarray(np.asarray(edges[key])[condition])\n            for key in synapse.synapse_states:\n                self.jaxedges[key] = jnp.asarray(np.asarray(edges[key])[condition])\n\n    def show(\n        self,\n        param_names: Optional[Union[str, List[str]]] = None,  # TODO.\n        *,\n        indices: bool = True,\n        params: bool = True,\n        states: bool = True,\n        channel_names: Optional[List[str]] = None,\n    ) -> pd.DataFrame:\n        \"\"\"Print detailed information about the Module or a view of it.\n\n        Args:\n            param_names: The names of the parameters to show. If `None`, all parameters\n                are shown. NOT YET IMPLEMENTED.\n            indices: Whether to show the indices of the compartments.\n            params: Whether to show the parameters of the compartments.\n            states: Whether to show the states of the compartments.\n            channel_names: The names of the channels to show. If `None`, all channels are\n                shown.\n\n        Returns:\n            A `pd.DataFrame` with the requested information.\n        \"\"\"\n        return self._show(\n            self.nodes, param_names, indices, params, states, channel_names\n        )\n\n    def _show(\n        self,\n        view: pd.DataFrame,\n        param_names: Optional[Union[str, List[str]]] = None,\n        indices: bool = True,\n        params: bool = True,\n        states: bool = True,\n        channel_names: Optional[List[str]] = None,\n    ):\n        \"\"\"Print detailed information about the entire Module.\"\"\"\n        printable_nodes = deepcopy(view)\n\n        for channel in self.channels:\n            name = channel._name\n            param_names = list(channel.channel_params.keys())\n            state_names = list(channel.channel_states.keys())\n            if channel_names is not None and name not in channel_names:\n                printable_nodes = printable_nodes.drop(name, axis=1)\n                printable_nodes = printable_nodes.drop(param_names, axis=1)\n                printable_nodes = printable_nodes.drop(state_names, axis=1)\n            else:\n                if not params:\n                    printable_nodes = printable_nodes.drop(param_names, axis=1)\n                if not states:\n                    printable_nodes = printable_nodes.drop(state_names, axis=1)\n\n        if not indices:\n            for name in [\"comp_index\", \"branch_index\", \"cell_index\"]:\n                printable_nodes = printable_nodes.drop(name, axis=1)\n\n        return printable_nodes\n\n    @abstractmethod\n    def init_conds(self, params: Dict):\n        \"\"\"Initialize coupling conductances.\n\n        Args:\n            params: Conductances and morphology parameters, not yet including\n                coupling conductances.\n        \"\"\"\n        raise NotImplementedError\n\n    def _append_channel_to_nodes(self, view: pd.DataFrame, channel: \"jx.Channel\"):\n        \"\"\"Adds channel nodes from constituents to `self.channel_nodes`.\"\"\"\n        name = channel._name\n\n        # Channel does not yet exist in the `jx.Module` at all.\n        if name not in [c._name for c in self.channels]:\n            self.channels.append(channel)\n            self.nodes[name] = False  # Previous columns do not have the new channel.\n\n        if channel.current_name not in self.membrane_current_names:\n            self.membrane_current_names.append(channel.current_name)\n\n        # Add a binary column that indicates if a channel is present.\n        self.nodes.loc[view.index.values, name] = True\n\n        # Loop over all new parameters, e.g. gNa, eNa.\n        for key in channel.channel_params:\n            self.nodes.loc[view.index.values, key] = channel.channel_params[key]\n\n        # Loop over all new parameters, e.g. gNa, eNa.\n        for key in channel.channel_states:\n            self.nodes.loc[view.index.values, key] = channel.channel_states[key]\n\n    def set(self, key: str, val: Union[float, jnp.ndarray]):\n        \"\"\"Set parameter of module (or its view) to a new value.\n\n        Note that this function can not be called within `jax.jit` or `jax.grad`.\n        Instead, it should be used set the parameters of the module **before** the\n        simulation. Use `.data_set()` to set parameters during `jax.jit` or\n        `jax.grad`.\n\n        Args:\n            key: The name of the parameter to set.\n            val: The value to set the parameter to. If it is `jnp.ndarray` then it\n                must be of shape `(len(num_compartments))`.\n        \"\"\"\n        # TODO(@michaeldeistler) should we allow `.set()` for synaptic parameters\n        # without using the `SynapseView`, purely for consistency with `make_trainable`?\n        view = (\n            self.edges\n            if key in self.synapse_param_names or key in self.synapse_state_names\n            else self.nodes\n        )\n        self._set(key, val, view, view)\n\n    def _set(\n        self,\n        key: str,\n        val: Union[float, jnp.ndarray],\n        view: pd.DataFrame,\n        table_to_update: pd.DataFrame,\n    ):\n        if key in view.columns:\n            view = view[~np.isnan(view[key])]\n            table_to_update.loc[view.index.values, key] = val\n        else:\n            raise KeyError(\"Key not recognized.\")\n\n    def data_set(\n        self,\n        key: str,\n        val: Union[float, jnp.ndarray],\n        param_state: Optional[List[Dict]],\n    ):\n        \"\"\"Set parameter of module (or its view) to a new value within `jit`.\n\n        Args:\n            key: The name of the parameter to set.\n            val: The value to set the parameter to. If it is `jnp.ndarray` then it\n                must be of shape `(len(num_compartments))`.\n            param_state: State of the setted parameters, internally used such that this\n                function does not modify global state.\n        \"\"\"\n        view = (\n            self.edges\n            if key in self.synapse_param_names or key in self.synapse_state_names\n            else self.nodes\n        )\n        return self._data_set(key, val, view, param_state=param_state)\n\n    def _data_set(\n        self,\n        key: str,\n        val: Tuple[float, jnp.ndarray],\n        view: pd.DataFrame,\n        param_state: Optional[List[Dict]] = None,\n    ):\n        # Note: `data_set` does not support arrays for `val`.\n        if key in view.columns:\n            view = view[~np.isnan(view[key])]\n            added_param_state = [\n                {\n                    \"indices\": np.atleast_2d(view.index.values),\n                    \"key\": key,\n                    \"val\": jnp.atleast_1d(jnp.asarray(val)),\n                }\n            ]\n            if param_state is not None:\n                param_state += added_param_state\n            else:\n                param_state = added_param_state\n        else:\n            raise KeyError(\"Key not recognized.\")\n        return param_state\n\n    def make_trainable(\n        self,\n        key: str,\n        init_val: Optional[Union[float, list]] = None,\n        verbose: bool = True,\n    ):\n        \"\"\"Make a parameter trainable.\n\n        If a parameter is made trainable, it will be returned by `get_parameters()`\n        and should then be passed to `jx.integrate(..., params=params)`.\n\n        Args:\n            key: Name of the parameter to make trainable.\n            init_val: Initial value of the parameter. If `float`, the same value is\n                used for every created parameter. If `list`, the length of the list has\n                to match the number of created parameters. If `None`, the current\n                parameter value is used and if parameter sharing is performed that the\n                current parameter value is averaged over all shared parameters.\n            verbose: Whether to print the number of parameters that are added and the\n                total number of parameters.\n        \"\"\"\n        assert (\n            key not in self.synapse_param_names and key not in self.synapse_state_names\n        ), \"Parameters of synapses can only be made trainable via the `SynapseView`.\"\n        view = self.nodes\n        view = deepcopy(view.assign(controlled_by_param=0))\n        self._make_trainable(view, key, init_val, verbose=verbose)\n\n    def _make_trainable(\n        self,\n        view: pd.DataFrame,\n        key: str,\n        init_val: Optional[Union[float, list]] = None,\n        verbose: bool = True,\n    ):\n        assert (\n            self.allow_make_trainable\n        ), \"network.cell('all').make_trainable() is not supported. Use a for-loop over cells.\"\n\n        if key in view.columns:\n            view = view[~np.isnan(view[key])]\n            grouped_view = view.groupby(\"controlled_by_param\")\n            # Because of this `x.index.values` we cannot support `make_trainable()` on\n            # the module level for synapse parameters (but only for `SynapseView`).\n            inds_of_comps = list(grouped_view.apply(lambda x: x.index.values))\n\n            # Sorted inds are only used to infer the correct starting values.\n            param_vals = jnp.asarray(\n                [view.loc[inds, key].to_numpy() for inds in inds_of_comps]\n            )\n        else:\n            raise KeyError(f\"Parameter {key} not recognized.\")\n\n        indices_per_param = jnp.stack(inds_of_comps)\n        self.indices_set_by_trainables.append(indices_per_param)\n\n        # Set the value which the trainable parameter should take.\n        num_created_parameters = len(indices_per_param)\n        if init_val is not None:\n            if isinstance(init_val, float):\n                new_params = jnp.asarray([init_val] * num_created_parameters)\n            elif isinstance(init_val, list):\n                assert (\n                    len(init_val) == num_created_parameters\n                ), f\"len(init_val)={len(init_val)}, but trying to create {num_created_parameters} parameters.\"\n                new_params = jnp.asarray(init_val)\n            else:\n                raise ValueError(\n                    f\"init_val must a float, list, or None, but it is a {type(init_val).__name__}.\"\n                )\n        else:\n            new_params = jnp.mean(param_vals, axis=1)\n\n        self.trainable_params.append({key: new_params})\n        self.num_trainable_params += num_created_parameters\n        if verbose:\n            print(\n                f\"Number of newly added trainable parameters: {num_created_parameters}. Total number of trainable parameters: {self.num_trainable_params}\"\n            )\n\n    def delete_trainables(self):\n        \"\"\"Removes all trainable parameters from the module.\"\"\"\n        self.indices_set_by_trainables: List[jnp.ndarray] = []\n        self.trainable_params: List[Dict[str, jnp.ndarray]] = []\n        self.num_trainable_params: int = 0\n\n    def add_to_group(self, group_name: str):\n        \"\"\"Add a view of the module to a group.\n\n        Groups can then be indexed. For example:\n        ```python\n        net.cell(0).add_to_group(\"excitatory\")\n        net.excitatory.set(\"radius\", 0.1)\n        ```\n\n        Args:\n            group_name: The name of the group.\n        \"\"\"\n        raise ValueError(\"`add_to_group()` makes no sense for an entire module.\")\n\n    def _add_to_group(self, group_name: str, view: pd.DataFrame):\n        if group_name in self.group_nodes:\n            view = pd.concat([self.group_nodes[group_name], view])\n        self.group_nodes[group_name] = view\n\n    def get_parameters(self) -> List[Dict[str, jnp.ndarray]]:\n        \"\"\"Get all trainable parameters.\n\n        The returned parameters should be passed to `jx.integrate(..., params=params).\n\n        Returns:\n            A list of all trainable parameters in the form of\n                [{\"gNa\": jnp.array([0.1, 0.2, 0.3])}, ...].\n        \"\"\"\n        return self.trainable_params\n\n    def get_all_parameters(self, pstate: List[Dict]) -> Dict[str, jnp.ndarray]:\n        \"\"\"Return all parameters (and coupling conductances) needed to simulate.\n\n        Runs `init_conds()` and return every parameter that is needed to solve the ODE.\n        This includes conductances, radiuses, lengths, axial_resistivities, but also\n        coupling conductances.\n\n        This is done by first obtaining the current value of every parameter (not only\n        the trainable ones) and then replacing the trainable ones with the value\n        in `trainable_params()`. This function is run within `jx.integrate()`.\n\n        pstate can be obtained by calling `params_to_pstate()`.\n        ```\n        params = module.get_parameters() # i.e. [0, 1, 2]\n        pstate = params_to_pstate(params, module.indices_set_by_trainables)\n        module.to_jax() # needed for call to module.jaxnodes\n        ```\n\n        Args:\n            pstate: The state of the trainable parameters. pstate takes the form\n                [{\"key\": \"gNa\", \"indices\": jnp.array([0, 1, 2]), \"val\": jnp.array([0.1, 0.2, 0.3])}, ...].\n\n        Returns:\n            A dictionary of all module parameters.\n        \"\"\"\n        params = {}\n        for key in [\"radius\", \"length\", \"axial_resistivity\", \"capacitance\"]:\n            params[key] = self.jaxnodes[key]\n\n        for channel in self.channels:\n            for channel_params in channel.channel_params:\n                params[channel_params] = self.jaxnodes[channel_params]\n\n        for synapse_params in self.synapse_param_names:\n            params[synapse_params] = self.jaxedges[synapse_params]\n\n        # Override with those parameters set by `.make_trainable()`.\n        for parameter in pstate:\n            key = parameter[\"key\"]\n            inds = parameter[\"indices\"]\n            set_param = parameter[\"val\"]\n            if key in params:  # Only parameters, not initial states.\n                # `inds` is of shape `(num_params, num_comps_per_param)`.\n                # `set_param` is of shape `(num_params,)`\n                # We need to unsqueeze `set_param` to make it `(num_params, 1)` for the\n                # `.set()` to work. This is done with `[:, None]`.\n                params[key] = params[key].at[inds].set(set_param[:, None])\n\n        # Compute conductance params and append them.\n        cond_params = self.init_conds(params)\n        for key in cond_params:\n            params[key] = cond_params[key]\n\n        return params\n\n    def get_all_states(\n        self, pstate: List[Dict], all_params, delta_t: float\n    ) -> Dict[str, jnp.ndarray]:\n        \"\"\"Get the full initial state of the module from jaxnodes and trainables.\n\n        Args:\n            pstate: The state of the trainable parameters.\n            all_params: All parameters of the module.\n            delta_t: The time step.\n\n        Returns:\n            A dictionary of all states of the module.\n        \"\"\"\n        # Join node and edge states into a single state dictionary.\n        states = {\"v\": self.jaxnodes[\"v\"]}\n        for channel in self.channels:\n            for channel_states in channel.channel_states:\n                states[channel_states] = self.jaxnodes[channel_states]\n        for synapse_states in self.synapse_state_names:\n            states[synapse_states] = self.jaxedges[synapse_states]\n\n        # Override with the initial states set by `.make_trainable()`.\n        for parameter in pstate:\n            key = parameter[\"key\"]\n            inds = parameter[\"indices\"]\n            set_param = parameter[\"val\"]\n            if key in states:  # Only initial states, not parameters.\n                # `inds` is of shape `(num_params, num_comps_per_param)`.\n                # `set_param` is of shape `(num_params,)`\n                # We need to unsqueeze `set_param` to make it `(num_params, 1)` for the\n                # `.set()` to work. This is done with `[:, None]`.\n                states[key] = states[key].at[inds].set(set_param[:, None])\n\n        # Add to the states the initial current through every channel.\n        states, _ = self._channel_currents(\n            states, delta_t, self.channels, self.nodes, all_params\n        )\n\n        # Add to the states the initial current through every synapse.\n        states, _ = self._synapse_currents(\n            states, self.synapses, all_params, delta_t, self.edges\n        )\n        return states\n\n    @property\n    def initialized(self):\n        \"\"\"Whether the `Module` is ready to be solved or not.\"\"\"\n        return self.initialized_morph and self.initialized_syns\n\n    def initialize(self):\n        \"\"\"Initialize the module.\"\"\"\n        self.init_morph()\n        return self\n\n    def init_states(self):\n        \"\"\"Initialize all mechanisms in their steady state.\n\n        This considers the voltages and parameters of each compartment.\"\"\"\n        # Update states of the channels.\n        channel_nodes = self.nodes\n\n        for channel in self.channels:\n            name = channel._name\n            indices = channel_nodes.loc[channel_nodes[name]][\"comp_index\"].to_numpy()\n            voltages = channel_nodes.loc[indices, \"v\"].to_numpy()\n\n            channel_param_names = list(channel.channel_params.keys())\n            channel_params = {}\n            for p in channel_param_names:\n                channel_params[p] = channel_nodes[p][indices].to_numpy()\n\n            init_state = channel.init_state(voltages, channel_params)\n\n            # `init_state` might not return all channel states. Only the ones that are\n            # returned are updated here.\n            for key, val in init_state.items():\n                self.nodes.loc[indices, key] = val\n\n    def _init_morph_for_debugging(self):\n        \"\"\"Instandiates row and column inds which can be used to solve the voltage eqs.\n\n        This is important only for expert users who try to modify the solver for the\n        voltage equations. By default, this function is never run.\n\n        This is useful for debugging the solver because one can use\n        `scipy.linalg.sparse.spsolve` after every step of the solve.\n\n        Here is the code snippet that can be used for debugging then (to be inserted in\n        `solver_voltage`):\n        ```python\n        from scipy.sparse import csc_matrix\n        from scipy.sparse.linalg import spsolve\n        from jaxley.utils.debug_solver import build_voltage_matrix_elements\n\n        elements, solve, num_entries, start_ind_for_branchpoints = (\n            build_voltage_matrix_elements(\n                uppers,\n                lowers,\n                diags,\n                solves,\n                branchpoint_conds_children[debug_states[\"child_inds\"]],\n                branchpoint_conds_parents[debug_states[\"par_inds\"]],\n                branchpoint_weights_children[debug_states[\"child_inds\"]],\n                branchpoint_weights_parents[debug_states[\"par_inds\"]],\n                branchpoint_diags,\n                branchpoint_solves,\n                debug_states[\"nseg\"],\n                nbranches,\n            )\n        )\n        sparse_matrix = csc_matrix(\n            (elements, (debug_states[\"row_inds\"], debug_states[\"col_inds\"])),\n            shape=(num_entries, num_entries),\n        )\n        solution = spsolve(sparse_matrix, solve)\n        solution = solution[:start_ind_for_branchpoints]  # Delete branchpoint voltages.\n        solves = jnp.reshape(solution, (debug_states[\"nseg\"], nbranches))\n        return solves\n        ```\n        \"\"\"\n        # For scipy and jax.scipy.\n        row_and_col_inds = compute_morphology_indices(\n            len(self.par_inds),\n            self.child_belongs_to_branchpoint,\n            self.par_inds,\n            self.child_inds,\n            self.nseg,\n            self.total_nbranches,\n        )\n\n        num_elements = len(row_and_col_inds[\"row_inds\"])\n        data_inds, indices, indptr = convert_to_csc(\n            num_elements=num_elements,\n            row_ind=row_and_col_inds[\"row_inds\"],\n            col_ind=row_and_col_inds[\"col_inds\"],\n        )\n        self.debug_states[\"row_inds\"] = row_and_col_inds[\"row_inds\"]\n        self.debug_states[\"col_inds\"] = row_and_col_inds[\"col_inds\"]\n        self.debug_states[\"data_inds\"] = data_inds\n        self.debug_states[\"indices\"] = indices\n        self.debug_states[\"indptr\"] = indptr\n\n        self.debug_states[\"nseg\"] = self.nseg\n        self.debug_states[\"child_inds\"] = self.child_inds\n        self.debug_states[\"par_inds\"] = self.par_inds\n\n    def record(self, state: str = \"v\", verbose: bool = True):\n        \"\"\"Insert a recording into the compartment.\n\n        Args:\n            state: The name of the state to record.\n            verbose: Whether to print number of inserted recordings.\"\"\"\n        view = deepcopy(self.nodes)\n        view[\"state\"] = state\n        recording_view = view[[\"comp_index\", \"state\"]]\n        recording_view = recording_view.rename(columns={\"comp_index\": \"rec_index\"})\n        self._record(recording_view, verbose=verbose)\n\n    def _record(self, view: pd.DataFrame, verbose: bool = True):\n        self.recordings = pd.concat([self.recordings, view], ignore_index=True)\n        if verbose:\n            print(f\"Added {len(view)} recordings. See `.recordings` for details.\")\n\n    def delete_recordings(self):\n        \"\"\"Removes all recordings from the module.\"\"\"\n        self.recordings = pd.DataFrame().from_dict({})\n\n    def stimulate(self, current: Optional[jnp.ndarray] = None, verbose: bool = True):\n        \"\"\"Insert a stimulus into the compartment.\n\n        current must be a 1d array or have batch dimension of size `(num_compartments, )`\n        or `(1, )`. If 1d, the same stimulus is added to all compartments.\n\n        This function cannot be run during `jax.jit` and `jax.grad`. Because of this,\n        it should only be used for static stimuli (i.e., stimuli that do not depend\n        on the data and that should not be learned). For stimuli that depend on data\n        (or that should be learned), please use `data_stimulate()`.\n\n        Args:\n            current: Current in `nA`.\n        \"\"\"\n        self._external_input(\"i\", current, self.nodes, verbose=verbose)\n\n    def clamp(self, state_name: str, state_array: jnp.ndarray, verbose: bool = True):\n        \"\"\"Clamp a state to a given value across specified compartments.\n\n        Args:\n            state_name: The name of the state to clamp.\n            state_array (jnp.nd: Array of values to clamp the state to.\n            verbose : If True, prints details about the clamping.\n\n        This function sets external states for the compartments.\n        \"\"\"\n        if state_name not in self.nodes.columns:\n            raise KeyError(f\"{state_name} is not a recognized state in this module.\")\n        self._external_input(state_name, state_array, self.nodes, verbose=verbose)\n\n    def _external_input(\n        self,\n        key: str,\n        values: Optional[jnp.ndarray],\n        view: pd.DataFrame,\n        verbose: bool = True,\n    ):\n        values = values if values.ndim == 2 else jnp.expand_dims(values, axis=0)\n        batch_size = values.shape[0]\n        is_multiple = len(view) == batch_size\n        values = values if is_multiple else jnp.repeat(values, len(view), axis=0)\n        assert batch_size in [1, len(view)], \"Number of comps and stimuli do not match.\"\n\n        if key in self.externals.keys():\n            self.externals[key] = jnp.concatenate([self.externals[key], values])\n            self.external_inds[key] = jnp.concatenate(\n                [self.external_inds[key], view.comp_index.to_numpy()]\n            )\n        else:\n            self.externals[key] = values\n            self.external_inds[key] = view.comp_index.to_numpy()\n\n        if verbose:\n            print(f\"Added {len(view)} external_states. See `.externals` for details.\")\n\n    def data_stimulate(\n        self,\n        current: jnp.ndarray,\n        data_stimuli: Optional[Tuple[jnp.ndarray, pd.DataFrame]] = None,\n        verbose: bool = False,\n    ) -> Tuple[jnp.ndarray, pd.DataFrame]:\n        \"\"\"Insert a stimulus into the module within jit (or grad).\n\n        Args:\n            current: Current in `nA`.\n            verbose: Whether or not to print the number of inserted stimuli. `False`\n                by default because this method is meant to be jitted.\n        \"\"\"\n        return self._data_stimulate(current, data_stimuli, self.nodes, verbose=verbose)\n\n    def _data_stimulate(\n        self,\n        current: jnp.ndarray,\n        data_stimuli: Optional[Tuple[jnp.ndarray, pd.DataFrame]],\n        view: pd.DataFrame,\n        verbose: bool = False,\n    ) -> Tuple[jnp.ndarray, pd.DataFrame]:\n        current = current if current.ndim == 2 else jnp.expand_dims(current, axis=0)\n        batch_size = current.shape[0]\n        is_multiple = len(view) == batch_size\n        current = current if is_multiple else jnp.repeat(current, len(view), axis=0)\n        assert batch_size in [1, len(view)], \"Number of comps and stimuli do not match.\"\n\n        if data_stimuli is not None:\n            currents = data_stimuli[0]\n            inds = data_stimuli[1]\n        else:\n            currents = None\n            inds = pd.DataFrame().from_dict({})\n\n        # Same as in `.stimulate()`.\n        if currents is not None:\n            currents = jnp.concatenate([currents, current])\n        else:\n            currents = current\n        inds = pd.concat([inds, view])\n\n        if verbose:\n            print(f\"Added {len(view)} stimuli.\")\n\n        return (currents, inds)\n\n    def delete_stimuli(self):\n        \"\"\"Removes all stimuli from the module.\"\"\"\n        self.externals.pop(\"i\", None)\n        self.external_inds.pop(\"i\", None)\n\n    def insert(self, channel: Channel):\n        \"\"\"Insert a channel into the module.\n\n        Args:\n            channel: The channel to insert.\"\"\"\n        self._insert(channel, self.nodes)\n\n    def _insert(self, channel, view):\n        self._append_channel_to_nodes(view, channel)\n\n    def init_syns(self):\n        self.initialized_syns = True\n\n    def init_morph(self):\n        self.initialized_morph = True\n\n    def step(\n        self,\n        u: Dict[str, jnp.ndarray],\n        delta_t: float,\n        external_inds: Dict[str, jnp.ndarray],\n        externals: Dict[str, jnp.ndarray],\n        params: Dict[str, jnp.ndarray],\n        solver: str = \"bwd_euler\",\n        voltage_solver: str = \"jaxley.stone\",\n    ) -> Dict[str, jnp.ndarray]:\n        \"\"\"One step of solving the Ordinary Differential Equation.\n\n        This function is called inside of `integrate` and increments the state of the\n        module by one time step. Calls `_step_channels` and `_step_synapse` to update\n        the states of the channels and synapses using fwd_euler.\n\n        Args:\n            u: The state of the module. voltages = u[\"v\"]\n            delta_t: The time step.\n            external_inds: The indices of the external inputs.\n            externals: The external inputs.\n            params: The parameters of the module.\n            solver: The solver to use for the voltages. Either \"bwd_euler\" or \"fwd_euler\".\n            voltage_solver: The tridiagonal solver to used to diagonalize the\n                coefficient matrix of the ODE system. Either \"jaxley.thomas\",\n                \"jaxley.stone\", or \"jax.scipy.sparse\".\n\n        Returns:\n            The updated state of the module.\n        \"\"\"\n\n        # Extract the voltages\n        voltages = u[\"v\"]\n\n        # Extract the external inputs\n        has_current = \"i\" in externals.keys()\n        i_current = externals[\"i\"] if has_current else jnp.asarray([]).astype(\"float\")\n        i_inds = external_inds[\"i\"] if has_current else jnp.asarray([]).astype(\"int32\")\n        i_ext = self._get_external_input(\n            voltages, i_inds, i_current, params[\"radius\"], params[\"length\"]\n        )\n\n        # Step of the channels.\n        u, (v_terms, const_terms) = self._step_channels(\n            u, delta_t, self.channels, self.nodes, params\n        )\n\n        # Step of the synapse.\n        u, (syn_v_terms, syn_const_terms) = self._step_synapse(\n            u,\n            self.synapses,\n            params,\n            delta_t,\n            self.edges,\n        )\n\n        # Clamp for channels and synapses.\n        for key in externals.keys():\n            if key not in [\"i\", \"v\"]:\n                u[key] = u[key].at[external_inds[key]].set(externals[key])\n\n        # Voltage steps.\n        cm = params[\"capacitance\"]  # Abbreviation.\n        if solver == \"bwd_euler\":\n            new_voltages = step_voltage_implicit(\n                voltages=voltages,\n                voltage_terms=(v_terms + syn_v_terms) / cm,\n                constant_terms=(const_terms + i_ext + syn_const_terms) / cm,\n                coupling_conds_upper=params[\"branch_uppers\"],\n                coupling_conds_lower=params[\"branch_lowers\"],\n                summed_coupling_conds=params[\"branch_diags\"],\n                branchpoint_conds_children=params[\"branchpoint_conds_children\"],\n                branchpoint_conds_parents=params[\"branchpoint_conds_parents\"],\n                branchpoint_weights_children=params[\"branchpoint_weights_children\"],\n                branchpoint_weights_parents=params[\"branchpoint_weights_parents\"],\n                par_inds=self.par_inds,\n                child_inds=self.child_inds,\n                nbranches=self.total_nbranches,\n                solver=voltage_solver,\n                delta_t=delta_t,\n                children_in_level=self.children_in_level,\n                parents_in_level=self.parents_in_level,\n                root_inds=self.root_inds,\n                branchpoint_group_inds=self.branchpoint_group_inds,\n                debug_states=self.debug_states,\n            )\n        else:\n            new_voltages = step_voltage_explicit(\n                voltages,\n                (v_terms + syn_v_terms) / cm,\n                (const_terms + i_ext + syn_const_terms) / cm,\n                coupling_conds_bwd=params[\"coupling_conds_bwd\"],\n                coupling_conds_fwd=params[\"coupling_conds_fwd\"],\n                branch_cond_fwd=params[\"branch_conds_fwd\"],\n                branch_cond_bwd=params[\"branch_conds_bwd\"],\n                nbranches=self.total_nbranches,\n                parents=self.comb_parents,\n                delta_t=delta_t,\n            )\n\n        u[\"v\"] = new_voltages.ravel(order=\"C\")\n\n        # Clamp for voltages.\n        if \"v\" in externals.keys():\n            u[\"v\"] = u[\"v\"].at[external_inds[\"v\"]].set(externals[\"v\"])\n\n        return u\n\n    def _step_channels(\n        self,\n        states: Dict[str, jnp.ndarray],\n        delta_t: float,\n        channels: List[Channel],\n        channel_nodes: pd.DataFrame,\n        params: Dict[str, jnp.ndarray],\n    ) -> Tuple[Dict[str, jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray]]:\n        \"\"\"One step of integration of the channels and of computing their current.\"\"\"\n        states = self._step_channels_state(\n            states, delta_t, channels, channel_nodes, params\n        )\n        states, current_terms = self._channel_currents(\n            states, delta_t, channels, channel_nodes, params\n        )\n        return states, current_terms\n\n    def _step_channels_state(\n        self,\n        states,\n        delta_t,\n        channels: List[Channel],\n        channel_nodes: pd.DataFrame,\n        params: Dict[str, jnp.ndarray],\n    ) -> Dict[str, jnp.ndarray]:\n        \"\"\"One integration step of the channels.\"\"\"\n        voltages = states[\"v\"]\n\n        query = lambda d, keys, idcs: dict(\n            zip(keys, (v[idcs] for v in map(d.get, keys)))\n        )  # get dict with subset of keys and values from d\n        # only loops over necessary keys, as opposed to looping over d.items()\n\n        # Update states of the channels.\n        indices = channel_nodes[\"comp_index\"].to_numpy()\n        for channel in channels:\n            channel_param_names = list(channel.channel_params)\n            channel_param_names += [\"radius\", \"length\", \"axial_resistivity\"]\n            channel_state_names = list(channel.channel_states)\n            channel_state_names += self.membrane_current_names\n            channel_indices = indices[channel_nodes[channel._name].astype(bool)]\n\n            channel_params = query(params, channel_param_names, channel_indices)\n            channel_states = query(states, channel_state_names, channel_indices)\n\n            states_updated = channel.update_states(\n                channel_states, delta_t, voltages[channel_indices], channel_params\n            )\n            # Rebuild state. This has to be done within the loop over channels to allow\n            # multiple channels which modify the same state.\n            for key, val in states_updated.items():\n                states[key] = states[key].at[channel_indices].set(val)\n\n        return states\n\n    def _channel_currents(\n        self,\n        states: Dict[str, jnp.ndarray],\n        delta_t: float,\n        channels: List[Channel],\n        channel_nodes: pd.DataFrame,\n        params: Dict[str, jnp.ndarray],\n    ) -> Tuple[Dict[str, jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray]]:\n        \"\"\"Return the current through each channel.\n\n        This is also updates `state` because the `state` also contains the current.\n        \"\"\"\n        voltages = states[\"v\"]\n\n        # Compute current through channels.\n        voltage_terms = jnp.zeros_like(voltages)\n        constant_terms = jnp.zeros_like(voltages)\n        # Run with two different voltages that are `diff` apart to infer the slope and\n        # offset.\n        diff = 1e-3\n\n        current_states = {}\n        for name in self.membrane_current_names:\n            current_states[name] = jnp.zeros_like(voltages)\n\n        for channel in channels:\n            name = channel._name\n            channel_param_names = list(channel.channel_params.keys())\n            channel_state_names = list(channel.channel_states.keys())\n            indices = channel_nodes.loc[channel_nodes[name]][\"comp_index\"].to_numpy()\n\n            channel_params = {}\n            for p in channel_param_names:\n                channel_params[p] = params[p][indices]\n            channel_params[\"radius\"] = params[\"radius\"][indices]\n            channel_params[\"length\"] = params[\"length\"][indices]\n            channel_params[\"axial_resistivity\"] = params[\"axial_resistivity\"][indices]\n\n            channel_states = {}\n            for s in channel_state_names:\n                channel_states[s] = states[s][indices]\n\n            v_and_perturbed = jnp.stack([voltages[indices], voltages[indices] + diff])\n            membrane_currents = vmap(channel.compute_current, in_axes=(None, 0, None))(\n                channel_states, v_and_perturbed, channel_params\n            )\n            voltage_term = (membrane_currents[1] - membrane_currents[0]) / diff\n            constant_term = membrane_currents[0] - voltage_term * voltages[indices]\n            voltage_terms = voltage_terms.at[indices].add(voltage_term)\n            constant_terms = constant_terms.at[indices].add(-constant_term)\n\n            # Save the current (for the unperturbed voltage) as a state that will\n            # also be passed to the state update.\n            current_states[channel.current_name] = (\n                current_states[channel.current_name]\n                .at[indices]\n                .add(membrane_currents[0])\n            )\n\n        # Copy the currents into the `state` dictionary such that they can be\n        # recorded and used by `Channel.update_states()`.\n        for name in self.membrane_current_names:\n            states[name] = current_states[name]\n\n        return states, (voltage_terms, constant_terms)\n\n    def _step_synapse(\n        self,\n        u: Dict[str, jnp.ndarray],\n        syn_channels: List[Channel],\n        params: Dict[str, jnp.ndarray],\n        delta_t: float,\n        edges: pd.DataFrame,\n    ) -> Tuple[Dict[str, jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray]]:\n        \"\"\"One step of integration of the channels.\n\n        `Network` overrides this method (because it actually has synapses), whereas\n        `Compartment`, `Branch`, and `Cell` do not override this.\n        \"\"\"\n        voltages = u[\"v\"]\n        return u, (jnp.zeros_like(voltages), jnp.zeros_like(voltages))\n\n    def _synapse_currents(\n        self, states, syn_channels, params, delta_t, edges: pd.DataFrame\n    ) -> Tuple[Dict[str, jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray]]:\n        return states, (None, None)\n\n    @staticmethod\n    def _get_external_input(\n        voltages: jnp.ndarray,\n        i_inds: jnp.ndarray,\n        i_stim: jnp.ndarray,\n        radius: float,\n        length_single_compartment: float,\n    ) -> jnp.ndarray:\n        \"\"\"\n        Return external input to each compartment in uA / cm^2.\n\n        Args:\n            voltages: mV.\n            i_stim: nA.\n            radius: um.\n            length_single_compartment: um.\n        \"\"\"\n        zero_vec = jnp.zeros_like(voltages)\n        current = convert_point_process_to_distributed(\n            i_stim, radius[i_inds], length_single_compartment[i_inds]\n        )\n\n        dnums = ScatterDimensionNumbers(\n            update_window_dims=(),\n            inserted_window_dims=(0,),\n            scatter_dims_to_operand_dims=(0,),\n        )\n        stim_at_timestep = scatter_add(zero_vec, i_inds[:, None], current, dnums)\n        return stim_at_timestep\n\n    def vis(\n        self,\n        ax: Optional[Axes] = None,\n        col: str = \"k\",\n        dims: Tuple[int] = (0, 1),\n        type: str = \"line\",\n        morph_plot_kwargs: Dict = {},\n    ) -> Axes:\n        \"\"\"Visualize the module.\n\n        Args:\n            ax: An axis into which to plot.\n            col: The color for all branches.\n            dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of\n                two of them.\n            morph_plot_kwargs: Keyword arguments passed to the plotting function.\n        \"\"\"\n        return self._vis(\n            dims=dims,\n            col=col,\n            ax=ax,\n            view=self.nodes,\n            type=type,\n            morph_plot_kwargs=morph_plot_kwargs,\n        )\n\n    def _vis(\n        self,\n        ax: Axes,\n        col: str,\n        dims: Tuple[int],\n        view: pd.DataFrame,\n        type: str,\n        morph_plot_kwargs: Dict,\n    ) -> Axes:\n        branches_inds = view[\"branch_index\"].to_numpy()\n        coords = []\n        for branch_ind in branches_inds:\n            assert not np.any(\n                np.isnan(self.xyzr[branch_ind][:, dims])\n            ), \"No coordinates available. Use `vis(detail='point')` or run `.compute_xyz()` before running `.vis()`.\"\n            coords.append(self.xyzr[branch_ind])\n\n        ax = plot_morph(\n            coords,\n            dims=dims,\n            col=col,\n            ax=ax,\n            type=type,\n            morph_plot_kwargs=morph_plot_kwargs,\n        )\n\n        return ax\n\n    def _scatter(self, ax, col, dims, view, morph_plot_kwargs):\n        \"\"\"Scatter visualization (used only for compartments).\"\"\"\n        assert len(view) == 1, \"Scatter only deals with compartments.\"\n        branch_ind = view[\"branch_index\"].to_numpy().item()\n        comp_ind = view[\"comp_index\"].to_numpy().item()\n        assert not np.any(\n            np.isnan(self.xyzr[branch_ind][:, dims])\n        ), \"No coordinates available. Use `vis(detail='point')` or run `.compute_xyz()` before running `.vis()`.\"\n\n        comp_fraction = loc_of_index(comp_ind, self.nseg)\n        coords = self.xyzr[branch_ind]\n        interpolated_xyz = interpolate_xyz(comp_fraction, coords)\n\n        ax = plot_morph(\n            np.asarray([[interpolated_xyz]]),\n            dims=dims,\n            col=col,\n            ax=ax,\n            type=\"scatter\",\n            morph_plot_kwargs=morph_plot_kwargs,\n        )\n\n        return ax\n\n    def compute_xyz(self):\n        \"\"\"Return xyz coordinates of every branch, based on the branch length.\n\n        This function should not be called if the morphology was read from an `.swc`\n        file. However, for morphologies that were constructed from scratch, this\n        function **must** be called before `.vis()`. The computed `xyz` coordinates\n        are only used for plotting.\n        \"\"\"\n        max_y_multiplier = 5.0\n        min_y_multiplier = 0.5\n\n        parents = self.comb_parents\n        num_children = _compute_num_children(parents)\n        index_of_child = _compute_index_of_child(parents)\n        levels = compute_levels(parents)\n\n        # Extract branch.\n        inds_branch = self.nodes.groupby(\"branch_index\")[\"comp_index\"].apply(list)\n        branch_lens = [np.sum(self.nodes[\"length\"][np.asarray(i)]) for i in inds_branch]\n        endpoints = []\n\n        # Different levels will get a different \"angle\" at which the children emerge from\n        # the parents. This angle is defined by the `y_offset_multiplier`. This value\n        # defines the range between y-location of the first and of the last child of a\n        # parent.\n        y_offset_multiplier = np.linspace(\n            max_y_multiplier, min_y_multiplier, np.max(levels) + 1\n        )\n\n        for b in range(self.total_nbranches):\n            # For networks with mixed SWC and from-scatch neurons, only update those\n            # branches that do not have coordingates yet.\n            if np.any(np.isnan(self.xyzr[b])):\n                if parents[b] > -1:\n                    start_point = endpoints[parents[b]]\n                    num_children_of_parent = num_children[parents[b]]\n                    if num_children_of_parent > 1:\n                        y_offset = (\n                            ((index_of_child[b] / (num_children_of_parent - 1))) - 0.5\n                        ) * y_offset_multiplier[levels[b]]\n                    else:\n                        y_offset = 0.0\n                else:\n                    start_point = [0, 0, 0]\n                    y_offset = 0.0\n\n                len_of_path = np.sqrt(y_offset**2 + 1.0)\n\n                end_point = [\n                    start_point[0] + branch_lens[b] / len_of_path * 1.0,\n                    start_point[1] + branch_lens[b] / len_of_path * y_offset,\n                    start_point[2],\n                ]\n                endpoints.append(end_point)\n\n                self.xyzr[b][:, :3] = np.asarray([start_point, end_point])\n            else:\n                # Dummy to keey the index `endpoints[parent[b]]` above working.\n                endpoints.append(np.zeros((2,)))\n\n    def move(\n        self, x: float = 0.0, y: float = 0.0, z: float = 0.0, update_nodes: bool = True\n    ):\n        \"\"\"Move cells or networks by adding to their (x, y, z) coordinates.\n\n        This function is used only for visualization. It does not affect the simulation.\n\n        Args:\n            x: The amount to move in the x direction in um.\n            y: The amount to move in the y direction in um.\n            z: The amount to move in the z direction in um.\n            update_nodes: Whether `.nodes` should be updated or not. Setting this to\n                `False` largely speeds up moving, especially for big networks, but\n                `.nodes` or `.show` will not show the new xyz coordinates.\n        \"\"\"\n        self._move(x, y, z, self.nodes, update_nodes)\n\n    def _move(self, x: float, y: float, z: float, view, update_nodes: bool):\n        # Need to cast to set because this will return one columnn per compartment,\n        # not one column per branch.\n        indizes = set(view[\"branch_index\"].to_numpy().tolist())\n        for i in indizes:\n            self.xyzr[i][:, 0] += x\n            self.xyzr[i][:, 1] += y\n            self.xyzr[i][:, 2] += z\n        if update_nodes:\n            self._update_nodes_with_xyz()\n\n    def move_to(\n        self,\n        x: Union[float, np.ndarray] = 0.0,\n        y: Union[float, np.ndarray] = 0.0,\n        z: Union[float, np.ndarray] = 0.0,\n        update_nodes: bool = True,\n    ):\n        \"\"\"Move cells or networks to a location (x, y, z).\n\n        If x, y, and z are floats, then the first compartment of the first branch\n        of the first cell is moved to that float coordinate, and everything else is\n        shifted by the difference between that compartment's previous coordinate and\n        the new float location.\n\n        If x, y, and z are arrays, then they must each have a length equal to the number\n        of cells being moved. Then the first compartment of the first branch of each\n        cell is moved to the specified location.\n\n        Args:\n            update_nodes: Whether `.nodes` should be updated or not. Setting this to\n                `False` largely speeds up moving, especially for big networks, but\n                `.nodes` or `.show` will not show the new xyz coordinates.\n        \"\"\"\n        self._move_to(x, y, z, self.nodes, update_nodes)\n\n    def _move_to(\n        self,\n        x: Union[float, np.ndarray],\n        y: Union[float, np.ndarray],\n        z: Union[float, np.ndarray],\n        view: pd.DataFrame,\n        update_nodes: bool,\n    ):\n        # Test if any coordinate values are NaN which would greatly affect moving\n        if np.any(np.concatenate(self.xyzr, axis=0)[:, :3] == np.nan):\n            raise ValueError(\n                \"NaN coordinate values detected. Shift amounts cannot be computed. Please run compute_xyzr() or assign initial coordinate values.\"\n            )\n\n        # Get the indices of the cells and branches to move\n        cell_inds = list(view.cell_index.unique())\n        branch_inds = view.branch_index.unique()\n\n        if (\n            isinstance(x, np.ndarray)\n            and isinstance(y, np.ndarray)\n            and isinstance(z, np.ndarray)\n        ):\n            assert (\n                x.shape == y.shape == z.shape == (len(cell_inds),)\n            ), \"x, y, and z array shapes are not all equal to the number of cells to be moved.\"\n\n            # Split the branches by cell id\n            tup_indices = np.array([view.cell_index, view.branch_index])\n            view_cell_branch_inds = np.unique(tup_indices, axis=1)[0]\n            _, branch_split_inds = np.unique(view_cell_branch_inds, return_index=True)\n            branches_by_cell = np.split(\n                view.branch_index.unique(), branch_split_inds[1:]\n            )\n\n            # Calculate the amount to shift all of the branches of each cell\n            shift_amounts = (\n                np.array([x, y, z]).T - np.stack(self[cell_inds, 0].xyzr)[:, 0, :3]\n            )\n\n        else:\n            # Treat as if all branches belong to the same cell to be moved\n            branches_by_cell = [branch_inds]\n            # Calculate the amount to shift all branches by the 1st branch of 1st cell\n            shift_amounts = [np.array([x, y, z]) - self[cell_inds].xyzr[0][0, :3]]\n\n        # Move all of the branches\n        for i, branches in enumerate(branches_by_cell):\n            for b in branches:\n                self.xyzr[b][:, :3] += shift_amounts[i]\n\n        if update_nodes:\n            self._update_nodes_with_xyz()\n\n    def rotate(self, degrees: float, rotation_axis: str = \"xy\"):\n        \"\"\"Rotate jaxley modules clockwise. Used only for visualization.\n\n        This function is used only for visualization. It does not affect the simulation.\n\n        Args:\n            degrees: How many degrees to rotate the module by.\n            rotation_axis: Either of {`xy` | `xz` | `yz`}.\n        \"\"\"\n        self._rotate(degrees=degrees, rotation_axis=rotation_axis, view=self.nodes)\n\n    def _rotate(self, degrees: float, rotation_axis: str, view: pd.DataFrame):\n        degrees = degrees / 180 * np.pi\n        if rotation_axis == \"xy\":\n            dims = [0, 1]\n        elif rotation_axis == \"xz\":\n            dims = [0, 2]\n        elif rotation_axis == \"yz\":\n            dims = [1, 2]\n        else:\n            raise ValueError\n\n        rotation_matrix = np.asarray(\n            [[np.cos(degrees), np.sin(degrees)], [-np.sin(degrees), np.cos(degrees)]]\n        )\n        indizes = set(view[\"branch_index\"].to_numpy().tolist())\n        for i in indizes:\n            rot = np.dot(rotation_matrix, self.xyzr[i][:, dims].T).T\n            self.xyzr[i][:, dims] = rot\n\n    @property\n    def shape(self) -> Tuple[int]:\n        \"\"\"Returns the number of submodules contained in a module.\n\n        ```\n        network.shape = (num_cells, num_branches, num_compartments)\n        cell.shape = (num_branches, num_compartments)\n        branch.shape = (num_compartments,)\n        ```\"\"\"\n        mod_name = self.__class__.__name__.lower()\n        if \"comp\" in mod_name:\n            return (1,)\n        elif \"branch\" in mod_name:\n            return self[:].shape[1:]\n        return self[:].shape\n\n    def __getitem__(self, index):\n        return self._getitem(self, index)\n\n    def _getitem(\n        self,\n        module: Union[\"Module\", \"View\"],\n        index: Union[Tuple, int],\n        child_name: Optional[str] = None,\n    ) -> \"View\":\n        \"\"\"Return View which is created from indexing the module.\n\n        Args:\n            module: The module to be indexed. Will be a `Module` if `._getitem` is\n                called from `__getitem__` in a `Module` and will be a `View` if it was\n                called from `__getitem__` in a `View`.\n            index: The index (or indices) to index the module.\n            child_name: If passed, this will be the key that is used to index the\n                `module`, e.g. if it is the string `branch` then we will try to call\n                `module.xyz(index)`. If `None` then we try to infer automatically what\n                the childview should be, given the name of the `module`.\n\n        Returns:\n            An indexed `View`.\n        \"\"\"\n        if isinstance(index, tuple):\n            if len(index) > 1:\n                return childview(module, index[0], child_name)[index[1:]]\n            return childview(module, index[0], child_name)\n        return childview(module, index, child_name)\n\n    def __iter__(self):\n        for i in range(self.shape[0]):\n            yield self[i]\n\n    def _local_inds_to_global(\n        self, cell_inds: np.ndarray, branch_inds: np.ndarray, comp_inds: np.ndarray\n    ):\n        \"\"\"Given local inds of cell, branch, and comp, return the global comp index.\"\"\"\n        global_ind = (\n            self.cumsum_nbranches[cell_inds] + branch_inds\n        ) * self.nseg + comp_inds\n        return global_ind.astype(int)\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.initialized","title":"initialized property","text":"

    Whether the Module is ready to be solved or not.

    "},{"location":"reference/modules/#jaxley.modules.base.Module.shape","title":"shape: Tuple[int] property","text":"

    Returns the number of submodules contained in a module.

    network.shape = (num_cells, num_branches, num_compartments)\ncell.shape = (num_branches, num_compartments)\nbranch.shape = (num_compartments,)\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.add_to_group","title":"add_to_group(group_name)","text":"

    Add a view of the module to a group.

    Groups can then be indexed. For example:

    net.cell(0).add_to_group(\"excitatory\")\nnet.excitatory.set(\"radius\", 0.1)\n

    Parameters:

    Name Type Description Default group_name str

    The name of the group.

    required Source code in jaxley/modules/base.py
    def add_to_group(self, group_name: str):\n    \"\"\"Add a view of the module to a group.\n\n    Groups can then be indexed. For example:\n    ```python\n    net.cell(0).add_to_group(\"excitatory\")\n    net.excitatory.set(\"radius\", 0.1)\n    ```\n\n    Args:\n        group_name: The name of the group.\n    \"\"\"\n    raise ValueError(\"`add_to_group()` makes no sense for an entire module.\")\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.clamp","title":"clamp(state_name, state_array, verbose=True)","text":"

    Clamp a state to a given value across specified compartments.

    Parameters:

    Name Type Description Default state_name str

    The name of the state to clamp.

    required state_array nd

    Array of values to clamp the state to.

    required verbose

    If True, prints details about the clamping.

    True

    This function sets external states for the compartments.

    Source code in jaxley/modules/base.py
    def clamp(self, state_name: str, state_array: jnp.ndarray, verbose: bool = True):\n    \"\"\"Clamp a state to a given value across specified compartments.\n\n    Args:\n        state_name: The name of the state to clamp.\n        state_array (jnp.nd: Array of values to clamp the state to.\n        verbose : If True, prints details about the clamping.\n\n    This function sets external states for the compartments.\n    \"\"\"\n    if state_name not in self.nodes.columns:\n        raise KeyError(f\"{state_name} is not a recognized state in this module.\")\n    self._external_input(state_name, state_array, self.nodes, verbose=verbose)\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.compute_xyz","title":"compute_xyz()","text":"

    Return xyz coordinates of every branch, based on the branch length.

    This function should not be called if the morphology was read from an .swc file. However, for morphologies that were constructed from scratch, this function must be called before .vis(). The computed xyz coordinates are only used for plotting.

    Source code in jaxley/modules/base.py
    def compute_xyz(self):\n    \"\"\"Return xyz coordinates of every branch, based on the branch length.\n\n    This function should not be called if the morphology was read from an `.swc`\n    file. However, for morphologies that were constructed from scratch, this\n    function **must** be called before `.vis()`. The computed `xyz` coordinates\n    are only used for plotting.\n    \"\"\"\n    max_y_multiplier = 5.0\n    min_y_multiplier = 0.5\n\n    parents = self.comb_parents\n    num_children = _compute_num_children(parents)\n    index_of_child = _compute_index_of_child(parents)\n    levels = compute_levels(parents)\n\n    # Extract branch.\n    inds_branch = self.nodes.groupby(\"branch_index\")[\"comp_index\"].apply(list)\n    branch_lens = [np.sum(self.nodes[\"length\"][np.asarray(i)]) for i in inds_branch]\n    endpoints = []\n\n    # Different levels will get a different \"angle\" at which the children emerge from\n    # the parents. This angle is defined by the `y_offset_multiplier`. This value\n    # defines the range between y-location of the first and of the last child of a\n    # parent.\n    y_offset_multiplier = np.linspace(\n        max_y_multiplier, min_y_multiplier, np.max(levels) + 1\n    )\n\n    for b in range(self.total_nbranches):\n        # For networks with mixed SWC and from-scatch neurons, only update those\n        # branches that do not have coordingates yet.\n        if np.any(np.isnan(self.xyzr[b])):\n            if parents[b] > -1:\n                start_point = endpoints[parents[b]]\n                num_children_of_parent = num_children[parents[b]]\n                if num_children_of_parent > 1:\n                    y_offset = (\n                        ((index_of_child[b] / (num_children_of_parent - 1))) - 0.5\n                    ) * y_offset_multiplier[levels[b]]\n                else:\n                    y_offset = 0.0\n            else:\n                start_point = [0, 0, 0]\n                y_offset = 0.0\n\n            len_of_path = np.sqrt(y_offset**2 + 1.0)\n\n            end_point = [\n                start_point[0] + branch_lens[b] / len_of_path * 1.0,\n                start_point[1] + branch_lens[b] / len_of_path * y_offset,\n                start_point[2],\n            ]\n            endpoints.append(end_point)\n\n            self.xyzr[b][:, :3] = np.asarray([start_point, end_point])\n        else:\n            # Dummy to keey the index `endpoints[parent[b]]` above working.\n            endpoints.append(np.zeros((2,)))\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.data_set","title":"data_set(key, val, param_state)","text":"

    Set parameter of module (or its view) to a new value within jit.

    Parameters:

    Name Type Description Default key str

    The name of the parameter to set.

    required val Union[float, ndarray]

    The value to set the parameter to. If it is jnp.ndarray then it must be of shape (len(num_compartments)).

    required param_state Optional[List[Dict]]

    State of the setted parameters, internally used such that this function does not modify global state.

    required Source code in jaxley/modules/base.py
    def data_set(\n    self,\n    key: str,\n    val: Union[float, jnp.ndarray],\n    param_state: Optional[List[Dict]],\n):\n    \"\"\"Set parameter of module (or its view) to a new value within `jit`.\n\n    Args:\n        key: The name of the parameter to set.\n        val: The value to set the parameter to. If it is `jnp.ndarray` then it\n            must be of shape `(len(num_compartments))`.\n        param_state: State of the setted parameters, internally used such that this\n            function does not modify global state.\n    \"\"\"\n    view = (\n        self.edges\n        if key in self.synapse_param_names or key in self.synapse_state_names\n        else self.nodes\n    )\n    return self._data_set(key, val, view, param_state=param_state)\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.data_stimulate","title":"data_stimulate(current, data_stimuli=None, verbose=False)","text":"

    Insert a stimulus into the module within jit (or grad).

    Parameters:

    Name Type Description Default current ndarray

    Current in nA.

    required verbose bool

    Whether or not to print the number of inserted stimuli. False by default because this method is meant to be jitted.

    False Source code in jaxley/modules/base.py
    def data_stimulate(\n    self,\n    current: jnp.ndarray,\n    data_stimuli: Optional[Tuple[jnp.ndarray, pd.DataFrame]] = None,\n    verbose: bool = False,\n) -> Tuple[jnp.ndarray, pd.DataFrame]:\n    \"\"\"Insert a stimulus into the module within jit (or grad).\n\n    Args:\n        current: Current in `nA`.\n        verbose: Whether or not to print the number of inserted stimuli. `False`\n            by default because this method is meant to be jitted.\n    \"\"\"\n    return self._data_stimulate(current, data_stimuli, self.nodes, verbose=verbose)\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.delete_recordings","title":"delete_recordings()","text":"

    Removes all recordings from the module.

    Source code in jaxley/modules/base.py
    def delete_recordings(self):\n    \"\"\"Removes all recordings from the module.\"\"\"\n    self.recordings = pd.DataFrame().from_dict({})\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.delete_stimuli","title":"delete_stimuli()","text":"

    Removes all stimuli from the module.

    Source code in jaxley/modules/base.py
    def delete_stimuli(self):\n    \"\"\"Removes all stimuli from the module.\"\"\"\n    self.externals.pop(\"i\", None)\n    self.external_inds.pop(\"i\", None)\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.delete_trainables","title":"delete_trainables()","text":"

    Removes all trainable parameters from the module.

    Source code in jaxley/modules/base.py
    def delete_trainables(self):\n    \"\"\"Removes all trainable parameters from the module.\"\"\"\n    self.indices_set_by_trainables: List[jnp.ndarray] = []\n    self.trainable_params: List[Dict[str, jnp.ndarray]] = []\n    self.num_trainable_params: int = 0\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.get_all_parameters","title":"get_all_parameters(pstate)","text":"

    Return all parameters (and coupling conductances) needed to simulate.

    Runs init_conds() and return every parameter that is needed to solve the ODE. This includes conductances, radiuses, lengths, axial_resistivities, but also coupling conductances.

    This is done by first obtaining the current value of every parameter (not only the trainable ones) and then replacing the trainable ones with the value in trainable_params(). This function is run within jx.integrate().

    pstate can be obtained by calling params_to_pstate().

    params = module.get_parameters() # i.e. [0, 1, 2]\npstate = params_to_pstate(params, module.indices_set_by_trainables)\nmodule.to_jax() # needed for call to module.jaxnodes\n

    Parameters:

    Name Type Description Default pstate List[Dict]

    The state of the trainable parameters. pstate takes the form [{\u201ckey\u201d: \u201cgNa\u201d, \u201cindices\u201d: jnp.array([0, 1, 2]), \u201cval\u201d: jnp.array([0.1, 0.2, 0.3])}, \u2026].

    required

    Returns:

    Type Description Dict[str, ndarray]

    A dictionary of all module parameters.

    Source code in jaxley/modules/base.py
    def get_all_parameters(self, pstate: List[Dict]) -> Dict[str, jnp.ndarray]:\n    \"\"\"Return all parameters (and coupling conductances) needed to simulate.\n\n    Runs `init_conds()` and return every parameter that is needed to solve the ODE.\n    This includes conductances, radiuses, lengths, axial_resistivities, but also\n    coupling conductances.\n\n    This is done by first obtaining the current value of every parameter (not only\n    the trainable ones) and then replacing the trainable ones with the value\n    in `trainable_params()`. This function is run within `jx.integrate()`.\n\n    pstate can be obtained by calling `params_to_pstate()`.\n    ```\n    params = module.get_parameters() # i.e. [0, 1, 2]\n    pstate = params_to_pstate(params, module.indices_set_by_trainables)\n    module.to_jax() # needed for call to module.jaxnodes\n    ```\n\n    Args:\n        pstate: The state of the trainable parameters. pstate takes the form\n            [{\"key\": \"gNa\", \"indices\": jnp.array([0, 1, 2]), \"val\": jnp.array([0.1, 0.2, 0.3])}, ...].\n\n    Returns:\n        A dictionary of all module parameters.\n    \"\"\"\n    params = {}\n    for key in [\"radius\", \"length\", \"axial_resistivity\", \"capacitance\"]:\n        params[key] = self.jaxnodes[key]\n\n    for channel in self.channels:\n        for channel_params in channel.channel_params:\n            params[channel_params] = self.jaxnodes[channel_params]\n\n    for synapse_params in self.synapse_param_names:\n        params[synapse_params] = self.jaxedges[synapse_params]\n\n    # Override with those parameters set by `.make_trainable()`.\n    for parameter in pstate:\n        key = parameter[\"key\"]\n        inds = parameter[\"indices\"]\n        set_param = parameter[\"val\"]\n        if key in params:  # Only parameters, not initial states.\n            # `inds` is of shape `(num_params, num_comps_per_param)`.\n            # `set_param` is of shape `(num_params,)`\n            # We need to unsqueeze `set_param` to make it `(num_params, 1)` for the\n            # `.set()` to work. This is done with `[:, None]`.\n            params[key] = params[key].at[inds].set(set_param[:, None])\n\n    # Compute conductance params and append them.\n    cond_params = self.init_conds(params)\n    for key in cond_params:\n        params[key] = cond_params[key]\n\n    return params\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.get_all_states","title":"get_all_states(pstate, all_params, delta_t)","text":"

    Get the full initial state of the module from jaxnodes and trainables.

    Parameters:

    Name Type Description Default pstate List[Dict]

    The state of the trainable parameters.

    required all_params

    All parameters of the module.

    required delta_t float

    The time step.

    required

    Returns:

    Type Description Dict[str, ndarray]

    A dictionary of all states of the module.

    Source code in jaxley/modules/base.py
    def get_all_states(\n    self, pstate: List[Dict], all_params, delta_t: float\n) -> Dict[str, jnp.ndarray]:\n    \"\"\"Get the full initial state of the module from jaxnodes and trainables.\n\n    Args:\n        pstate: The state of the trainable parameters.\n        all_params: All parameters of the module.\n        delta_t: The time step.\n\n    Returns:\n        A dictionary of all states of the module.\n    \"\"\"\n    # Join node and edge states into a single state dictionary.\n    states = {\"v\": self.jaxnodes[\"v\"]}\n    for channel in self.channels:\n        for channel_states in channel.channel_states:\n            states[channel_states] = self.jaxnodes[channel_states]\n    for synapse_states in self.synapse_state_names:\n        states[synapse_states] = self.jaxedges[synapse_states]\n\n    # Override with the initial states set by `.make_trainable()`.\n    for parameter in pstate:\n        key = parameter[\"key\"]\n        inds = parameter[\"indices\"]\n        set_param = parameter[\"val\"]\n        if key in states:  # Only initial states, not parameters.\n            # `inds` is of shape `(num_params, num_comps_per_param)`.\n            # `set_param` is of shape `(num_params,)`\n            # We need to unsqueeze `set_param` to make it `(num_params, 1)` for the\n            # `.set()` to work. This is done with `[:, None]`.\n            states[key] = states[key].at[inds].set(set_param[:, None])\n\n    # Add to the states the initial current through every channel.\n    states, _ = self._channel_currents(\n        states, delta_t, self.channels, self.nodes, all_params\n    )\n\n    # Add to the states the initial current through every synapse.\n    states, _ = self._synapse_currents(\n        states, self.synapses, all_params, delta_t, self.edges\n    )\n    return states\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.get_parameters","title":"get_parameters()","text":"

    Get all trainable parameters.

    The returned parameters should be passed to `jx.integrate(\u2026, params=params).

    Returns:

    Type Description List[Dict[str, ndarray]]

    A list of all trainable parameters in the form of [{\u201cgNa\u201d: jnp.array([0.1, 0.2, 0.3])}, \u2026].

    Source code in jaxley/modules/base.py
    def get_parameters(self) -> List[Dict[str, jnp.ndarray]]:\n    \"\"\"Get all trainable parameters.\n\n    The returned parameters should be passed to `jx.integrate(..., params=params).\n\n    Returns:\n        A list of all trainable parameters in the form of\n            [{\"gNa\": jnp.array([0.1, 0.2, 0.3])}, ...].\n    \"\"\"\n    return self.trainable_params\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.init_conds","title":"init_conds(params) abstractmethod","text":"

    Initialize coupling conductances.

    Parameters:

    Name Type Description Default params Dict

    Conductances and morphology parameters, not yet including coupling conductances.

    required Source code in jaxley/modules/base.py
    @abstractmethod\ndef init_conds(self, params: Dict):\n    \"\"\"Initialize coupling conductances.\n\n    Args:\n        params: Conductances and morphology parameters, not yet including\n            coupling conductances.\n    \"\"\"\n    raise NotImplementedError\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.init_states","title":"init_states()","text":"

    Initialize all mechanisms in their steady state.

    This considers the voltages and parameters of each compartment.

    Source code in jaxley/modules/base.py
    def init_states(self):\n    \"\"\"Initialize all mechanisms in their steady state.\n\n    This considers the voltages and parameters of each compartment.\"\"\"\n    # Update states of the channels.\n    channel_nodes = self.nodes\n\n    for channel in self.channels:\n        name = channel._name\n        indices = channel_nodes.loc[channel_nodes[name]][\"comp_index\"].to_numpy()\n        voltages = channel_nodes.loc[indices, \"v\"].to_numpy()\n\n        channel_param_names = list(channel.channel_params.keys())\n        channel_params = {}\n        for p in channel_param_names:\n            channel_params[p] = channel_nodes[p][indices].to_numpy()\n\n        init_state = channel.init_state(voltages, channel_params)\n\n        # `init_state` might not return all channel states. Only the ones that are\n        # returned are updated here.\n        for key, val in init_state.items():\n            self.nodes.loc[indices, key] = val\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.initialize","title":"initialize()","text":"

    Initialize the module.

    Source code in jaxley/modules/base.py
    def initialize(self):\n    \"\"\"Initialize the module.\"\"\"\n    self.init_morph()\n    return self\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.insert","title":"insert(channel)","text":"

    Insert a channel into the module.

    Parameters:

    Name Type Description Default channel Channel

    The channel to insert.

    required Source code in jaxley/modules/base.py
    def insert(self, channel: Channel):\n    \"\"\"Insert a channel into the module.\n\n    Args:\n        channel: The channel to insert.\"\"\"\n    self._insert(channel, self.nodes)\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.make_trainable","title":"make_trainable(key, init_val=None, verbose=True)","text":"

    Make a parameter trainable.

    If a parameter is made trainable, it will be returned by get_parameters() and should then be passed to jx.integrate(..., params=params).

    Parameters:

    Name Type Description Default key str

    Name of the parameter to make trainable.

    required init_val Optional[Union[float, list]]

    Initial value of the parameter. If float, the same value is used for every created parameter. If list, the length of the list has to match the number of created parameters. If None, the current parameter value is used and if parameter sharing is performed that the current parameter value is averaged over all shared parameters.

    None verbose bool

    Whether to print the number of parameters that are added and the total number of parameters.

    True Source code in jaxley/modules/base.py
    def make_trainable(\n    self,\n    key: str,\n    init_val: Optional[Union[float, list]] = None,\n    verbose: bool = True,\n):\n    \"\"\"Make a parameter trainable.\n\n    If a parameter is made trainable, it will be returned by `get_parameters()`\n    and should then be passed to `jx.integrate(..., params=params)`.\n\n    Args:\n        key: Name of the parameter to make trainable.\n        init_val: Initial value of the parameter. If `float`, the same value is\n            used for every created parameter. If `list`, the length of the list has\n            to match the number of created parameters. If `None`, the current\n            parameter value is used and if parameter sharing is performed that the\n            current parameter value is averaged over all shared parameters.\n        verbose: Whether to print the number of parameters that are added and the\n            total number of parameters.\n    \"\"\"\n    assert (\n        key not in self.synapse_param_names and key not in self.synapse_state_names\n    ), \"Parameters of synapses can only be made trainable via the `SynapseView`.\"\n    view = self.nodes\n    view = deepcopy(view.assign(controlled_by_param=0))\n    self._make_trainable(view, key, init_val, verbose=verbose)\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.move","title":"move(x=0.0, y=0.0, z=0.0, update_nodes=True)","text":"

    Move cells or networks by adding to their (x, y, z) coordinates.

    This function is used only for visualization. It does not affect the simulation.

    Parameters:

    Name Type Description Default x float

    The amount to move in the x direction in um.

    0.0 y float

    The amount to move in the y direction in um.

    0.0 z float

    The amount to move in the z direction in um.

    0.0 update_nodes bool

    Whether .nodes should be updated or not. Setting this to False largely speeds up moving, especially for big networks, but .nodes or .show will not show the new xyz coordinates.

    True Source code in jaxley/modules/base.py
    def move(\n    self, x: float = 0.0, y: float = 0.0, z: float = 0.0, update_nodes: bool = True\n):\n    \"\"\"Move cells or networks by adding to their (x, y, z) coordinates.\n\n    This function is used only for visualization. It does not affect the simulation.\n\n    Args:\n        x: The amount to move in the x direction in um.\n        y: The amount to move in the y direction in um.\n        z: The amount to move in the z direction in um.\n        update_nodes: Whether `.nodes` should be updated or not. Setting this to\n            `False` largely speeds up moving, especially for big networks, but\n            `.nodes` or `.show` will not show the new xyz coordinates.\n    \"\"\"\n    self._move(x, y, z, self.nodes, update_nodes)\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.move_to","title":"move_to(x=0.0, y=0.0, z=0.0, update_nodes=True)","text":"

    Move cells or networks to a location (x, y, z).

    If x, y, and z are floats, then the first compartment of the first branch of the first cell is moved to that float coordinate, and everything else is shifted by the difference between that compartment\u2019s previous coordinate and the new float location.

    If x, y, and z are arrays, then they must each have a length equal to the number of cells being moved. Then the first compartment of the first branch of each cell is moved to the specified location.

    Parameters:

    Name Type Description Default update_nodes bool

    Whether .nodes should be updated or not. Setting this to False largely speeds up moving, especially for big networks, but .nodes or .show will not show the new xyz coordinates.

    True Source code in jaxley/modules/base.py
    def move_to(\n    self,\n    x: Union[float, np.ndarray] = 0.0,\n    y: Union[float, np.ndarray] = 0.0,\n    z: Union[float, np.ndarray] = 0.0,\n    update_nodes: bool = True,\n):\n    \"\"\"Move cells or networks to a location (x, y, z).\n\n    If x, y, and z are floats, then the first compartment of the first branch\n    of the first cell is moved to that float coordinate, and everything else is\n    shifted by the difference between that compartment's previous coordinate and\n    the new float location.\n\n    If x, y, and z are arrays, then they must each have a length equal to the number\n    of cells being moved. Then the first compartment of the first branch of each\n    cell is moved to the specified location.\n\n    Args:\n        update_nodes: Whether `.nodes` should be updated or not. Setting this to\n            `False` largely speeds up moving, especially for big networks, but\n            `.nodes` or `.show` will not show the new xyz coordinates.\n    \"\"\"\n    self._move_to(x, y, z, self.nodes, update_nodes)\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.record","title":"record(state='v', verbose=True)","text":"

    Insert a recording into the compartment.

    Parameters:

    Name Type Description Default state str

    The name of the state to record.

    'v' verbose bool

    Whether to print number of inserted recordings.

    True Source code in jaxley/modules/base.py
    def record(self, state: str = \"v\", verbose: bool = True):\n    \"\"\"Insert a recording into the compartment.\n\n    Args:\n        state: The name of the state to record.\n        verbose: Whether to print number of inserted recordings.\"\"\"\n    view = deepcopy(self.nodes)\n    view[\"state\"] = state\n    recording_view = view[[\"comp_index\", \"state\"]]\n    recording_view = recording_view.rename(columns={\"comp_index\": \"rec_index\"})\n    self._record(recording_view, verbose=verbose)\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.rotate","title":"rotate(degrees, rotation_axis='xy')","text":"

    Rotate jaxley modules clockwise. Used only for visualization.

    This function is used only for visualization. It does not affect the simulation.

    Parameters:

    Name Type Description Default degrees float

    How many degrees to rotate the module by.

    required rotation_axis str

    Either of {xy | xz | yz}.

    'xy' Source code in jaxley/modules/base.py
    def rotate(self, degrees: float, rotation_axis: str = \"xy\"):\n    \"\"\"Rotate jaxley modules clockwise. Used only for visualization.\n\n    This function is used only for visualization. It does not affect the simulation.\n\n    Args:\n        degrees: How many degrees to rotate the module by.\n        rotation_axis: Either of {`xy` | `xz` | `yz`}.\n    \"\"\"\n    self._rotate(degrees=degrees, rotation_axis=rotation_axis, view=self.nodes)\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.set","title":"set(key, val)","text":"

    Set parameter of module (or its view) to a new value.

    Note that this function can not be called within jax.jit or jax.grad. Instead, it should be used set the parameters of the module before the simulation. Use .data_set() to set parameters during jax.jit or jax.grad.

    Parameters:

    Name Type Description Default key str

    The name of the parameter to set.

    required val Union[float, ndarray]

    The value to set the parameter to. If it is jnp.ndarray then it must be of shape (len(num_compartments)).

    required Source code in jaxley/modules/base.py
    def set(self, key: str, val: Union[float, jnp.ndarray]):\n    \"\"\"Set parameter of module (or its view) to a new value.\n\n    Note that this function can not be called within `jax.jit` or `jax.grad`.\n    Instead, it should be used set the parameters of the module **before** the\n    simulation. Use `.data_set()` to set parameters during `jax.jit` or\n    `jax.grad`.\n\n    Args:\n        key: The name of the parameter to set.\n        val: The value to set the parameter to. If it is `jnp.ndarray` then it\n            must be of shape `(len(num_compartments))`.\n    \"\"\"\n    # TODO(@michaeldeistler) should we allow `.set()` for synaptic parameters\n    # without using the `SynapseView`, purely for consistency with `make_trainable`?\n    view = (\n        self.edges\n        if key in self.synapse_param_names or key in self.synapse_state_names\n        else self.nodes\n    )\n    self._set(key, val, view, view)\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.show","title":"show(param_names=None, *, indices=True, params=True, states=True, channel_names=None)","text":"

    Print detailed information about the Module or a view of it.

    Parameters:

    Name Type Description Default param_names Optional[Union[str, List[str]]]

    The names of the parameters to show. If None, all parameters are shown. NOT YET IMPLEMENTED.

    None indices bool

    Whether to show the indices of the compartments.

    True params bool

    Whether to show the parameters of the compartments.

    True states bool

    Whether to show the states of the compartments.

    True channel_names Optional[List[str]]

    The names of the channels to show. If None, all channels are shown.

    None

    Returns:

    Type Description DataFrame

    A pd.DataFrame with the requested information.

    Source code in jaxley/modules/base.py
    def show(\n    self,\n    param_names: Optional[Union[str, List[str]]] = None,  # TODO.\n    *,\n    indices: bool = True,\n    params: bool = True,\n    states: bool = True,\n    channel_names: Optional[List[str]] = None,\n) -> pd.DataFrame:\n    \"\"\"Print detailed information about the Module or a view of it.\n\n    Args:\n        param_names: The names of the parameters to show. If `None`, all parameters\n            are shown. NOT YET IMPLEMENTED.\n        indices: Whether to show the indices of the compartments.\n        params: Whether to show the parameters of the compartments.\n        states: Whether to show the states of the compartments.\n        channel_names: The names of the channels to show. If `None`, all channels are\n            shown.\n\n    Returns:\n        A `pd.DataFrame` with the requested information.\n    \"\"\"\n    return self._show(\n        self.nodes, param_names, indices, params, states, channel_names\n    )\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.step","title":"step(u, delta_t, external_inds, externals, params, solver='bwd_euler', voltage_solver='jaxley.stone')","text":"

    One step of solving the Ordinary Differential Equation.

    This function is called inside of integrate and increments the state of the module by one time step. Calls _step_channels and _step_synapse to update the states of the channels and synapses using fwd_euler.

    Parameters:

    Name Type Description Default u Dict[str, ndarray]

    The state of the module. voltages = u[\u201cv\u201d]

    required delta_t float

    The time step.

    required external_inds Dict[str, ndarray]

    The indices of the external inputs.

    required externals Dict[str, ndarray]

    The external inputs.

    required params Dict[str, ndarray]

    The parameters of the module.

    required solver str

    The solver to use for the voltages. Either \u201cbwd_euler\u201d or \u201cfwd_euler\u201d.

    'bwd_euler' voltage_solver str

    The tridiagonal solver to used to diagonalize the coefficient matrix of the ODE system. Either \u201cjaxley.thomas\u201d, \u201cjaxley.stone\u201d, or \u201cjax.scipy.sparse\u201d.

    'jaxley.stone'

    Returns:

    Type Description Dict[str, ndarray]

    The updated state of the module.

    Source code in jaxley/modules/base.py
    def step(\n    self,\n    u: Dict[str, jnp.ndarray],\n    delta_t: float,\n    external_inds: Dict[str, jnp.ndarray],\n    externals: Dict[str, jnp.ndarray],\n    params: Dict[str, jnp.ndarray],\n    solver: str = \"bwd_euler\",\n    voltage_solver: str = \"jaxley.stone\",\n) -> Dict[str, jnp.ndarray]:\n    \"\"\"One step of solving the Ordinary Differential Equation.\n\n    This function is called inside of `integrate` and increments the state of the\n    module by one time step. Calls `_step_channels` and `_step_synapse` to update\n    the states of the channels and synapses using fwd_euler.\n\n    Args:\n        u: The state of the module. voltages = u[\"v\"]\n        delta_t: The time step.\n        external_inds: The indices of the external inputs.\n        externals: The external inputs.\n        params: The parameters of the module.\n        solver: The solver to use for the voltages. Either \"bwd_euler\" or \"fwd_euler\".\n        voltage_solver: The tridiagonal solver to used to diagonalize the\n            coefficient matrix of the ODE system. Either \"jaxley.thomas\",\n            \"jaxley.stone\", or \"jax.scipy.sparse\".\n\n    Returns:\n        The updated state of the module.\n    \"\"\"\n\n    # Extract the voltages\n    voltages = u[\"v\"]\n\n    # Extract the external inputs\n    has_current = \"i\" in externals.keys()\n    i_current = externals[\"i\"] if has_current else jnp.asarray([]).astype(\"float\")\n    i_inds = external_inds[\"i\"] if has_current else jnp.asarray([]).astype(\"int32\")\n    i_ext = self._get_external_input(\n        voltages, i_inds, i_current, params[\"radius\"], params[\"length\"]\n    )\n\n    # Step of the channels.\n    u, (v_terms, const_terms) = self._step_channels(\n        u, delta_t, self.channels, self.nodes, params\n    )\n\n    # Step of the synapse.\n    u, (syn_v_terms, syn_const_terms) = self._step_synapse(\n        u,\n        self.synapses,\n        params,\n        delta_t,\n        self.edges,\n    )\n\n    # Clamp for channels and synapses.\n    for key in externals.keys():\n        if key not in [\"i\", \"v\"]:\n            u[key] = u[key].at[external_inds[key]].set(externals[key])\n\n    # Voltage steps.\n    cm = params[\"capacitance\"]  # Abbreviation.\n    if solver == \"bwd_euler\":\n        new_voltages = step_voltage_implicit(\n            voltages=voltages,\n            voltage_terms=(v_terms + syn_v_terms) / cm,\n            constant_terms=(const_terms + i_ext + syn_const_terms) / cm,\n            coupling_conds_upper=params[\"branch_uppers\"],\n            coupling_conds_lower=params[\"branch_lowers\"],\n            summed_coupling_conds=params[\"branch_diags\"],\n            branchpoint_conds_children=params[\"branchpoint_conds_children\"],\n            branchpoint_conds_parents=params[\"branchpoint_conds_parents\"],\n            branchpoint_weights_children=params[\"branchpoint_weights_children\"],\n            branchpoint_weights_parents=params[\"branchpoint_weights_parents\"],\n            par_inds=self.par_inds,\n            child_inds=self.child_inds,\n            nbranches=self.total_nbranches,\n            solver=voltage_solver,\n            delta_t=delta_t,\n            children_in_level=self.children_in_level,\n            parents_in_level=self.parents_in_level,\n            root_inds=self.root_inds,\n            branchpoint_group_inds=self.branchpoint_group_inds,\n            debug_states=self.debug_states,\n        )\n    else:\n        new_voltages = step_voltage_explicit(\n            voltages,\n            (v_terms + syn_v_terms) / cm,\n            (const_terms + i_ext + syn_const_terms) / cm,\n            coupling_conds_bwd=params[\"coupling_conds_bwd\"],\n            coupling_conds_fwd=params[\"coupling_conds_fwd\"],\n            branch_cond_fwd=params[\"branch_conds_fwd\"],\n            branch_cond_bwd=params[\"branch_conds_bwd\"],\n            nbranches=self.total_nbranches,\n            parents=self.comb_parents,\n            delta_t=delta_t,\n        )\n\n    u[\"v\"] = new_voltages.ravel(order=\"C\")\n\n    # Clamp for voltages.\n    if \"v\" in externals.keys():\n        u[\"v\"] = u[\"v\"].at[external_inds[\"v\"]].set(externals[\"v\"])\n\n    return u\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.stimulate","title":"stimulate(current=None, verbose=True)","text":"

    Insert a stimulus into the compartment.

    current must be a 1d array or have batch dimension of size (num_compartments, ) or (1, ). If 1d, the same stimulus is added to all compartments.

    This function cannot be run during jax.jit and jax.grad. Because of this, it should only be used for static stimuli (i.e., stimuli that do not depend on the data and that should not be learned). For stimuli that depend on data (or that should be learned), please use data_stimulate().

    Parameters:

    Name Type Description Default current Optional[ndarray]

    Current in nA.

    None Source code in jaxley/modules/base.py
    def stimulate(self, current: Optional[jnp.ndarray] = None, verbose: bool = True):\n    \"\"\"Insert a stimulus into the compartment.\n\n    current must be a 1d array or have batch dimension of size `(num_compartments, )`\n    or `(1, )`. If 1d, the same stimulus is added to all compartments.\n\n    This function cannot be run during `jax.jit` and `jax.grad`. Because of this,\n    it should only be used for static stimuli (i.e., stimuli that do not depend\n    on the data and that should not be learned). For stimuli that depend on data\n    (or that should be learned), please use `data_stimulate()`.\n\n    Args:\n        current: Current in `nA`.\n    \"\"\"\n    self._external_input(\"i\", current, self.nodes, verbose=verbose)\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.to_jax","title":"to_jax()","text":"

    Move .nodes to .jaxnodes.

    Before the actual simulation is run (via jx.integrate), all parameters of the jx.Module are stored in .nodes (a pd.DataFrame). However, for simulation, these parameters have to be moved to be jnp.ndarrays such that they can be processed on GPU/TPU and such that the simulation can be differentiated. .to_jax() copies the .nodes to .jaxnodes.

    Source code in jaxley/modules/base.py
    def to_jax(self):\n    \"\"\"Move `.nodes` to `.jaxnodes`.\n\n    Before the actual simulation is run (via `jx.integrate`), all parameters of\n    the `jx.Module` are stored in `.nodes` (a `pd.DataFrame`). However, for\n    simulation, these parameters have to be moved to be `jnp.ndarrays` such that\n    they can be processed on GPU/TPU and such that the simulation can be\n    differentiated. `.to_jax()` copies the `.nodes` to `.jaxnodes`.\n    \"\"\"\n    self.jaxnodes = {}\n    for key, value in self.nodes.to_dict(orient=\"list\").items():\n        inds = jnp.arange(len(value))\n        self.jaxnodes[key] = jnp.asarray(value)[inds]\n\n    # `jaxedges` contains only parameters (no indices).\n    # `jaxedges` contains only non-Nan elements. This is unlike the channels where\n    # we allow parameter sharing.\n    self.jaxedges = {}\n    edges = self.edges.to_dict(orient=\"list\")\n    for i, synapse in enumerate(self.synapses):\n        for key in synapse.synapse_params:\n            condition = np.asarray(edges[\"type_ind\"]) == i\n            self.jaxedges[key] = jnp.asarray(np.asarray(edges[key])[condition])\n        for key in synapse.synapse_states:\n            self.jaxedges[key] = jnp.asarray(np.asarray(edges[key])[condition])\n
    "},{"location":"reference/modules/#jaxley.modules.base.Module.vis","title":"vis(ax=None, col='k', dims=(0, 1), type='line', morph_plot_kwargs={})","text":"

    Visualize the module.

    Parameters:

    Name Type Description Default ax Optional[Axes]

    An axis into which to plot.

    None col str

    The color for all branches.

    'k' dims Tuple[int]

    Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of two of them.

    (0, 1) morph_plot_kwargs Dict

    Keyword arguments passed to the plotting function.

    {} Source code in jaxley/modules/base.py
    def vis(\n    self,\n    ax: Optional[Axes] = None,\n    col: str = \"k\",\n    dims: Tuple[int] = (0, 1),\n    type: str = \"line\",\n    morph_plot_kwargs: Dict = {},\n) -> Axes:\n    \"\"\"Visualize the module.\n\n    Args:\n        ax: An axis into which to plot.\n        col: The color for all branches.\n        dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of\n            two of them.\n        morph_plot_kwargs: Keyword arguments passed to the plotting function.\n    \"\"\"\n    return self._vis(\n        dims=dims,\n        col=col,\n        ax=ax,\n        view=self.nodes,\n        type=type,\n        morph_plot_kwargs=morph_plot_kwargs,\n    )\n
    "},{"location":"reference/modules/#compartment","title":"Compartment","text":"

    Bases: Module

    Compartment class.

    This class defines a single compartment that can be simulated by itself or connected up into branches. It is the basic building block of a neuron model.

    Source code in jaxley/modules/compartment.py
    class Compartment(Module):\n    \"\"\"Compartment class.\n\n    This class defines a single compartment that can be simulated by itself or\n    connected up into branches. It is the basic building block of a neuron model.\n    \"\"\"\n\n    compartment_params: Dict = {\n        \"length\": 10.0,  # um\n        \"radius\": 1.0,  # um\n        \"axial_resistivity\": 5_000.0,  # ohm cm\n        \"capacitance\": 1.0,  # uF/cm^2\n    }\n    compartment_states: Dict = {\"v\": -70.0}\n\n    def __init__(self):\n        super().__init__()\n\n        self.nseg = 1\n        self.total_nbranches = 1\n        self.nbranches_per_cell = [1]\n        self.cumsum_nbranches = jnp.asarray([0, 1])\n\n        # Setting up the `nodes` for indexing.\n        self.nodes = pd.DataFrame(\n            dict(comp_index=[0], branch_index=[0], cell_index=[0])\n        )\n        self._append_params_and_states(self.compartment_params, self.compartment_states)\n\n        # Synapses.\n        self.branch_edges = pd.DataFrame(\n            dict(parent_branch_index=[], child_branch_index=[])\n        )\n\n        # For morphology indexing.\n        self.child_inds = np.asarray([]).astype(int)\n        self.child_belongs_to_branchpoint = np.asarray([]).astype(int)\n        self.par_inds = np.asarray([]).astype(int)\n        self.total_nbranchpoints = 0\n        self.branchpoint_group_inds = np.asarray([]).astype(int)\n\n        self.children_in_level = []\n        self.parents_in_level = []\n        self.root_inds = jnp.asarray([0])\n\n        # Initialize the module.\n        self.initialize()\n        self.init_syns()\n        self.initialized_conds = True\n\n        # Coordinates.\n        self.xyzr = [float(\"NaN\") * np.zeros((2, 4))]\n\n    def init_conds(self, params):\n        return {\n            \"branchpoint_conds_children\": jnp.asarray([]),\n            \"branchpoint_conds_parents\": jnp.asarray([]),\n            \"branchpoint_weights_children\": jnp.asarray([]),\n            \"branchpoint_weights_parents\": jnp.asarray([]),\n            \"branch_uppers\": jnp.asarray([]),\n            \"branch_lowers\": jnp.asarray([]),\n            \"branch_diags\": jnp.asarray([0.0]),\n        }\n
    "},{"location":"reference/modules/#branch","title":"Branch","text":"

    Bases: Module

    Branch class.

    This class defines a single branch that can be simulated by itself or connected to build a cell. A branch is linear segment of several compartments and can be connected to no, one or more other branches at each end to build more intricate cell morphologies.

    Source code in jaxley/modules/branch.py
    class Branch(Module):\n    \"\"\"Branch class.\n\n    This class defines a single branch that can be simulated by itself or\n    connected to build a cell. A branch is linear segment of several compartments\n    and can be connected to no, one or more other branches at each end to build more\n    intricate cell morphologies.\n    \"\"\"\n\n    branch_params: Dict = {}\n    branch_states: Dict = {}\n\n    def __init__(\n        self,\n        compartments: Optional[Union[Compartment, List[Compartment]]] = None,\n        nseg: Optional[int] = None,\n    ):\n        \"\"\"\n        Args:\n            compartments: A single compartment or a list of compartments that make up the\n                branch.\n            nseg: Number of segments to divide the branch into. If `compartments` is an\n                a single compartment, than the compartment is repeated `nseg` times to\n                create the branch.\n        \"\"\"\n        super().__init__()\n        assert (\n            isinstance(compartments, (Compartment, List)) or compartments is None\n        ), \"Only Compartment or List[Compartment] is allowed.\"\n        if isinstance(compartments, Compartment):\n            assert (\n                nseg is not None\n            ), \"If `compartments` is not a list then you have to set `nseg`.\"\n        compartments = Compartment() if compartments is None else compartments\n        nseg = 1 if nseg is None else nseg\n\n        if isinstance(compartments, Compartment):\n            compartment_list = [compartments] * nseg\n        else:\n            compartment_list = compartments\n\n        self.nseg = len(compartment_list)\n        self.total_nbranches = 1\n        self.nbranches_per_cell = [1]\n        self.cumsum_nbranches = jnp.asarray([0, 1])\n\n        # Indexing.\n        self.nodes = pd.concat([c.nodes for c in compartment_list], ignore_index=True)\n        self._append_params_and_states(self.branch_params, self.branch_states)\n        self.nodes[\"comp_index\"] = np.arange(self.nseg).tolist()\n        self.nodes[\"branch_index\"] = [0] * self.nseg\n        self.nodes[\"cell_index\"] = [0] * self.nseg\n\n        # Channels.\n        self._gather_channels_from_constituents(compartment_list)\n\n        # Synapse indexing.\n        self.syn_edges = pd.DataFrame(\n            dict(global_pre_comp_index=[], global_post_comp_index=[], type=\"\")\n        )\n        self.branch_edges = pd.DataFrame(\n            dict(parent_branch_index=[], child_branch_index=[])\n        )\n\n        # For morphology indexing.\n        self.child_inds = np.asarray([]).astype(int)\n        self.child_belongs_to_branchpoint = np.asarray([]).astype(int)\n        self.par_inds = np.asarray([]).astype(int)\n        self.total_nbranchpoints = 0\n        self.branchpoint_group_inds = np.asarray([]).astype(int)\n\n        self.children_in_level = []\n        self.parents_in_level = []\n        self.root_inds = jnp.asarray([0])\n\n        self.initialize()\n        self.init_syns()\n        self.initialized_conds = False\n\n        # Coordinates.\n        self.xyzr = [float(\"NaN\") * np.zeros((2, 4))]\n\n    def __getattr__(self, key: str):\n        # Ensure that hidden methods such as `__deepcopy__` still work.\n        if key.startswith(\"__\"):\n            return super().__getattribute__(key)\n\n        if key in [\"comp\", \"loc\"]:\n            view = deepcopy(self.nodes)\n            view[\"global_comp_index\"] = view[\"comp_index\"]\n            view[\"global_branch_index\"] = view[\"branch_index\"]\n            view[\"global_cell_index\"] = view[\"cell_index\"]\n            compview = CompartmentView(self, view)\n            return compview if key == \"comp\" else compview.loc\n        elif key in self.group_nodes:\n            inds = self.group_nodes[key].index.values\n            view = self.nodes.loc[inds]\n            view[\"global_comp_index\"] = view[\"comp_index\"]\n            view[\"global_branch_index\"] = view[\"branch_index\"]\n            view[\"global_cell_index\"] = view[\"cell_index\"]\n            return GroupView(self, view, CompartmentView, [\"comp\", \"loc\"])\n        else:\n            raise KeyError(f\"Key {key} not recognized.\")\n\n    def init_conds(self, params: Dict) -> Dict[str, jnp.ndarray]:\n        conds = self.init_branch_conds(\n            params[\"axial_resistivity\"], params[\"radius\"], params[\"length\"], self.nseg\n        )\n        cond_params = {\n            \"branchpoint_conds_children\": jnp.asarray([]),\n            \"branchpoint_conds_parents\": jnp.asarray([]),\n            \"branchpoint_weights_children\": jnp.asarray([]),\n            \"branchpoint_weights_parents\": jnp.asarray([]),\n        }\n        cond_params[\"branch_lowers\"] = conds[0]\n        cond_params[\"branch_uppers\"] = conds[1]\n        cond_params[\"branch_diags\"] = conds[2]\n\n        return cond_params\n\n    @staticmethod\n    def init_branch_conds(\n        axial_resistivity: jnp.ndarray,\n        radiuses: jnp.ndarray,\n        lengths: jnp.ndarray,\n        nseg: int,\n    ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:\n        \"\"\"Given an axial resisitivity, set the coupling conductances.\n\n        Args:\n            axial_resistivity: Axial resistivity of each compartment.\n            radiuses: Radius of each compartment.\n            lengths: Length of each compartment.\n            nseg: Number of compartments in the branch.\n\n        Returns:\n            Tuple of forward coupling conductances, backward coupling conductances, and summed coupling conductances.\n        \"\"\"\n\n        # Compute coupling conductance for segments within a branch.\n        # `radius`: um\n        # `r_a`: ohm cm\n        # `length_single_compartment`: um\n        # `coupling_conds`: S * um / cm / um^2 = S / cm / um\n        r1 = radiuses[:-1]\n        r2 = radiuses[1:]\n        r_a1 = axial_resistivity[:-1]\n        r_a2 = axial_resistivity[1:]\n        l1 = lengths[:-1]\n        l2 = lengths[1:]\n        coupling_conds_bwd = compute_coupling_cond(r1, r2, r_a1, r_a2, l1, l2)\n        coupling_conds_fwd = compute_coupling_cond(r2, r1, r_a2, r_a1, l2, l1)\n\n        # Compute the summed coupling conductances of each compartment.\n        summed_coupling_conds = jnp.zeros((nseg))\n        summed_coupling_conds = summed_coupling_conds.at[1:].add(coupling_conds_fwd)\n        summed_coupling_conds = summed_coupling_conds.at[:-1].add(coupling_conds_bwd)\n        return coupling_conds_fwd, coupling_conds_bwd, summed_coupling_conds\n\n    def __len__(self) -> int:\n        return self.nseg\n
    "},{"location":"reference/modules/#jaxley.modules.branch.Branch.__init__","title":"__init__(compartments=None, nseg=None)","text":"

    Parameters:

    Name Type Description Default compartments Optional[Union[Compartment, List[Compartment]]]

    A single compartment or a list of compartments that make up the branch.

    None nseg Optional[int]

    Number of segments to divide the branch into. If compartments is an a single compartment, than the compartment is repeated nseg times to create the branch.

    None Source code in jaxley/modules/branch.py
    def __init__(\n    self,\n    compartments: Optional[Union[Compartment, List[Compartment]]] = None,\n    nseg: Optional[int] = None,\n):\n    \"\"\"\n    Args:\n        compartments: A single compartment or a list of compartments that make up the\n            branch.\n        nseg: Number of segments to divide the branch into. If `compartments` is an\n            a single compartment, than the compartment is repeated `nseg` times to\n            create the branch.\n    \"\"\"\n    super().__init__()\n    assert (\n        isinstance(compartments, (Compartment, List)) or compartments is None\n    ), \"Only Compartment or List[Compartment] is allowed.\"\n    if isinstance(compartments, Compartment):\n        assert (\n            nseg is not None\n        ), \"If `compartments` is not a list then you have to set `nseg`.\"\n    compartments = Compartment() if compartments is None else compartments\n    nseg = 1 if nseg is None else nseg\n\n    if isinstance(compartments, Compartment):\n        compartment_list = [compartments] * nseg\n    else:\n        compartment_list = compartments\n\n    self.nseg = len(compartment_list)\n    self.total_nbranches = 1\n    self.nbranches_per_cell = [1]\n    self.cumsum_nbranches = jnp.asarray([0, 1])\n\n    # Indexing.\n    self.nodes = pd.concat([c.nodes for c in compartment_list], ignore_index=True)\n    self._append_params_and_states(self.branch_params, self.branch_states)\n    self.nodes[\"comp_index\"] = np.arange(self.nseg).tolist()\n    self.nodes[\"branch_index\"] = [0] * self.nseg\n    self.nodes[\"cell_index\"] = [0] * self.nseg\n\n    # Channels.\n    self._gather_channels_from_constituents(compartment_list)\n\n    # Synapse indexing.\n    self.syn_edges = pd.DataFrame(\n        dict(global_pre_comp_index=[], global_post_comp_index=[], type=\"\")\n    )\n    self.branch_edges = pd.DataFrame(\n        dict(parent_branch_index=[], child_branch_index=[])\n    )\n\n    # For morphology indexing.\n    self.child_inds = np.asarray([]).astype(int)\n    self.child_belongs_to_branchpoint = np.asarray([]).astype(int)\n    self.par_inds = np.asarray([]).astype(int)\n    self.total_nbranchpoints = 0\n    self.branchpoint_group_inds = np.asarray([]).astype(int)\n\n    self.children_in_level = []\n    self.parents_in_level = []\n    self.root_inds = jnp.asarray([0])\n\n    self.initialize()\n    self.init_syns()\n    self.initialized_conds = False\n\n    # Coordinates.\n    self.xyzr = [float(\"NaN\") * np.zeros((2, 4))]\n
    "},{"location":"reference/modules/#jaxley.modules.branch.Branch.init_branch_conds","title":"init_branch_conds(axial_resistivity, radiuses, lengths, nseg) staticmethod","text":"

    Given an axial resisitivity, set the coupling conductances.

    Parameters:

    Name Type Description Default axial_resistivity ndarray

    Axial resistivity of each compartment.

    required radiuses ndarray

    Radius of each compartment.

    required lengths ndarray

    Length of each compartment.

    required nseg int

    Number of compartments in the branch.

    required

    Returns:

    Type Description Tuple[ndarray, ndarray, ndarray]

    Tuple of forward coupling conductances, backward coupling conductances, and summed coupling conductances.

    Source code in jaxley/modules/branch.py
    @staticmethod\ndef init_branch_conds(\n    axial_resistivity: jnp.ndarray,\n    radiuses: jnp.ndarray,\n    lengths: jnp.ndarray,\n    nseg: int,\n) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:\n    \"\"\"Given an axial resisitivity, set the coupling conductances.\n\n    Args:\n        axial_resistivity: Axial resistivity of each compartment.\n        radiuses: Radius of each compartment.\n        lengths: Length of each compartment.\n        nseg: Number of compartments in the branch.\n\n    Returns:\n        Tuple of forward coupling conductances, backward coupling conductances, and summed coupling conductances.\n    \"\"\"\n\n    # Compute coupling conductance for segments within a branch.\n    # `radius`: um\n    # `r_a`: ohm cm\n    # `length_single_compartment`: um\n    # `coupling_conds`: S * um / cm / um^2 = S / cm / um\n    r1 = radiuses[:-1]\n    r2 = radiuses[1:]\n    r_a1 = axial_resistivity[:-1]\n    r_a2 = axial_resistivity[1:]\n    l1 = lengths[:-1]\n    l2 = lengths[1:]\n    coupling_conds_bwd = compute_coupling_cond(r1, r2, r_a1, r_a2, l1, l2)\n    coupling_conds_fwd = compute_coupling_cond(r2, r1, r_a2, r_a1, l2, l1)\n\n    # Compute the summed coupling conductances of each compartment.\n    summed_coupling_conds = jnp.zeros((nseg))\n    summed_coupling_conds = summed_coupling_conds.at[1:].add(coupling_conds_fwd)\n    summed_coupling_conds = summed_coupling_conds.at[:-1].add(coupling_conds_bwd)\n    return coupling_conds_fwd, coupling_conds_bwd, summed_coupling_conds\n
    "},{"location":"reference/modules/#cell","title":"Cell","text":"

    Bases: Module

    Cell class.

    This class defines a single cell that can be simulated by itself or connected with synapses to build a network. A cell is made up of several branches and supports intricate cell morphologies.

    Source code in jaxley/modules/cell.py
    class Cell(Module):\n    \"\"\"Cell class.\n\n    This class defines a single cell that can be simulated by itself or\n    connected with synapses to build a network. A cell is made up of several branches\n    and supports intricate cell morphologies.\n    \"\"\"\n\n    cell_params: Dict = {}\n    cell_states: Dict = {}\n\n    def __init__(\n        self,\n        branches: Optional[Union[Branch, List[Branch]]] = None,\n        parents: Optional[List[int]] = None,\n        xyzr: Optional[List[np.ndarray]] = None,\n    ):\n        \"\"\"Initialize a cell.\n\n        Args:\n            branches: A single branch or a list of branches that make up the cell.\n                If a single branch is provided, then the branch is repeated `len(parents)`\n                times to create the cell.\n            parents: The parent branch index for each branch. The first branch has no\n                parent and is therefore set to -1.\n            xyzr: For every branch, the x, y, and z coordinates and the radius at the\n                traced coordinates. Note that this is the full tracing (from SWC), not\n                the stick representation coordinates.\n        \"\"\"\n        super().__init__()\n        assert (\n            isinstance(branches, (Branch, List)) or branches is None\n        ), \"Only Branch or List[Branch] is allowed.\"\n        if branches is not None:\n            assert (\n                parents is not None\n            ), \"If `branches` is not a list then you have to set `parents`.\"\n        if isinstance(branches, List):\n            assert len(parents) == len(\n                branches\n            ), \"Ensure equally many parents, i.e. len(branches) == len(parents).\"\n\n        branches = Branch() if branches is None else branches\n        parents = [-1] if parents is None else parents\n\n        if isinstance(branches, Branch):\n            branch_list = [branches for _ in range(len(parents))]\n        else:\n            branch_list = branches\n\n        if xyzr is not None:\n            assert len(xyzr) == len(parents)\n            self.xyzr = xyzr\n        else:\n            # For every branch (`len(parents)`), we have a start and end point (`2`) and\n            # a (x,y,z,r) coordinate for each of them (`4`).\n            # Since `xyzr` is only inspected at `.vis()` and because it depends on the\n            # (potentially learned) length of every compartment, we only populate\n            # self.xyzr at `.vis()`.\n            self.xyzr = [float(\"NaN\") * np.zeros((2, 4)) for _ in range(len(parents))]\n\n        self.nseg = branch_list[0].nseg\n        self.total_nbranches = len(branch_list)\n        self.nbranches_per_cell = [len(branch_list)]\n        self.comb_parents = jnp.asarray(parents)\n        self.comb_children = compute_children_indices(self.comb_parents)\n        self.cumsum_nbranches = jnp.asarray([0, len(branch_list)])\n\n        # Indexing.\n        self.nodes = pd.concat([c.nodes for c in branch_list], ignore_index=True)\n        self._append_params_and_states(self.cell_params, self.cell_states)\n        self.nodes[\"comp_index\"] = np.arange(self.nseg * self.total_nbranches).tolist()\n        self.nodes[\"branch_index\"] = (\n            np.arange(self.nseg * self.total_nbranches) // self.nseg\n        ).tolist()\n        self.nodes[\"cell_index\"] = [0] * (self.nseg * self.total_nbranches)\n\n        # Channels.\n        self._gather_channels_from_constituents(branch_list)\n\n        # Synapse indexing.\n        self.syn_edges = pd.DataFrame(\n            dict(global_pre_comp_index=[], global_post_comp_index=[], type=\"\")\n        )\n        self.branch_edges = pd.DataFrame(\n            dict(\n                parent_branch_index=self.comb_parents[1:],\n                child_branch_index=np.arange(1, self.total_nbranches),\n            )\n        )\n\n        # For morphology indexing.\n        par_inds = self.branch_edges[\"parent_branch_index\"].to_numpy()\n        self.child_inds = self.branch_edges[\"child_branch_index\"].to_numpy()\n        self.child_belongs_to_branchpoint = remap_to_consecutive(par_inds)\n\n        # TODO: does order have to be preserved?\n        self.par_inds = np.unique(par_inds)\n        self.total_nbranchpoints = len(self.par_inds)\n        self.root_inds = jnp.asarray([0])\n\n        self.initialize()\n\n        self.init_syns()\n        self.initialized_conds = False\n\n    def __getattr__(self, key: str):\n        # Ensure that hidden methods such as `__deepcopy__` still work.\n        if key.startswith(\"__\"):\n            return super().__getattribute__(key)\n\n        if key == \"branch\":\n            view = deepcopy(self.nodes)\n            view[\"global_comp_index\"] = view[\"comp_index\"]\n            view[\"global_branch_index\"] = view[\"branch_index\"]\n            view[\"global_cell_index\"] = view[\"cell_index\"]\n            return BranchView(self, view)\n        elif key in self.group_nodes:\n            inds = self.group_nodes[key].index.values\n            view = self.nodes.loc[inds]\n            view[\"global_comp_index\"] = view[\"comp_index\"]\n            view[\"global_branch_index\"] = view[\"branch_index\"]\n            view[\"global_cell_index\"] = view[\"cell_index\"]\n            return GroupView(self, view, BranchView, [\"branch\"])\n        else:\n            raise KeyError(f\"Key {key} not recognized.\")\n\n    def init_morph(self):\n        \"\"\"Initialize morphology.\"\"\"\n\n        # For Jaxley custom implementation.\n        children_and_parents = compute_morphology_indices_in_levels(\n            len(self.par_inds),\n            self.child_belongs_to_branchpoint,\n            self.par_inds,\n            self.child_inds,\n        )\n        self.branchpoint_group_inds = build_branchpoint_group_inds(\n            len(self.par_inds),\n            self.child_belongs_to_branchpoint,\n            self.nseg,\n            self.total_nbranches,\n        )\n        parents = self.comb_parents\n        children_inds = children_and_parents[\"children\"]\n        parents_inds = children_and_parents[\"parents\"]\n\n        levels = compute_levels(parents)\n        self.children_in_level = compute_children_in_level(levels, children_inds)\n        self.parents_in_level = compute_parents_in_level(\n            levels, self.par_inds, parents_inds\n        )\n\n        self.initialized_morph = True\n\n    def init_conds(self, params: Dict) -> Dict[str, jnp.ndarray]:\n        \"\"\"Given an axial resisitivity, set the coupling conductances.\"\"\"\n        nbranches = self.total_nbranches\n        nseg = self.nseg\n\n        axial_resistivity = jnp.reshape(params[\"axial_resistivity\"], (nbranches, nseg))\n        radiuses = jnp.reshape(params[\"radius\"], (nbranches, nseg))\n        lengths = jnp.reshape(params[\"length\"], (nbranches, nseg))\n\n        conds = vmap(Branch.init_branch_conds, in_axes=(0, 0, 0, None))(\n            axial_resistivity, radiuses, lengths, self.nseg\n        )\n        coupling_conds_fwd = conds[0]\n        coupling_conds_bwd = conds[1]\n        summed_coupling_conds = conds[2]\n\n        # The conductance from the children to the branch point.\n        branchpoint_conds_children = vmap(\n            compute_coupling_cond_branchpoint, in_axes=(0, 0, 0)\n        )(\n            radiuses[self.child_inds, 0],\n            axial_resistivity[self.child_inds, 0],\n            lengths[self.child_inds, 0],\n        )\n        # The conductance from the parents to the branch point.\n        branchpoint_conds_parents = vmap(\n            compute_coupling_cond_branchpoint, in_axes=(0, 0, 0)\n        )(\n            radiuses[self.par_inds, -1],\n            axial_resistivity[self.par_inds, -1],\n            lengths[self.par_inds, -1],\n        )\n\n        # Weights with which the compartments influence their nearby node.\n        # The impact of the children on the branch point.\n        branchpoint_weights_children = vmap(compute_impact_on_node, in_axes=(0, 0, 0))(\n            radiuses[self.child_inds, 0],\n            axial_resistivity[self.child_inds, 0],\n            lengths[self.child_inds, 0],\n        )\n        # The impact of parents on the branch point.\n        branchpoint_weights_parents = vmap(compute_impact_on_node, in_axes=(0, 0, 0))(\n            radiuses[self.par_inds, -1],\n            axial_resistivity[self.par_inds, -1],\n            lengths[self.par_inds, -1],\n        )\n\n        summed_coupling_conds = self.update_summed_coupling_conds(\n            summed_coupling_conds,\n            self.child_inds,\n            self.par_inds,\n            branchpoint_conds_children,\n            branchpoint_conds_parents,\n        )\n\n        cond_params = {\n            \"branch_uppers\": coupling_conds_bwd,\n            \"branch_lowers\": coupling_conds_fwd,\n            \"branch_diags\": summed_coupling_conds,\n            \"branchpoint_conds_children\": branchpoint_conds_children,\n            \"branchpoint_conds_parents\": branchpoint_conds_parents,\n            \"branchpoint_weights_children\": branchpoint_weights_children,\n            \"branchpoint_weights_parents\": branchpoint_weights_parents,\n        }\n        return cond_params\n\n    @staticmethod\n    def update_summed_coupling_conds(\n        summed_conds,\n        child_inds,\n        par_inds,\n        branchpoint_conds_children,\n        branchpoint_conds_parents,\n    ):\n        \"\"\"Perform updates on the diagonal based on conductances of the branchpoints.\n\n        Args:\n            summed_conds: shape [num_branches, nseg]\n            child_inds: shape [num_branches - 1]\n            conds_fwd: shape [num_branches - 1]\n            conds_bwd: shape [num_branches - 1]\n            parents: shape [num_branches]\n\n        Returns:\n            Updated `summed_coupling_conds`.\n        \"\"\"\n        summed_conds = summed_conds.at[child_inds, 0].add(branchpoint_conds_children)\n        summed_conds = summed_conds.at[par_inds, -1].add(branchpoint_conds_parents)\n        return summed_conds\n
    "},{"location":"reference/modules/#jaxley.modules.cell.Cell.__init__","title":"__init__(branches=None, parents=None, xyzr=None)","text":"

    Initialize a cell.

    Parameters:

    Name Type Description Default branches Optional[Union[Branch, List[Branch]]]

    A single branch or a list of branches that make up the cell. If a single branch is provided, then the branch is repeated len(parents) times to create the cell.

    None parents Optional[List[int]]

    The parent branch index for each branch. The first branch has no parent and is therefore set to -1.

    None xyzr Optional[List[ndarray]]

    For every branch, the x, y, and z coordinates and the radius at the traced coordinates. Note that this is the full tracing (from SWC), not the stick representation coordinates.

    None Source code in jaxley/modules/cell.py
    def __init__(\n    self,\n    branches: Optional[Union[Branch, List[Branch]]] = None,\n    parents: Optional[List[int]] = None,\n    xyzr: Optional[List[np.ndarray]] = None,\n):\n    \"\"\"Initialize a cell.\n\n    Args:\n        branches: A single branch or a list of branches that make up the cell.\n            If a single branch is provided, then the branch is repeated `len(parents)`\n            times to create the cell.\n        parents: The parent branch index for each branch. The first branch has no\n            parent and is therefore set to -1.\n        xyzr: For every branch, the x, y, and z coordinates and the radius at the\n            traced coordinates. Note that this is the full tracing (from SWC), not\n            the stick representation coordinates.\n    \"\"\"\n    super().__init__()\n    assert (\n        isinstance(branches, (Branch, List)) or branches is None\n    ), \"Only Branch or List[Branch] is allowed.\"\n    if branches is not None:\n        assert (\n            parents is not None\n        ), \"If `branches` is not a list then you have to set `parents`.\"\n    if isinstance(branches, List):\n        assert len(parents) == len(\n            branches\n        ), \"Ensure equally many parents, i.e. len(branches) == len(parents).\"\n\n    branches = Branch() if branches is None else branches\n    parents = [-1] if parents is None else parents\n\n    if isinstance(branches, Branch):\n        branch_list = [branches for _ in range(len(parents))]\n    else:\n        branch_list = branches\n\n    if xyzr is not None:\n        assert len(xyzr) == len(parents)\n        self.xyzr = xyzr\n    else:\n        # For every branch (`len(parents)`), we have a start and end point (`2`) and\n        # a (x,y,z,r) coordinate for each of them (`4`).\n        # Since `xyzr` is only inspected at `.vis()` and because it depends on the\n        # (potentially learned) length of every compartment, we only populate\n        # self.xyzr at `.vis()`.\n        self.xyzr = [float(\"NaN\") * np.zeros((2, 4)) for _ in range(len(parents))]\n\n    self.nseg = branch_list[0].nseg\n    self.total_nbranches = len(branch_list)\n    self.nbranches_per_cell = [len(branch_list)]\n    self.comb_parents = jnp.asarray(parents)\n    self.comb_children = compute_children_indices(self.comb_parents)\n    self.cumsum_nbranches = jnp.asarray([0, len(branch_list)])\n\n    # Indexing.\n    self.nodes = pd.concat([c.nodes for c in branch_list], ignore_index=True)\n    self._append_params_and_states(self.cell_params, self.cell_states)\n    self.nodes[\"comp_index\"] = np.arange(self.nseg * self.total_nbranches).tolist()\n    self.nodes[\"branch_index\"] = (\n        np.arange(self.nseg * self.total_nbranches) // self.nseg\n    ).tolist()\n    self.nodes[\"cell_index\"] = [0] * (self.nseg * self.total_nbranches)\n\n    # Channels.\n    self._gather_channels_from_constituents(branch_list)\n\n    # Synapse indexing.\n    self.syn_edges = pd.DataFrame(\n        dict(global_pre_comp_index=[], global_post_comp_index=[], type=\"\")\n    )\n    self.branch_edges = pd.DataFrame(\n        dict(\n            parent_branch_index=self.comb_parents[1:],\n            child_branch_index=np.arange(1, self.total_nbranches),\n        )\n    )\n\n    # For morphology indexing.\n    par_inds = self.branch_edges[\"parent_branch_index\"].to_numpy()\n    self.child_inds = self.branch_edges[\"child_branch_index\"].to_numpy()\n    self.child_belongs_to_branchpoint = remap_to_consecutive(par_inds)\n\n    # TODO: does order have to be preserved?\n    self.par_inds = np.unique(par_inds)\n    self.total_nbranchpoints = len(self.par_inds)\n    self.root_inds = jnp.asarray([0])\n\n    self.initialize()\n\n    self.init_syns()\n    self.initialized_conds = False\n
    "},{"location":"reference/modules/#jaxley.modules.cell.Cell.init_conds","title":"init_conds(params)","text":"

    Given an axial resisitivity, set the coupling conductances.

    Source code in jaxley/modules/cell.py
    def init_conds(self, params: Dict) -> Dict[str, jnp.ndarray]:\n    \"\"\"Given an axial resisitivity, set the coupling conductances.\"\"\"\n    nbranches = self.total_nbranches\n    nseg = self.nseg\n\n    axial_resistivity = jnp.reshape(params[\"axial_resistivity\"], (nbranches, nseg))\n    radiuses = jnp.reshape(params[\"radius\"], (nbranches, nseg))\n    lengths = jnp.reshape(params[\"length\"], (nbranches, nseg))\n\n    conds = vmap(Branch.init_branch_conds, in_axes=(0, 0, 0, None))(\n        axial_resistivity, radiuses, lengths, self.nseg\n    )\n    coupling_conds_fwd = conds[0]\n    coupling_conds_bwd = conds[1]\n    summed_coupling_conds = conds[2]\n\n    # The conductance from the children to the branch point.\n    branchpoint_conds_children = vmap(\n        compute_coupling_cond_branchpoint, in_axes=(0, 0, 0)\n    )(\n        radiuses[self.child_inds, 0],\n        axial_resistivity[self.child_inds, 0],\n        lengths[self.child_inds, 0],\n    )\n    # The conductance from the parents to the branch point.\n    branchpoint_conds_parents = vmap(\n        compute_coupling_cond_branchpoint, in_axes=(0, 0, 0)\n    )(\n        radiuses[self.par_inds, -1],\n        axial_resistivity[self.par_inds, -1],\n        lengths[self.par_inds, -1],\n    )\n\n    # Weights with which the compartments influence their nearby node.\n    # The impact of the children on the branch point.\n    branchpoint_weights_children = vmap(compute_impact_on_node, in_axes=(0, 0, 0))(\n        radiuses[self.child_inds, 0],\n        axial_resistivity[self.child_inds, 0],\n        lengths[self.child_inds, 0],\n    )\n    # The impact of parents on the branch point.\n    branchpoint_weights_parents = vmap(compute_impact_on_node, in_axes=(0, 0, 0))(\n        radiuses[self.par_inds, -1],\n        axial_resistivity[self.par_inds, -1],\n        lengths[self.par_inds, -1],\n    )\n\n    summed_coupling_conds = self.update_summed_coupling_conds(\n        summed_coupling_conds,\n        self.child_inds,\n        self.par_inds,\n        branchpoint_conds_children,\n        branchpoint_conds_parents,\n    )\n\n    cond_params = {\n        \"branch_uppers\": coupling_conds_bwd,\n        \"branch_lowers\": coupling_conds_fwd,\n        \"branch_diags\": summed_coupling_conds,\n        \"branchpoint_conds_children\": branchpoint_conds_children,\n        \"branchpoint_conds_parents\": branchpoint_conds_parents,\n        \"branchpoint_weights_children\": branchpoint_weights_children,\n        \"branchpoint_weights_parents\": branchpoint_weights_parents,\n    }\n    return cond_params\n
    "},{"location":"reference/modules/#jaxley.modules.cell.Cell.init_morph","title":"init_morph()","text":"

    Initialize morphology.

    Source code in jaxley/modules/cell.py
    def init_morph(self):\n    \"\"\"Initialize morphology.\"\"\"\n\n    # For Jaxley custom implementation.\n    children_and_parents = compute_morphology_indices_in_levels(\n        len(self.par_inds),\n        self.child_belongs_to_branchpoint,\n        self.par_inds,\n        self.child_inds,\n    )\n    self.branchpoint_group_inds = build_branchpoint_group_inds(\n        len(self.par_inds),\n        self.child_belongs_to_branchpoint,\n        self.nseg,\n        self.total_nbranches,\n    )\n    parents = self.comb_parents\n    children_inds = children_and_parents[\"children\"]\n    parents_inds = children_and_parents[\"parents\"]\n\n    levels = compute_levels(parents)\n    self.children_in_level = compute_children_in_level(levels, children_inds)\n    self.parents_in_level = compute_parents_in_level(\n        levels, self.par_inds, parents_inds\n    )\n\n    self.initialized_morph = True\n
    "},{"location":"reference/modules/#jaxley.modules.cell.Cell.update_summed_coupling_conds","title":"update_summed_coupling_conds(summed_conds, child_inds, par_inds, branchpoint_conds_children, branchpoint_conds_parents) staticmethod","text":"

    Perform updates on the diagonal based on conductances of the branchpoints.

    Parameters:

    Name Type Description Default summed_conds

    shape [num_branches, nseg]

    required child_inds

    shape [num_branches - 1]

    required conds_fwd

    shape [num_branches - 1]

    required conds_bwd

    shape [num_branches - 1]

    required parents

    shape [num_branches]

    required

    Returns:

    Type Description

    Updated summed_coupling_conds.

    Source code in jaxley/modules/cell.py
    @staticmethod\ndef update_summed_coupling_conds(\n    summed_conds,\n    child_inds,\n    par_inds,\n    branchpoint_conds_children,\n    branchpoint_conds_parents,\n):\n    \"\"\"Perform updates on the diagonal based on conductances of the branchpoints.\n\n    Args:\n        summed_conds: shape [num_branches, nseg]\n        child_inds: shape [num_branches - 1]\n        conds_fwd: shape [num_branches - 1]\n        conds_bwd: shape [num_branches - 1]\n        parents: shape [num_branches]\n\n    Returns:\n        Updated `summed_coupling_conds`.\n    \"\"\"\n    summed_conds = summed_conds.at[child_inds, 0].add(branchpoint_conds_children)\n    summed_conds = summed_conds.at[par_inds, -1].add(branchpoint_conds_parents)\n    return summed_conds\n
    "},{"location":"reference/modules/#network","title":"Network","text":"

    Bases: Module

    Network class.

    This class defines a network of cells that can be connected with synapses.

    Source code in jaxley/modules/network.py
    class Network(Module):\n    \"\"\"Network class.\n\n    This class defines a network of cells that can be connected with synapses.\n    \"\"\"\n\n    network_params: Dict = {}\n    network_states: Dict = {}\n\n    def __init__(\n        self,\n        cells: List[Cell],\n    ):\n        \"\"\"Initialize network of cells and synapses.\n\n        Args:\n            cells: A list of cells that make up the network.\n        \"\"\"\n        super().__init__()\n        for cell in cells:\n            self.xyzr += deepcopy(cell.xyzr)\n\n        self.cells = cells\n        self.nseg = cells[0].nseg\n        self._append_params_and_states(self.network_params, self.network_states)\n\n        self.nbranches_per_cell = [cell.total_nbranches for cell in self.cells]\n        self.nbranchpoints_per_cell = [cell.total_nbranchpoints for cell in self.cells]\n        self.total_nbranches = sum(self.nbranches_per_cell)\n        self.cumsum_nbranches = jnp.cumsum(jnp.asarray([0] + self.nbranches_per_cell))\n        self.cumsum_nbranchpoints = jnp.cumsum(\n            jnp.asarray([0] + self.nbranchpoints_per_cell)\n        )\n\n        self.nodes = pd.concat([c.nodes for c in cells], ignore_index=True)\n        self.nodes[\"comp_index\"] = np.arange(self.nseg * self.total_nbranches).tolist()\n        self.nodes[\"branch_index\"] = (\n            np.arange(self.nseg * self.total_nbranches) // self.nseg\n        ).tolist()\n        self.nodes[\"cell_index\"] = list(\n            itertools.chain(\n                *[[i] * (self.nseg * b) for i, b in enumerate(self.nbranches_per_cell)]\n            )\n        )\n\n        parents = [cell.comb_parents for cell in self.cells]\n        self.comb_parents = jnp.concatenate(\n            [p.at[1:].add(self.cumsum_nbranches[i]) for i, p in enumerate(parents)]\n        )\n\n        # Two columns: `parent_branch_index` and `child_branch_index`. One row per\n        # branch, apart from those branches which do not have a parent (i.e.\n        # -1 in parents). For every branch, tracks the global index of that branch\n        # (`child_branch_index`) and the global index of its parent\n        # (`parent_branch_index`).\n        self.branch_edges = pd.DataFrame(\n            dict(\n                parent_branch_index=self.comb_parents[self.comb_parents != -1],\n                child_branch_index=np.where(self.comb_parents != -1)[0],\n            )\n        )\n\n        # For morphology indexing.\n        par_inds = self.branch_edges[\"parent_branch_index\"].to_numpy()\n        self.child_inds = self.branch_edges[\"child_branch_index\"].to_numpy()\n        self.child_belongs_to_branchpoint = remap_to_consecutive(par_inds)\n        self.par_inds = np.unique(par_inds)  # TODO: does order have to be preserved?\n        self.total_nbranchpoints = len(self.par_inds)\n        self.root_inds = self.cumsum_nbranches[:-1]\n\n        # Channels.\n        self._gather_channels_from_constituents(cells)\n\n        self.initialize()\n        self.init_syns()\n        self.initialized_conds = False\n\n    def __getattr__(self, key: str):\n        # Ensure that hidden methods such as `__deepcopy__` still work.\n        if key.startswith(\"__\"):\n            return super().__getattribute__(key)\n\n        if key == \"cell\":\n            view = deepcopy(self.nodes)\n            view[\"global_comp_index\"] = view[\"comp_index\"]\n            view[\"global_branch_index\"] = view[\"branch_index\"]\n            view[\"global_cell_index\"] = view[\"cell_index\"]\n            return CellView(self, view)\n        elif key in self.synapse_names:\n            type_index = self.synapse_names.index(key)\n            return SynapseView(self, self.edges, key, self.synapses[type_index])\n        elif key in self.group_nodes:\n            inds = self.group_nodes[key].index.values\n            view = self.nodes.loc[inds]\n            view[\"global_comp_index\"] = view[\"comp_index\"]\n            view[\"global_branch_index\"] = view[\"branch_index\"]\n            view[\"global_cell_index\"] = view[\"cell_index\"]\n            return GroupView(self, view, CellView, [\"cell\"])\n        else:\n            raise KeyError(f\"Key {key} not recognized.\")\n\n    def init_morph(self):\n        self.branchpoint_group_inds = build_branchpoint_group_inds(\n            len(self.par_inds),\n            self.child_belongs_to_branchpoint,\n            self.nseg,\n            self.total_nbranches,\n        )\n        self.children_in_level = merge_cells(\n            self.cumsum_nbranches,\n            self.cumsum_nbranchpoints,\n            [cell.children_in_level for cell in self.cells],\n            exclude_first=False,\n        )\n        self.parents_in_level = merge_cells(\n            self.cumsum_nbranches,\n            self.cumsum_nbranchpoints,\n            [cell.parents_in_level for cell in self.cells],\n            exclude_first=False,\n        )\n        self.initialized_morph = True\n\n    def init_conds(self, params: Dict) -> Dict[str, jnp.ndarray]:\n        \"\"\"Given an axial resisitivity, set the coupling conductances.\"\"\"\n        nbranches = self.total_nbranches\n        nseg = self.nseg\n        parents = self.comb_parents\n\n        axial_resistivity = jnp.reshape(params[\"axial_resistivity\"], (nbranches, nseg))\n        radiuses = jnp.reshape(params[\"radius\"], (nbranches, nseg))\n        lengths = jnp.reshape(params[\"length\"], (nbranches, nseg))\n\n        conds = vmap(Branch.init_branch_conds, in_axes=(0, 0, 0, None))(\n            axial_resistivity, radiuses, lengths, self.nseg\n        )\n        coupling_conds_fwd = conds[0]\n        coupling_conds_bwd = conds[1]\n        summed_coupling_conds = conds[2]\n\n        # The conductance from the children to the branch point.\n        branchpoint_conds_children = vmap(\n            compute_coupling_cond_branchpoint, in_axes=(0, 0, 0)\n        )(\n            radiuses[self.child_inds, 0],\n            axial_resistivity[self.child_inds, 0],\n            lengths[self.child_inds, 0],\n        )\n        # The conductance from the parents to the branch point.\n        branchpoint_conds_parents = vmap(\n            compute_coupling_cond_branchpoint, in_axes=(0, 0, 0)\n        )(\n            radiuses[self.par_inds, -1],\n            axial_resistivity[self.par_inds, -1],\n            lengths[self.par_inds, -1],\n        )\n\n        # Weights with which the compartments influence their nearby node.\n        # The impact of the children on the branch point.\n        branchpoint_weights_children = vmap(compute_impact_on_node, in_axes=(0, 0, 0))(\n            radiuses[self.child_inds, 0],\n            axial_resistivity[self.child_inds, 0],\n            lengths[self.child_inds, 0],\n        )\n        # The impact of parents on the branch point.\n        branchpoint_weights_parents = vmap(compute_impact_on_node, in_axes=(0, 0, 0))(\n            radiuses[self.par_inds, -1],\n            axial_resistivity[self.par_inds, -1],\n            lengths[self.par_inds, -1],\n        )\n\n        summed_coupling_conds = Cell.update_summed_coupling_conds(\n            summed_coupling_conds,\n            self.child_inds,\n            self.par_inds,\n            branchpoint_conds_children,\n            branchpoint_conds_parents,\n        )\n\n        cond_params = {\n            \"branch_uppers\": coupling_conds_bwd,\n            \"branch_lowers\": coupling_conds_fwd,\n            \"branch_diags\": summed_coupling_conds,\n            \"branchpoint_conds_children\": branchpoint_conds_children,\n            \"branchpoint_conds_parents\": branchpoint_conds_parents,\n            \"branchpoint_weights_children\": branchpoint_weights_children,\n            \"branchpoint_weights_parents\": branchpoint_weights_parents,\n        }\n        return cond_params\n\n    def init_syns(self):\n        \"\"\"Initialize synapses.\"\"\"\n        self.synapses = []\n\n        # TODO(@michaeldeistler): should we also track this for channels?\n        self.synapse_names = []\n        self.synapse_param_names = []\n        self.synapse_state_names = []\n\n        self.initialized_syns = True\n\n    def _step_synapse(\n        self,\n        states: Dict,\n        syn_channels: List,\n        params: Dict,\n        delta_t: float,\n        edges: pd.DataFrame,\n    ) -> Tuple[Dict, Tuple[jnp.ndarray, jnp.ndarray]]:\n        \"\"\"Perform one step of the synapses and obtain their currents.\"\"\"\n        states = self._step_synapse_state(states, syn_channels, params, delta_t, edges)\n        states, current_terms = self._synapse_currents(\n            states, syn_channels, params, delta_t, edges\n        )\n        return states, current_terms\n\n    def _step_synapse_state(\n        self,\n        states: Dict,\n        syn_channels: List,\n        params: Dict,\n        delta_t: float,\n        edges: pd.DataFrame,\n    ) -> Dict:\n        voltages = states[\"v\"]\n\n        grouped_syns = edges.groupby(\"type\", sort=False, group_keys=False)\n        pre_syn_inds = grouped_syns[\"global_pre_comp_index\"].apply(list)\n        post_syn_inds = grouped_syns[\"global_post_comp_index\"].apply(list)\n        synapse_names = list(grouped_syns.indices.keys())\n\n        for i, synapse_type in enumerate(syn_channels):\n            assert (\n                synapse_names[i] == synapse_type._name\n            ), \"Mixup in the ordering of synapses. Please create an issue on Github.\"\n            synapse_param_names = list(synapse_type.synapse_params.keys())\n            synapse_state_names = list(synapse_type.synapse_states.keys())\n\n            synapse_params = {}\n            for p in synapse_param_names:\n                synapse_params[p] = params[p]\n            synapse_states = {}\n            for s in synapse_state_names:\n                synapse_states[s] = states[s]\n\n            pre_inds = np.asarray(pre_syn_inds[synapse_names[i]])\n            post_inds = np.asarray(post_syn_inds[synapse_names[i]])\n\n            # State updates.\n            states_updated = synapse_type.update_states(\n                synapse_states,\n                delta_t,\n                voltages[pre_inds],\n                voltages[post_inds],\n                synapse_params,\n            )\n\n            # Rebuild state.\n            for key, val in states_updated.items():\n                states[key] = val\n\n        return states\n\n    def _synapse_currents(\n        self,\n        states: Dict,\n        syn_channels: List,\n        params: Dict,\n        delta_t: float,\n        edges: pd.DataFrame,\n    ) -> Tuple[Dict, Tuple[jnp.ndarray, jnp.ndarray]]:\n        voltages = states[\"v\"]\n\n        grouped_syns = edges.groupby(\"type\", sort=False, group_keys=False)\n        pre_syn_inds = grouped_syns[\"global_pre_comp_index\"].apply(list)\n        post_syn_inds = grouped_syns[\"global_post_comp_index\"].apply(list)\n        synapse_names = list(grouped_syns.indices.keys())\n\n        syn_voltage_terms = jnp.zeros_like(voltages)\n        syn_constant_terms = jnp.zeros_like(voltages)\n        # Run with two different voltages that are `diff` apart to infer the slope and\n        # offset.\n        diff = 1e-3\n        for i, synapse_type in enumerate(syn_channels):\n            assert (\n                synapse_names[i] == synapse_type._name\n            ), \"Mixup in the ordering of synapses. Please create an issue on Github.\"\n            synapse_param_names = list(synapse_type.synapse_params.keys())\n            synapse_state_names = list(synapse_type.synapse_states.keys())\n\n            synapse_params = {}\n            for p in synapse_param_names:\n                synapse_params[p] = params[p]\n            synapse_states = {}\n            for s in synapse_state_names:\n                synapse_states[s] = states[s]\n\n            # Get pre and post indexes of the current synapse type.\n            pre_inds = np.asarray(pre_syn_inds[synapse_names[i]])\n            post_inds = np.asarray(post_syn_inds[synapse_names[i]])\n\n            # Compute slope and offset of the current through every synapse.\n            pre_v_and_perturbed = jnp.stack(\n                [voltages[pre_inds], voltages[pre_inds] + diff]\n            )\n            post_v_and_perturbed = jnp.stack(\n                [voltages[post_inds], voltages[post_inds] + diff]\n            )\n            synapse_currents = vmap(\n                synapse_type.compute_current, in_axes=(None, 0, 0, None)\n            )(\n                synapse_states,\n                pre_v_and_perturbed,\n                post_v_and_perturbed,\n                synapse_params,\n            )\n            synapse_currents_dist = convert_point_process_to_distributed(\n                synapse_currents,\n                params[\"radius\"][post_inds],\n                params[\"length\"][post_inds],\n            )\n\n            # Split into voltage and constant terms.\n            voltage_term = (synapse_currents_dist[1] - synapse_currents_dist[0]) / diff\n            constant_term = (\n                synapse_currents_dist[0] - voltage_term * voltages[post_inds]\n            )\n\n            # Gather slope and offset for every postsynaptic compartment.\n            gathered_syn_currents = gather_synapes(\n                len(voltages),\n                post_inds,\n                voltage_term,\n                constant_term,\n            )\n            syn_voltage_terms += gathered_syn_currents[0]\n            syn_constant_terms -= gathered_syn_currents[1]\n\n            # Add the synaptic currents through every compartment as state.\n            # `post_syn_currents` is a `jnp.ndarray` of as many elements as there are\n            # compartments in the network.\n            # `[0]` because we only use the non-perturbed voltage.\n            states[f\"{synapse_type._name}_current\"] = synapse_currents[0]\n\n        return states, (syn_voltage_terms, syn_constant_terms)\n\n    def vis(\n        self,\n        detail: str = \"full\",\n        ax: Optional[Axes] = None,\n        col: str = \"k\",\n        synapse_col: str = \"b\",\n        dims: Tuple[int] = (0, 1),\n        type: str = \"line\",\n        layers: Optional[List] = None,\n        morph_plot_kwargs: Dict = {},\n        synapse_plot_kwargs: Dict = {},\n        synapse_scatter_kwargs: Dict = {},\n        networkx_options: Dict = {},\n        layer_kwargs: Dict = {},\n    ) -> Axes:\n        \"\"\"Visualize the module.\n\n        Args:\n            detail: Either of [point, full]. `point` visualizes every neuron in the\n                network as a dot (and it uses `networkx` to obtain cell positions).\n                `full` plots the full morphology of every neuron. It requires that\n                `compute_xyz()` has been run and allows for indivual neurons to be\n                moved with `.move()`.\n            col: The color in which cells are plotted. Only takes effect if\n                `detail='full'`.\n            type: Either `line` or `scatter`. Only takes effect if `detail='full'`.\n            synapse_col: The color in which synapses are plotted. Only takes effect if\n                `detail='full'`.\n            dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of\n                two of them.\n            layers: Allows to plot the network in layers. Should provide the number of\n                neurons in each layer, e.g., [5, 10, 1] would be a network with 5 input\n                neurons, 10 hidden layer neurons, and 1 output neuron.\n            morph_plot_kwargs: Keyword arguments passed to the plotting function for\n                cell morphologies. Only takes effect for `detail='full'`.\n            synapse_plot_kwargs: Keyword arguments passed to the plotting function for\n                syanpses. Only takes effect for `detail='full'`.\n            synapse_scatter_kwargs: Keyword arguments passed to the scatter function\n                for the end point of synapses. Only takes effect for `detail='full'`.\n            networkx_options: Options passed to `networkx.draw()`. Only takes effect if\n                `detail='point'`.\n            layer_kwargs: Only used if `layers` is specified and if `detail='full'`.\n                Can have the following entries: `within_layer_offset` (float),\n                `between_layer_offset` (float), `vertical_layers` (bool).\n        \"\"\"\n        if detail == \"point\":\n            graph = self._build_graph(layers)\n\n            if layers is not None:\n                pos = nx.multipartite_layout(graph, subset_key=\"layer\")\n                nx.draw(graph, pos, with_labels=True, **networkx_options)\n            else:\n                nx.draw(graph, with_labels=True, **networkx_options)\n        elif detail == \"full\":\n            if layers is not None:\n                # Assemble cells in the network into layers.\n                global_counter = 0\n                layers_config = {\n                    \"within_layer_offset\": 500.0,\n                    \"between_layer_offset\": 1500.0,\n                    \"vertical_layers\": False,\n                }\n                layers_config.update(layer_kwargs)\n                for layer_ind, num_in_layer in enumerate(layers):\n                    for ind_within_layer in range(num_in_layer):\n                        if layers_config[\"vertical_layers\"]:\n                            x_offset = (\n                                ind_within_layer - (num_in_layer - 1) / 2\n                            ) * layers_config[\"within_layer_offset\"]\n                            y_offset = (len(layers) - 1 - layer_ind) * layers_config[\n                                \"between_layer_offset\"\n                            ]\n                        else:\n                            x_offset = layer_ind * layers_config[\"between_layer_offset\"]\n                            y_offset = (\n                                ind_within_layer - (num_in_layer - 1) / 2\n                            ) * layers_config[\"within_layer_offset\"]\n\n                        self.cell(global_counter).move_to(x=x_offset, y=y_offset, z=0)\n                        global_counter += 1\n            ax = self._vis(\n                dims=dims,\n                col=col,\n                ax=ax,\n                type=type,\n                view=self.nodes,\n                morph_plot_kwargs=morph_plot_kwargs,\n            )\n\n            pre_locs = self.edges[\"pre_locs\"].to_numpy()\n            post_locs = self.edges[\"post_locs\"].to_numpy()\n            pre_branch = self.edges[\"global_pre_branch_index\"].to_numpy()\n            post_branch = self.edges[\"global_post_branch_index\"].to_numpy()\n\n            dims_np = np.asarray(dims)\n\n            for pre_loc, post_loc, pre_b, post_b in zip(\n                pre_locs, post_locs, pre_branch, post_branch\n            ):\n                pre_coord = self.xyzr[pre_b]\n                if len(pre_coord) == 2:\n                    # If only start and end point of a branch are traced, perform a\n                    # linear interpolation to get the synpase location.\n                    pre_coord = pre_coord[0] + (pre_coord[1] - pre_coord[0]) * pre_loc\n                else:\n                    # If densely traced, use intermediate trace values for synapse loc.\n                    middle_ind = int((len(pre_coord) - 1) * pre_loc)\n                    pre_coord = pre_coord[middle_ind]\n\n                post_coord = self.xyzr[post_b]\n                if len(post_coord) == 2:\n                    # If only start and end point of a branch are traced, perform a\n                    # linear interpolation to get the synpase location.\n                    post_coord = (\n                        post_coord[0] + (post_coord[1] - post_coord[0]) * post_loc\n                    )\n                else:\n                    # If densely traced, use intermediate trace values for synapse loc.\n                    middle_ind = int((len(post_coord) - 1) * post_loc)\n                    post_coord = post_coord[middle_ind]\n\n                coords = np.stack([pre_coord[dims_np], post_coord[dims_np]]).T\n                ax.plot(\n                    coords[0],\n                    coords[1],\n                    c=synapse_col,\n                    **synapse_plot_kwargs,\n                )\n                ax.scatter(\n                    post_coord[dims_np[0]],\n                    post_coord[dims_np[1]],\n                    c=synapse_col,\n                    **synapse_scatter_kwargs,\n                )\n        else:\n            raise ValueError(\"detail must be in {full, point}.\")\n\n        return ax\n\n    def _build_graph(self, layers: Optional[List] = None, **options):\n        graph = nx.DiGraph()\n\n        def build_extents(*subset_sizes):\n            return nx.utils.pairwise(itertools.accumulate((0,) + subset_sizes))\n\n        if layers is not None:\n            extents = build_extents(*layers)\n            layers = [range(start, end) for start, end in extents]\n            for i, layer in enumerate(layers):\n                graph.add_nodes_from(layer, layer=i)\n        else:\n            graph.add_nodes_from(range(len(self.cells)))\n\n        pre_cell = self.edges[\"pre_cell_index\"].to_numpy()\n        post_cell = self.edges[\"post_cell_index\"].to_numpy()\n\n        inds = np.stack([pre_cell, post_cell]).T\n        graph.add_edges_from(inds)\n\n        return graph\n
    "},{"location":"reference/modules/#jaxley.modules.network.Network.__init__","title":"__init__(cells)","text":"

    Initialize network of cells and synapses.

    Parameters:

    Name Type Description Default cells List[Cell]

    A list of cells that make up the network.

    required Source code in jaxley/modules/network.py
    def __init__(\n    self,\n    cells: List[Cell],\n):\n    \"\"\"Initialize network of cells and synapses.\n\n    Args:\n        cells: A list of cells that make up the network.\n    \"\"\"\n    super().__init__()\n    for cell in cells:\n        self.xyzr += deepcopy(cell.xyzr)\n\n    self.cells = cells\n    self.nseg = cells[0].nseg\n    self._append_params_and_states(self.network_params, self.network_states)\n\n    self.nbranches_per_cell = [cell.total_nbranches for cell in self.cells]\n    self.nbranchpoints_per_cell = [cell.total_nbranchpoints for cell in self.cells]\n    self.total_nbranches = sum(self.nbranches_per_cell)\n    self.cumsum_nbranches = jnp.cumsum(jnp.asarray([0] + self.nbranches_per_cell))\n    self.cumsum_nbranchpoints = jnp.cumsum(\n        jnp.asarray([0] + self.nbranchpoints_per_cell)\n    )\n\n    self.nodes = pd.concat([c.nodes for c in cells], ignore_index=True)\n    self.nodes[\"comp_index\"] = np.arange(self.nseg * self.total_nbranches).tolist()\n    self.nodes[\"branch_index\"] = (\n        np.arange(self.nseg * self.total_nbranches) // self.nseg\n    ).tolist()\n    self.nodes[\"cell_index\"] = list(\n        itertools.chain(\n            *[[i] * (self.nseg * b) for i, b in enumerate(self.nbranches_per_cell)]\n        )\n    )\n\n    parents = [cell.comb_parents for cell in self.cells]\n    self.comb_parents = jnp.concatenate(\n        [p.at[1:].add(self.cumsum_nbranches[i]) for i, p in enumerate(parents)]\n    )\n\n    # Two columns: `parent_branch_index` and `child_branch_index`. One row per\n    # branch, apart from those branches which do not have a parent (i.e.\n    # -1 in parents). For every branch, tracks the global index of that branch\n    # (`child_branch_index`) and the global index of its parent\n    # (`parent_branch_index`).\n    self.branch_edges = pd.DataFrame(\n        dict(\n            parent_branch_index=self.comb_parents[self.comb_parents != -1],\n            child_branch_index=np.where(self.comb_parents != -1)[0],\n        )\n    )\n\n    # For morphology indexing.\n    par_inds = self.branch_edges[\"parent_branch_index\"].to_numpy()\n    self.child_inds = self.branch_edges[\"child_branch_index\"].to_numpy()\n    self.child_belongs_to_branchpoint = remap_to_consecutive(par_inds)\n    self.par_inds = np.unique(par_inds)  # TODO: does order have to be preserved?\n    self.total_nbranchpoints = len(self.par_inds)\n    self.root_inds = self.cumsum_nbranches[:-1]\n\n    # Channels.\n    self._gather_channels_from_constituents(cells)\n\n    self.initialize()\n    self.init_syns()\n    self.initialized_conds = False\n
    "},{"location":"reference/modules/#jaxley.modules.network.Network.init_conds","title":"init_conds(params)","text":"

    Given an axial resisitivity, set the coupling conductances.

    Source code in jaxley/modules/network.py
    def init_conds(self, params: Dict) -> Dict[str, jnp.ndarray]:\n    \"\"\"Given an axial resisitivity, set the coupling conductances.\"\"\"\n    nbranches = self.total_nbranches\n    nseg = self.nseg\n    parents = self.comb_parents\n\n    axial_resistivity = jnp.reshape(params[\"axial_resistivity\"], (nbranches, nseg))\n    radiuses = jnp.reshape(params[\"radius\"], (nbranches, nseg))\n    lengths = jnp.reshape(params[\"length\"], (nbranches, nseg))\n\n    conds = vmap(Branch.init_branch_conds, in_axes=(0, 0, 0, None))(\n        axial_resistivity, radiuses, lengths, self.nseg\n    )\n    coupling_conds_fwd = conds[0]\n    coupling_conds_bwd = conds[1]\n    summed_coupling_conds = conds[2]\n\n    # The conductance from the children to the branch point.\n    branchpoint_conds_children = vmap(\n        compute_coupling_cond_branchpoint, in_axes=(0, 0, 0)\n    )(\n        radiuses[self.child_inds, 0],\n        axial_resistivity[self.child_inds, 0],\n        lengths[self.child_inds, 0],\n    )\n    # The conductance from the parents to the branch point.\n    branchpoint_conds_parents = vmap(\n        compute_coupling_cond_branchpoint, in_axes=(0, 0, 0)\n    )(\n        radiuses[self.par_inds, -1],\n        axial_resistivity[self.par_inds, -1],\n        lengths[self.par_inds, -1],\n    )\n\n    # Weights with which the compartments influence their nearby node.\n    # The impact of the children on the branch point.\n    branchpoint_weights_children = vmap(compute_impact_on_node, in_axes=(0, 0, 0))(\n        radiuses[self.child_inds, 0],\n        axial_resistivity[self.child_inds, 0],\n        lengths[self.child_inds, 0],\n    )\n    # The impact of parents on the branch point.\n    branchpoint_weights_parents = vmap(compute_impact_on_node, in_axes=(0, 0, 0))(\n        radiuses[self.par_inds, -1],\n        axial_resistivity[self.par_inds, -1],\n        lengths[self.par_inds, -1],\n    )\n\n    summed_coupling_conds = Cell.update_summed_coupling_conds(\n        summed_coupling_conds,\n        self.child_inds,\n        self.par_inds,\n        branchpoint_conds_children,\n        branchpoint_conds_parents,\n    )\n\n    cond_params = {\n        \"branch_uppers\": coupling_conds_bwd,\n        \"branch_lowers\": coupling_conds_fwd,\n        \"branch_diags\": summed_coupling_conds,\n        \"branchpoint_conds_children\": branchpoint_conds_children,\n        \"branchpoint_conds_parents\": branchpoint_conds_parents,\n        \"branchpoint_weights_children\": branchpoint_weights_children,\n        \"branchpoint_weights_parents\": branchpoint_weights_parents,\n    }\n    return cond_params\n
    "},{"location":"reference/modules/#jaxley.modules.network.Network.init_syns","title":"init_syns()","text":"

    Initialize synapses.

    Source code in jaxley/modules/network.py
    def init_syns(self):\n    \"\"\"Initialize synapses.\"\"\"\n    self.synapses = []\n\n    # TODO(@michaeldeistler): should we also track this for channels?\n    self.synapse_names = []\n    self.synapse_param_names = []\n    self.synapse_state_names = []\n\n    self.initialized_syns = True\n
    "},{"location":"reference/modules/#jaxley.modules.network.Network.vis","title":"vis(detail='full', ax=None, col='k', synapse_col='b', dims=(0, 1), type='line', layers=None, morph_plot_kwargs={}, synapse_plot_kwargs={}, synapse_scatter_kwargs={}, networkx_options={}, layer_kwargs={})","text":"

    Visualize the module.

    Parameters:

    Name Type Description Default detail str

    Either of [point, full]. point visualizes every neuron in the network as a dot (and it uses networkx to obtain cell positions). full plots the full morphology of every neuron. It requires that compute_xyz() has been run and allows for indivual neurons to be moved with .move().

    'full' col str

    The color in which cells are plotted. Only takes effect if detail='full'.

    'k' type str

    Either line or scatter. Only takes effect if detail='full'.

    'line' synapse_col str

    The color in which synapses are plotted. Only takes effect if detail='full'.

    'b' dims Tuple[int]

    Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of two of them.

    (0, 1) layers Optional[List]

    Allows to plot the network in layers. Should provide the number of neurons in each layer, e.g., [5, 10, 1] would be a network with 5 input neurons, 10 hidden layer neurons, and 1 output neuron.

    None morph_plot_kwargs Dict

    Keyword arguments passed to the plotting function for cell morphologies. Only takes effect for detail='full'.

    {} synapse_plot_kwargs Dict

    Keyword arguments passed to the plotting function for syanpses. Only takes effect for detail='full'.

    {} synapse_scatter_kwargs Dict

    Keyword arguments passed to the scatter function for the end point of synapses. Only takes effect for detail='full'.

    {} networkx_options Dict

    Options passed to networkx.draw(). Only takes effect if detail='point'.

    {} layer_kwargs Dict

    Only used if layers is specified and if detail='full'. Can have the following entries: within_layer_offset (float), between_layer_offset (float), vertical_layers (bool).

    {} Source code in jaxley/modules/network.py
    def vis(\n    self,\n    detail: str = \"full\",\n    ax: Optional[Axes] = None,\n    col: str = \"k\",\n    synapse_col: str = \"b\",\n    dims: Tuple[int] = (0, 1),\n    type: str = \"line\",\n    layers: Optional[List] = None,\n    morph_plot_kwargs: Dict = {},\n    synapse_plot_kwargs: Dict = {},\n    synapse_scatter_kwargs: Dict = {},\n    networkx_options: Dict = {},\n    layer_kwargs: Dict = {},\n) -> Axes:\n    \"\"\"Visualize the module.\n\n    Args:\n        detail: Either of [point, full]. `point` visualizes every neuron in the\n            network as a dot (and it uses `networkx` to obtain cell positions).\n            `full` plots the full morphology of every neuron. It requires that\n            `compute_xyz()` has been run and allows for indivual neurons to be\n            moved with `.move()`.\n        col: The color in which cells are plotted. Only takes effect if\n            `detail='full'`.\n        type: Either `line` or `scatter`. Only takes effect if `detail='full'`.\n        synapse_col: The color in which synapses are plotted. Only takes effect if\n            `detail='full'`.\n        dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of\n            two of them.\n        layers: Allows to plot the network in layers. Should provide the number of\n            neurons in each layer, e.g., [5, 10, 1] would be a network with 5 input\n            neurons, 10 hidden layer neurons, and 1 output neuron.\n        morph_plot_kwargs: Keyword arguments passed to the plotting function for\n            cell morphologies. Only takes effect for `detail='full'`.\n        synapse_plot_kwargs: Keyword arguments passed to the plotting function for\n            syanpses. Only takes effect for `detail='full'`.\n        synapse_scatter_kwargs: Keyword arguments passed to the scatter function\n            for the end point of synapses. Only takes effect for `detail='full'`.\n        networkx_options: Options passed to `networkx.draw()`. Only takes effect if\n            `detail='point'`.\n        layer_kwargs: Only used if `layers` is specified and if `detail='full'`.\n            Can have the following entries: `within_layer_offset` (float),\n            `between_layer_offset` (float), `vertical_layers` (bool).\n    \"\"\"\n    if detail == \"point\":\n        graph = self._build_graph(layers)\n\n        if layers is not None:\n            pos = nx.multipartite_layout(graph, subset_key=\"layer\")\n            nx.draw(graph, pos, with_labels=True, **networkx_options)\n        else:\n            nx.draw(graph, with_labels=True, **networkx_options)\n    elif detail == \"full\":\n        if layers is not None:\n            # Assemble cells in the network into layers.\n            global_counter = 0\n            layers_config = {\n                \"within_layer_offset\": 500.0,\n                \"between_layer_offset\": 1500.0,\n                \"vertical_layers\": False,\n            }\n            layers_config.update(layer_kwargs)\n            for layer_ind, num_in_layer in enumerate(layers):\n                for ind_within_layer in range(num_in_layer):\n                    if layers_config[\"vertical_layers\"]:\n                        x_offset = (\n                            ind_within_layer - (num_in_layer - 1) / 2\n                        ) * layers_config[\"within_layer_offset\"]\n                        y_offset = (len(layers) - 1 - layer_ind) * layers_config[\n                            \"between_layer_offset\"\n                        ]\n                    else:\n                        x_offset = layer_ind * layers_config[\"between_layer_offset\"]\n                        y_offset = (\n                            ind_within_layer - (num_in_layer - 1) / 2\n                        ) * layers_config[\"within_layer_offset\"]\n\n                    self.cell(global_counter).move_to(x=x_offset, y=y_offset, z=0)\n                    global_counter += 1\n        ax = self._vis(\n            dims=dims,\n            col=col,\n            ax=ax,\n            type=type,\n            view=self.nodes,\n            morph_plot_kwargs=morph_plot_kwargs,\n        )\n\n        pre_locs = self.edges[\"pre_locs\"].to_numpy()\n        post_locs = self.edges[\"post_locs\"].to_numpy()\n        pre_branch = self.edges[\"global_pre_branch_index\"].to_numpy()\n        post_branch = self.edges[\"global_post_branch_index\"].to_numpy()\n\n        dims_np = np.asarray(dims)\n\n        for pre_loc, post_loc, pre_b, post_b in zip(\n            pre_locs, post_locs, pre_branch, post_branch\n        ):\n            pre_coord = self.xyzr[pre_b]\n            if len(pre_coord) == 2:\n                # If only start and end point of a branch are traced, perform a\n                # linear interpolation to get the synpase location.\n                pre_coord = pre_coord[0] + (pre_coord[1] - pre_coord[0]) * pre_loc\n            else:\n                # If densely traced, use intermediate trace values for synapse loc.\n                middle_ind = int((len(pre_coord) - 1) * pre_loc)\n                pre_coord = pre_coord[middle_ind]\n\n            post_coord = self.xyzr[post_b]\n            if len(post_coord) == 2:\n                # If only start and end point of a branch are traced, perform a\n                # linear interpolation to get the synpase location.\n                post_coord = (\n                    post_coord[0] + (post_coord[1] - post_coord[0]) * post_loc\n                )\n            else:\n                # If densely traced, use intermediate trace values for synapse loc.\n                middle_ind = int((len(post_coord) - 1) * post_loc)\n                post_coord = post_coord[middle_ind]\n\n            coords = np.stack([pre_coord[dims_np], post_coord[dims_np]]).T\n            ax.plot(\n                coords[0],\n                coords[1],\n                c=synapse_col,\n                **synapse_plot_kwargs,\n            )\n            ax.scatter(\n                post_coord[dims_np[0]],\n                post_coord[dims_np[1]],\n                c=synapse_col,\n                **synapse_scatter_kwargs,\n            )\n    else:\n        raise ValueError(\"detail must be in {full, point}.\")\n\n    return ax\n
    "},{"location":"reference/optimize/","title":"Optimization","text":""},{"location":"reference/optimize/#jaxley.optimize.optimizer.TypeOptimizer","title":"TypeOptimizer","text":"

    optax wrapper which allows different argument values for different params.

    Source code in jaxley/optimize/optimizer.py
    class TypeOptimizer:\n    \"\"\"`optax` wrapper which allows different argument values for different params.\"\"\"\n\n    def __init__(\n        self,\n        optimizer: Callable,\n        optimizer_args: Dict[str, Any],\n        opt_params: List[Dict[str, jnp.ndarray]],\n    ):\n        \"\"\"Create the optimizers.\n\n        This requires access to `opt_params` in order to know how many optimizers\n        should be created. It creates `len(opt_params)` optimizers.\n\n        Example usage:\n        ```\n        lrs = {\"HH_gNa\": 0.01, \"radius\": 1.0}\n        optimizer = TypeOptimizer(lambda lr: optax.adam(lr), lrs, opt_params)\n        opt_state = optimizer.init(opt_params)\n        ```\n\n        ```\n        optimizer_args = {\"HH_gNa\": [0.01, 0.4], \"radius\": [1.0, 0.8]}\n        optimizer = TypeOptimizer(\n            lambda args: optax.sgd(args[0], momentum=args[1]),\n            optimizer_args,\n            opt_params\n        )\n        opt_state = optimizer.init(opt_params)\n        ```\n\n        Args:\n            optimizer: A Callable that takes the learning rate and returns the\n                `optax.optimizer` which should be used.\n            optimizer_args: The arguments for different kinds of parameters.\n                Each item of the dictionary will be passed to the `Callable` passed to\n                `optimizer`.\n            opt_params: The parameters to be optimized. The exact values are not used,\n                only the number of elements in the list and the key of each dict.\n        \"\"\"\n        self.base_optimizer = optimizer\n\n        self.optimizers = []\n        for params in opt_params:\n            names = list(params.keys())\n            assert len(names) == 1, \"Multiple parameters were added at once.\"\n            name = names[0]\n            optimizer = self.base_optimizer(optimizer_args[name])\n            self.optimizers.append({name: optimizer})\n\n    def init(self, opt_params: List[Dict[str, jnp.ndarray]]) -> List:\n        \"\"\"Initialize the optimizers. Equivalent to `optax.optimizers.init()`.\"\"\"\n        opt_states = []\n        for params, optimizer in zip(opt_params, self.optimizers):\n            name = list(optimizer.keys())[0]\n            opt_state = optimizer[name].init(params)\n            opt_states.append(opt_state)\n        return opt_states\n\n    def update(self, gradient: jnp.ndarray, opt_state: List) -> Tuple[List, List]:\n        \"\"\"Update the optimizers. Equivalent to `optax.optimizers.update()`.\"\"\"\n        all_updates = []\n        new_opt_states = []\n        for grad, state, opt in zip(gradient, opt_state, self.optimizers):\n            name = list(opt.keys())[0]\n            updates, new_opt_state = opt[name].update(grad, state)\n            all_updates.append(updates)\n            new_opt_states.append(new_opt_state)\n        return all_updates, new_opt_states\n
    "},{"location":"reference/optimize/#jaxley.optimize.optimizer.TypeOptimizer.__init__","title":"__init__(optimizer, optimizer_args, opt_params)","text":"

    Create the optimizers.

    This requires access to opt_params in order to know how many optimizers should be created. It creates len(opt_params) optimizers.

    Example usage:

    lrs = {\"HH_gNa\": 0.01, \"radius\": 1.0}\noptimizer = TypeOptimizer(lambda lr: optax.adam(lr), lrs, opt_params)\nopt_state = optimizer.init(opt_params)\n

    optimizer_args = {\"HH_gNa\": [0.01, 0.4], \"radius\": [1.0, 0.8]}\noptimizer = TypeOptimizer(\n    lambda args: optax.sgd(args[0], momentum=args[1]),\n    optimizer_args,\n    opt_params\n)\nopt_state = optimizer.init(opt_params)\n

    Parameters:

    Name Type Description Default optimizer Callable

    A Callable that takes the learning rate and returns the optax.optimizer which should be used.

    required optimizer_args Dict[str, Any]

    The arguments for different kinds of parameters. Each item of the dictionary will be passed to the Callable passed to optimizer.

    required opt_params List[Dict[str, ndarray]]

    The parameters to be optimized. The exact values are not used, only the number of elements in the list and the key of each dict.

    required Source code in jaxley/optimize/optimizer.py
    def __init__(\n    self,\n    optimizer: Callable,\n    optimizer_args: Dict[str, Any],\n    opt_params: List[Dict[str, jnp.ndarray]],\n):\n    \"\"\"Create the optimizers.\n\n    This requires access to `opt_params` in order to know how many optimizers\n    should be created. It creates `len(opt_params)` optimizers.\n\n    Example usage:\n    ```\n    lrs = {\"HH_gNa\": 0.01, \"radius\": 1.0}\n    optimizer = TypeOptimizer(lambda lr: optax.adam(lr), lrs, opt_params)\n    opt_state = optimizer.init(opt_params)\n    ```\n\n    ```\n    optimizer_args = {\"HH_gNa\": [0.01, 0.4], \"radius\": [1.0, 0.8]}\n    optimizer = TypeOptimizer(\n        lambda args: optax.sgd(args[0], momentum=args[1]),\n        optimizer_args,\n        opt_params\n    )\n    opt_state = optimizer.init(opt_params)\n    ```\n\n    Args:\n        optimizer: A Callable that takes the learning rate and returns the\n            `optax.optimizer` which should be used.\n        optimizer_args: The arguments for different kinds of parameters.\n            Each item of the dictionary will be passed to the `Callable` passed to\n            `optimizer`.\n        opt_params: The parameters to be optimized. The exact values are not used,\n            only the number of elements in the list and the key of each dict.\n    \"\"\"\n    self.base_optimizer = optimizer\n\n    self.optimizers = []\n    for params in opt_params:\n        names = list(params.keys())\n        assert len(names) == 1, \"Multiple parameters were added at once.\"\n        name = names[0]\n        optimizer = self.base_optimizer(optimizer_args[name])\n        self.optimizers.append({name: optimizer})\n
    "},{"location":"reference/optimize/#jaxley.optimize.optimizer.TypeOptimizer.init","title":"init(opt_params)","text":"

    Initialize the optimizers. Equivalent to optax.optimizers.init().

    Source code in jaxley/optimize/optimizer.py
    def init(self, opt_params: List[Dict[str, jnp.ndarray]]) -> List:\n    \"\"\"Initialize the optimizers. Equivalent to `optax.optimizers.init()`.\"\"\"\n    opt_states = []\n    for params, optimizer in zip(opt_params, self.optimizers):\n        name = list(optimizer.keys())[0]\n        opt_state = optimizer[name].init(params)\n        opt_states.append(opt_state)\n    return opt_states\n
    "},{"location":"reference/optimize/#jaxley.optimize.optimizer.TypeOptimizer.update","title":"update(gradient, opt_state)","text":"

    Update the optimizers. Equivalent to optax.optimizers.update().

    Source code in jaxley/optimize/optimizer.py
    def update(self, gradient: jnp.ndarray, opt_state: List) -> Tuple[List, List]:\n    \"\"\"Update the optimizers. Equivalent to `optax.optimizers.update()`.\"\"\"\n    all_updates = []\n    new_opt_states = []\n    for grad, state, opt in zip(gradient, opt_state, self.optimizers):\n        name = list(opt.keys())[0]\n        updates, new_opt_state = opt[name].update(grad, state)\n        all_updates.append(updates)\n        new_opt_states.append(new_opt_state)\n    return all_updates, new_opt_states\n
    "},{"location":"reference/optimize/#jaxley.optimize.transforms.ParamTransform","title":"ParamTransform","text":"

    Parameter transformation utility.

    This class is used to transform parameters from an unconstrained space to a constrained space and back. If the range is bounded both from above and below, we use the sigmoid function to transform the parameters. If the range is only bounded from below or above, we use softplus.

    Attributes:

    Name Type Description lowers

    A dictionary of lower bounds for each parameter (None for no bound).

    uppers

    A dictionary of upper bounds for each parameter (None for no bound).

    Source code in jaxley/optimize/transforms.py
    class ParamTransform:\n    \"\"\"Parameter transformation utility.\n\n    This class is used to transform parameters from an unconstrained space to a constrained space\n    and back. If the range is bounded both from above and below, we use the sigmoid function to\n    transform the parameters. If the range is only bounded from below or above, we use softplus.\n\n    Attributes:\n        lowers: A dictionary of lower bounds for each parameter (None for no bound).\n        uppers: A dictionary of upper bounds for each parameter (None for no bound).\n\n    \"\"\"\n\n    def __init__(self, lowers: Dict[str, float], uppers: Dict[str, float]):\n        \"\"\"Initialize the bounds.\n\n        Args:\n            lowers: A dictionary of lower bounds for each parameter (None for no bound).\n            uppers: A dictionary of upper bounds for each parameter (None for no bound).\n        \"\"\"\n\n        self.lowers = lowers\n        self.uppers = uppers\n\n    def forward(self, params: List[Dict[str, jnp.ndarray]]) -> jnp.ndarray:\n        \"\"\"Pushes unconstrained parameters through a tf such that they fit the interval.\n\n        Args:\n            params: A list of dictionaries with unconstrained parameters.\n\n        Returns:\n            A list of dictionaries with transformed parameters.\n\n        \"\"\"\n\n        tf_params = []\n        for param in params:\n            key = list(param.keys())[0]\n\n            # If constrained from below and above, use sigmoid\n            if self.lowers[key] is not None and self.uppers[key] is not None:\n                tf = (\n                    sigmoid(param[key]) * (self.uppers[key] - self.lowers[key])\n                    + self.lowers[key]\n                )\n                tf_params.append({key: tf})\n\n            # If constrained from below, use softplus\n            elif self.lowers[key] is not None:\n                tf = softplus(param[key]) + self.lowers[key]\n                tf_params.append({key: tf})\n\n            # If constrained from above, use negative softplus\n            elif self.uppers[key] is not None:\n                tf = -softplus(-param[key]) + self.uppers[key]\n                tf_params.append({key: tf})\n\n            # Else just pass through\n            else:\n                tf_params.append({key: param[key]})\n\n        return tf_params\n\n    def inverse(self, params: jnp.ndarray) -> jnp.ndarray:\n        \"\"\"Takes parameters from within the interval and makes them unconstrained.\n\n        Args:\n            params: A list of dictionaries with transformed parameters.\n\n        Returns:\n            A list of dictionaries with unconstrained parameters.\n        \"\"\"\n\n        tf_params = []\n        for param in params:\n            key = list(param.keys())[0]\n\n            # If constrained from below and above, use expit\n            if self.lowers[key] is not None and self.uppers[key] is not None:\n                tf = expit(\n                    (param[key] - self.lowers[key])\n                    / (self.uppers[key] - self.lowers[key])\n                )\n                tf_params.append({key: tf})\n\n            # If constrained from below, use inv_softplus\n            elif self.lowers[key] is not None:\n                tf = inv_softplus(param[key] - self.lowers[key])\n                tf_params.append({key: tf})\n\n            # If constrained from above, use negative inv_softplus\n            elif self.uppers[key] is not None:\n                tf = -inv_softplus(-(param[key] - self.uppers[key]))\n                tf_params.append({key: tf})\n\n            # else just pass through\n            else:\n                tf_params.append({key: param[key]})\n\n        return tf_params\n
    "},{"location":"reference/optimize/#jaxley.optimize.transforms.ParamTransform.__init__","title":"__init__(lowers, uppers)","text":"

    Initialize the bounds.

    Parameters:

    Name Type Description Default lowers Dict[str, float]

    A dictionary of lower bounds for each parameter (None for no bound).

    required uppers Dict[str, float]

    A dictionary of upper bounds for each parameter (None for no bound).

    required Source code in jaxley/optimize/transforms.py
    def __init__(self, lowers: Dict[str, float], uppers: Dict[str, float]):\n    \"\"\"Initialize the bounds.\n\n    Args:\n        lowers: A dictionary of lower bounds for each parameter (None for no bound).\n        uppers: A dictionary of upper bounds for each parameter (None for no bound).\n    \"\"\"\n\n    self.lowers = lowers\n    self.uppers = uppers\n
    "},{"location":"reference/optimize/#jaxley.optimize.transforms.ParamTransform.forward","title":"forward(params)","text":"

    Pushes unconstrained parameters through a tf such that they fit the interval.

    Parameters:

    Name Type Description Default params List[Dict[str, ndarray]]

    A list of dictionaries with unconstrained parameters.

    required

    Returns:

    Type Description ndarray

    A list of dictionaries with transformed parameters.

    Source code in jaxley/optimize/transforms.py
    def forward(self, params: List[Dict[str, jnp.ndarray]]) -> jnp.ndarray:\n    \"\"\"Pushes unconstrained parameters through a tf such that they fit the interval.\n\n    Args:\n        params: A list of dictionaries with unconstrained parameters.\n\n    Returns:\n        A list of dictionaries with transformed parameters.\n\n    \"\"\"\n\n    tf_params = []\n    for param in params:\n        key = list(param.keys())[0]\n\n        # If constrained from below and above, use sigmoid\n        if self.lowers[key] is not None and self.uppers[key] is not None:\n            tf = (\n                sigmoid(param[key]) * (self.uppers[key] - self.lowers[key])\n                + self.lowers[key]\n            )\n            tf_params.append({key: tf})\n\n        # If constrained from below, use softplus\n        elif self.lowers[key] is not None:\n            tf = softplus(param[key]) + self.lowers[key]\n            tf_params.append({key: tf})\n\n        # If constrained from above, use negative softplus\n        elif self.uppers[key] is not None:\n            tf = -softplus(-param[key]) + self.uppers[key]\n            tf_params.append({key: tf})\n\n        # Else just pass through\n        else:\n            tf_params.append({key: param[key]})\n\n    return tf_params\n
    "},{"location":"reference/optimize/#jaxley.optimize.transforms.ParamTransform.inverse","title":"inverse(params)","text":"

    Takes parameters from within the interval and makes them unconstrained.

    Parameters:

    Name Type Description Default params ndarray

    A list of dictionaries with transformed parameters.

    required

    Returns:

    Type Description ndarray

    A list of dictionaries with unconstrained parameters.

    Source code in jaxley/optimize/transforms.py
    def inverse(self, params: jnp.ndarray) -> jnp.ndarray:\n    \"\"\"Takes parameters from within the interval and makes them unconstrained.\n\n    Args:\n        params: A list of dictionaries with transformed parameters.\n\n    Returns:\n        A list of dictionaries with unconstrained parameters.\n    \"\"\"\n\n    tf_params = []\n    for param in params:\n        key = list(param.keys())[0]\n\n        # If constrained from below and above, use expit\n        if self.lowers[key] is not None and self.uppers[key] is not None:\n            tf = expit(\n                (param[key] - self.lowers[key])\n                / (self.uppers[key] - self.lowers[key])\n            )\n            tf_params.append({key: tf})\n\n        # If constrained from below, use inv_softplus\n        elif self.lowers[key] is not None:\n            tf = inv_softplus(param[key] - self.lowers[key])\n            tf_params.append({key: tf})\n\n        # If constrained from above, use negative inv_softplus\n        elif self.uppers[key] is not None:\n            tf = -inv_softplus(-(param[key] - self.uppers[key]))\n            tf_params.append({key: tf})\n\n        # else just pass through\n        else:\n            tf_params.append({key: param[key]})\n\n    return tf_params\n
    "},{"location":"reference/optimize/#jaxley.optimize.transforms.expit","title":"expit(x)","text":"

    Inverse sigmoid (expit)

    Source code in jaxley/optimize/transforms.py
    def expit(x: jnp.ndarray) -> jnp.ndarray:\n    \"\"\"Inverse sigmoid (expit)\"\"\"\n    return -jnp.log(1 / x - 1)\n
    "},{"location":"reference/optimize/#jaxley.optimize.transforms.inv_softplus","title":"inv_softplus(x)","text":"

    Inverse softplus.

    Source code in jaxley/optimize/transforms.py
    def inv_softplus(x: jnp.ndarray) -> jnp.ndarray:\n    \"\"\"Inverse softplus.\"\"\"\n    return jnp.log(jnp.exp(x) - 1)\n
    "},{"location":"reference/optimize/#jaxley.optimize.transforms.sigmoid","title":"sigmoid(x)","text":"

    Sigmoid.

    Source code in jaxley/optimize/transforms.py
    def sigmoid(x: jnp.ndarray) -> jnp.ndarray:\n    \"\"\"Sigmoid.\"\"\"\n    return 1 / (1 + save_exp(-x))\n
    "},{"location":"reference/optimize/#jaxley.optimize.transforms.softplus","title":"softplus(x)","text":"

    Softplus.

    Source code in jaxley/optimize/transforms.py
    def softplus(x: jnp.ndarray) -> jnp.ndarray:\n    \"\"\"Softplus.\"\"\"\n    return jnp.log(1 + jnp.exp(x))\n
    "},{"location":"reference/utils/","title":"Utils","text":""},{"location":"reference/utils/#jaxley.utils.cell_utils.compute_children_indices","title":"compute_children_indices(parents)","text":"

    Return all children indices of every branch.

    Example:

    parents = [-1, 0, 0]\ncompute_children_indices(parents) -> [[1, 2], [], []]\n

    Source code in jaxley/utils/cell_utils.py
    def compute_children_indices(parents) -> List[jnp.ndarray]:\n    \"\"\"Return all children indices of every branch.\n\n    Example:\n    ```\n    parents = [-1, 0, 0]\n    compute_children_indices(parents) -> [[1, 2], [], []]\n    ```\n    \"\"\"\n    num_branches = len(parents)\n    child_indices = []\n    for b in range(num_branches):\n        child_indices.append(np.where(parents == b)[0])\n    return child_indices\n
    "},{"location":"reference/utils/#jaxley.utils.cell_utils.compute_coupling_cond","title":"compute_coupling_cond(rad1, rad2, r_a1, r_a2, l1, l2)","text":"

    Return the coupling conductance between two compartments.

    Equations taken from https://en.wikipedia.org/wiki/Compartmental_neuron_models.

    radius: um r_a: ohm cm length_single_compartment: um coupling_conds: S * um / cm / um^2 = S / cm / um -> *10**7 -> mS / cm^2

    Source code in jaxley/utils/cell_utils.py
    def compute_coupling_cond(rad1, rad2, r_a1, r_a2, l1, l2):\n    \"\"\"Return the coupling conductance between two compartments.\n\n    Equations taken from `https://en.wikipedia.org/wiki/Compartmental_neuron_models`.\n\n    `radius`: um\n    `r_a`: ohm cm\n    `length_single_compartment`: um\n    `coupling_conds`: S * um / cm / um^2 = S / cm / um -> *10**7 -> mS / cm^2\n    \"\"\"\n    # Multiply by 10**7 to convert (S / cm / um) -> (mS / cm^2).\n    return rad1 * rad2**2 / (r_a1 * rad2**2 * l1 + r_a2 * rad1**2 * l2) / l1 * 10**7\n
    "},{"location":"reference/utils/#jaxley.utils.cell_utils.compute_coupling_cond_branchpoint","title":"compute_coupling_cond_branchpoint(rad, r_a, l)","text":"

    Return the coupling conductance between one compartment and a comp with l=0.

    From https://en.wikipedia.org/wiki/Compartmental_neuron_models

    If one compartment has l=0.0 then the equations simplify.

    R_long = \\sum_i r_a * L_i/2 / crosssection_i

    with crosssection = pi * r**2

    For a single compartment with L>0, this turns into: R_long = r_a * L/2 / crosssection

    Then, g_long = crosssection * 2 / L / r_a

    Then, the effective conductance is g_long / zylinder_area. So: g = pi * r**2 * 2 / L / r_a / 2 / pi / r / L g = r / r_a / L**2

    Source code in jaxley/utils/cell_utils.py
    def compute_coupling_cond_branchpoint(rad, r_a, l):\n    r\"\"\"Return the coupling conductance between one compartment and a comp with l=0.\n\n    From https://en.wikipedia.org/wiki/Compartmental_neuron_models\n\n    If one compartment has l=0.0 then the equations simplify.\n\n    R_long = \\sum_i r_a * L_i/2 / crosssection_i\n\n    with crosssection = pi * r**2\n\n    For a single compartment with L>0, this turns into:\n    R_long = r_a * L/2 / crosssection\n\n    Then, g_long = crosssection * 2 / L / r_a\n\n    Then, the effective conductance is g_long / zylinder_area. So:\n    g = pi * r**2 * 2 / L / r_a / 2 / pi / r / L\n    g = r / r_a / L**2\n    \"\"\"\n    return rad / r_a / l**2 * 10**7  # Convert (S / cm / um) -> (mS / cm^2)\n
    "},{"location":"reference/utils/#jaxley.utils.cell_utils.compute_impact_on_node","title":"compute_impact_on_node(rad, r_a, l)","text":"

    Compute the weight with which a compartment influences its node.

    In order to satisfy Kirchhoffs current law, the current at a branch point must be proportional to the crosssection of the compartment. We only require proportionality here because the branch point equation reads: g_1 * (V_1 - V_b) + g_2 * (V_2 - V_b) = 0.0

    Because R_long = r_a * L/2 / crosssection, we get g_long = crosssection * 2 / L / r_a \\propto rad**2 / L / r_a

    This equation can be multiplied by any constant.

    Source code in jaxley/utils/cell_utils.py
    def compute_impact_on_node(rad, r_a, l):\n    r\"\"\"Compute the weight with which a compartment influences its node.\n\n    In order to satisfy Kirchhoffs current law, the current at a branch point must be\n    proportional to the crosssection of the compartment. We only require proportionality\n    here because the branch point equation reads:\n    `g_1 * (V_1 - V_b) + g_2 * (V_2 - V_b) = 0.0`\n\n    Because R_long = r_a * L/2 / crosssection, we get\n    g_long = crosssection * 2 / L / r_a \\propto rad**2 / L / r_a\n\n    This equation can be multiplied by any constant.\"\"\"\n    return rad**2 / r_a / l\n
    "},{"location":"reference/utils/#jaxley.utils.cell_utils.compute_morphology_indices_in_levels","title":"compute_morphology_indices_in_levels(num_branchpoints, child_belongs_to_branchpoint, par_inds, child_inds)","text":"

    Return (row, col) to build the sparse matrix defining the voltage eqs.

    This is run at init, not during runtime.

    Source code in jaxley/utils/cell_utils.py
    def compute_morphology_indices_in_levels(\n    num_branchpoints,\n    child_belongs_to_branchpoint,\n    par_inds,\n    child_inds,\n):\n    \"\"\"Return (row, col) to build the sparse matrix defining the voltage eqs.\n\n    This is run at `init`, not during runtime.\n    \"\"\"\n    branchpoint_inds_parents = jnp.arange(num_branchpoints)\n    branchpoint_inds_children = child_belongs_to_branchpoint\n    branch_inds_parents = par_inds\n    branch_inds_children = child_inds\n\n    children = jnp.stack([branch_inds_children, branchpoint_inds_children])\n    parents = jnp.stack([branch_inds_parents, branchpoint_inds_parents])\n\n    return {\"children\": children.T, \"parents\": parents.T}\n
    "},{"location":"reference/utils/#jaxley.utils.cell_utils.convert_point_process_to_distributed","title":"convert_point_process_to_distributed(current, radius, length)","text":"

    Convert current point process (nA) to distributed current (uA/cm2).

    This function gets called for synapses and for external stimuli.

    Parameters:

    Name Type Description Default current ndarray

    Current in nA.

    required radius ndarray

    Compartment radius in um.

    required length ndarray

    Compartment length in um.

    required Return

    Current in uA/cm2.

    Source code in jaxley/utils/cell_utils.py
    def convert_point_process_to_distributed(\n    current: jnp.ndarray, radius: jnp.ndarray, length: jnp.ndarray\n) -> jnp.ndarray:\n    \"\"\"Convert current point process (nA) to distributed current (uA/cm2).\n\n    This function gets called for synapses and for external stimuli.\n\n    Args:\n        current: Current in `nA`.\n        radius: Compartment radius in `um`.\n        length: Compartment length in `um`.\n\n    Return:\n        Current in `uA/cm2`.\n    \"\"\"\n    area = 2 * pi * radius * length\n    current /= area  # nA / um^2\n    return current * 100_000  # Convert (nA / um^2) to (uA / cm^2)\n
    "},{"location":"reference/utils/#jaxley.utils.cell_utils.equal_segments","title":"equal_segments(branch_property, nseg_per_branch)","text":"

    Generates segments where some property is the same in each segment.

    Parameters:

    Name Type Description Default branch_property list

    List of values of the property in each branch. Should have len(branch_property) == num_branches.

    required Source code in jaxley/utils/cell_utils.py
    def equal_segments(branch_property: list, nseg_per_branch: int):\n    \"\"\"Generates segments where some property is the same in each segment.\n\n    Args:\n        branch_property: List of values of the property in each branch. Should have\n            `len(branch_property) == num_branches`.\n    \"\"\"\n    assert isinstance(branch_property, list), \"branch_property must be a list.\"\n    return jnp.asarray([branch_property] * nseg_per_branch).T\n
    "},{"location":"reference/utils/#jaxley.utils.cell_utils.get_num_neighbours","title":"get_num_neighbours(num_children, nseg_per_branch, num_branches)","text":"

    Number of neighbours of each compartment.

    Source code in jaxley/utils/cell_utils.py
    def get_num_neighbours(\n    num_children: jnp.ndarray,\n    nseg_per_branch: int,\n    num_branches: int,\n):\n    \"\"\"\n    Number of neighbours of each compartment.\n    \"\"\"\n    num_neighbours = 2 * jnp.ones((num_branches * nseg_per_branch))\n    num_neighbours = num_neighbours.at[nseg_per_branch - 1].set(1.0)\n    num_neighbours = num_neighbours.at[jnp.arange(num_branches) * nseg_per_branch].set(\n        num_children + 1.0\n    )\n    return num_neighbours\n
    "},{"location":"reference/utils/#jaxley.utils.cell_utils.group_and_sum","title":"group_and_sum(values_to_sum, inds_to_group_by, num_branchpoints)","text":"

    Group values by whether they have the same integer and sum values within group.

    This is used to construct the last diagonals at the branch points.

    Written by ChatGPT.

    Source code in jaxley/utils/cell_utils.py
    def group_and_sum(\n    values_to_sum: jnp.ndarray, inds_to_group_by: jnp.ndarray, num_branchpoints: int\n) -> jnp.ndarray:\n    \"\"\"Group values by whether they have the same integer and sum values within group.\n\n    This is used to construct the last diagonals at the branch points.\n\n    Written by ChatGPT.\n    \"\"\"\n    # Initialize an array to hold the sum of each group\n    group_sums = jnp.zeros(num_branchpoints)\n\n    # `.at[inds]` requires that `inds` is not empty, so we need an if-case here.\n    # `len(inds) == 0` is the case for branches and compartments.\n    if num_branchpoints > 0:\n        group_sums = group_sums.at[inds_to_group_by].add(values_to_sum)\n\n    return group_sums\n
    "},{"location":"reference/utils/#jaxley.utils.cell_utils.index_of_loc","title":"index_of_loc(branch_ind, loc, nseg_per_branch)","text":"

    Returns the index of a segment given a loc in [0, 1] and the index of a branch.

    This is used because we specify locations such as synapses as a value between 0 and 1. We have to convert this onto a discrete segment here.

    Parameters:

    Name Type Description Default branch_ind int

    Index of the branch.

    required loc float

    Location (in [0, 1]) along that branch.

    required nseg_per_branch int

    Number of segments of each branch.

    required

    Returns:

    Type Description int

    The index of the compartment within the entire cell.

    Source code in jaxley/utils/cell_utils.py
    def index_of_loc(branch_ind: int, loc: float, nseg_per_branch: int) -> int:\n    \"\"\"Returns the index of a segment given a loc in [0, 1] and the index of a branch.\n\n    This is used because we specify locations such as synapses as a value between 0 and\n    1. We have to convert this onto a discrete segment here.\n\n    Args:\n        branch_ind: Index of the branch.\n        loc: Location (in [0, 1]) along that branch.\n        nseg_per_branch: Number of segments of each branch.\n\n    Returns:\n        The index of the compartment within the entire cell.\n    \"\"\"\n    nseg = nseg_per_branch  # only for convenience.\n    possible_locs = np.linspace(0.5 / nseg, 1 - 0.5 / nseg, nseg)\n    ind_along_branch = np.argmin(np.abs(possible_locs - loc))\n    return branch_ind * nseg + ind_along_branch\n
    "},{"location":"reference/utils/#jaxley.utils.cell_utils.interpolate_xyz","title":"interpolate_xyz(loc, coords)","text":"

    Perform a linear interpolation between xyz-coordinates.

    Parameters:

    Name Type Description Default loc float

    The location in [0,1] along the branch.

    required coords ndarray

    Array containing the reconstructed xyzr points of the branch.

    required Return

    Interpolated xyz coordinate at loc, shape `(3,).

    Source code in jaxley/utils/cell_utils.py
    def interpolate_xyz(loc: float, coords: np.ndarray):\n    \"\"\"Perform a linear interpolation between xyz-coordinates.\n\n    Args:\n        loc: The location in [0,1] along the branch.\n        coords: Array containing the reconstructed xyzr points of the branch.\n\n    Return:\n        Interpolated xyz coordinate at `loc`, shape `(3,).\n    \"\"\"\n    return vmap(lambda x: jnp.interp(loc, jnp.linspace(0, 1, len(x)), x), in_axes=(1,))(\n        coords[:, :3]\n    )\n
    "},{"location":"reference/utils/#jaxley.utils.cell_utils.linear_segments","title":"linear_segments(initial_val, endpoint_vals, parents, nseg_per_branch)","text":"

    Generates segments where some property is linearly interpolated.

    Parameters:

    Name Type Description Default initial_val float

    The value at the tip of the soma.

    required endpoint_vals list

    The value at the endpoints of each branch.

    required Source code in jaxley/utils/cell_utils.py
    def linear_segments(\n    initial_val: float, endpoint_vals: list, parents: jnp.ndarray, nseg_per_branch: int\n):\n    \"\"\"Generates segments where some property is linearly interpolated.\n\n    Args:\n        initial_val: The value at the tip of the soma.\n        endpoint_vals: The value at the endpoints of each branch.\n    \"\"\"\n    branch_property = endpoint_vals + [initial_val]\n    num_branches = len(parents)\n    # Compute radiuses by linear interpolation.\n    endpoint_radiuses = jnp.asarray(branch_property)\n\n    def compute_rad(branch_ind, loc):\n        start = endpoint_radiuses[parents[branch_ind]]\n        end = endpoint_radiuses[branch_ind]\n        return (end - start) * loc + start\n\n    branch_inds_of_each_comp = jnp.tile(jnp.arange(num_branches), nseg_per_branch)\n    locs_of_each_comp = jnp.linspace(1, 0, nseg_per_branch).repeat(num_branches)\n    rad_of_each_comp = compute_rad(branch_inds_of_each_comp, locs_of_each_comp)\n\n    return jnp.reshape(rad_of_each_comp, (nseg_per_branch, num_branches)).T\n
    "},{"location":"reference/utils/#jaxley.utils.cell_utils.loc_of_index","title":"loc_of_index(global_comp_index, nseg)","text":"

    Return location corresponding to index.

    Source code in jaxley/utils/cell_utils.py
    def loc_of_index(global_comp_index, nseg):\n    \"\"\"Return location corresponding to index.\"\"\"\n    index = global_comp_index % nseg\n    possible_locs = np.linspace(0.5 / nseg, 1 - 0.5 / nseg, nseg)\n    return possible_locs[index]\n
    "},{"location":"reference/utils/#jaxley.utils.cell_utils.merge_cells","title":"merge_cells(cumsum_num_branches, cumsum_num_branchpoints, arrs, exclude_first=True)","text":"

    Build full list of which branches are solved in which iteration.

    From the branching pattern of single cells, this \u201cmerges\u201d them into a single ordering of branches.

    Parameters:

    Name Type Description Default cumsum_num_branches List[int]

    cumulative number of branches. E.g., for three cells with 10, 15, and 5 branches respectively, this will should be a list containing [0, 10, 25, 30].

    required arrs List[List[ndarray]]

    A list of a list of arrays that should be merged.

    required exclude_first bool

    If True, the first element of each list in arrs will remain unchanged. Useful if a -1 (which indicates \u201cno parent\u201d) entry should not be changed.

    True

    Returns:

    Type Description ndarray

    A list of arrays which contain the branch indices that are computed at each

    ndarray

    level (i.e., iteration).

    Source code in jaxley/utils/cell_utils.py
    def merge_cells(\n    cumsum_num_branches: List[int],\n    cumsum_num_branchpoints: List[int],\n    arrs: List[List[jnp.ndarray]],\n    exclude_first: bool = True,\n) -> jnp.ndarray:\n    \"\"\"\n    Build full list of which branches are solved in which iteration.\n\n    From the branching pattern of single cells, this \"merges\" them into a single\n    ordering of branches.\n\n    Args:\n        cumsum_num_branches: cumulative number of branches. E.g., for three cells with\n            10, 15, and 5 branches respectively, this will should be a list containing\n            `[0, 10, 25, 30]`.\n        arrs: A list of a list of arrays that should be merged.\n        exclude_first: If `True`, the first element of each list in `arrs` will remain\n            unchanged. Useful if a `-1` (which indicates \"no parent\") entry should not\n            be changed.\n\n    Returns:\n        A list of arrays which contain the branch indices that are computed at each\n        level (i.e., iteration).\n    \"\"\"\n    ps = []\n    for i, att in enumerate(arrs):\n        p = att\n        if exclude_first:\n            raise NotImplementedError\n            p = [p[0]] + [p_in_level + cumsum_num_branches[i] for p_in_level in p[1:]]\n        else:\n            p = [\n                p_in_level\n                + jnp.asarray([cumsum_num_branches[i], cumsum_num_branchpoints[i]])\n                for p_in_level in p\n            ]\n        ps.append(p)\n\n    max_len = max([len(att) for att in arrs])\n    combined_parents_in_level = []\n    for i in range(max_len):\n        current_ps = []\n        for p in ps:\n            if len(p) > i:\n                current_ps.append(p[i])\n        combined_parents_in_level.append(jnp.concatenate(current_ps))\n\n    return combined_parents_in_level\n
    "},{"location":"reference/utils/#jaxley.utils.cell_utils.params_to_pstate","title":"params_to_pstate(params, indices_set_by_trainables)","text":"

    Make outputs get_parameters() conform with outputs of .data_set().

    make_trainable() followed by params=get_parameters() does not return indices because these indices would also be differentiated by jax.grad (as soon as the params are passed to def simulate(params). Therefore, in jx.integrate, we run the function to add indices to the dict. The outputs of params_to_pstate are of the same shape as the outputs of .data_set().

    Source code in jaxley/utils/cell_utils.py
    def params_to_pstate(\n    params: List[Dict[str, jnp.ndarray]],\n    indices_set_by_trainables: List[jnp.ndarray],\n):\n    \"\"\"Make outputs `get_parameters()` conform with outputs of `.data_set()`.\n\n    `make_trainable()` followed by `params=get_parameters()` does not return indices\n    because these indices would also be differentiated by `jax.grad` (as soon as\n    the `params` are passed to `def simulate(params)`. Therefore, in `jx.integrate`,\n    we run the function to add indices to the dict. The outputs of `params_to_pstate`\n    are of the same shape as the outputs of `.data_set()`.\"\"\"\n    return [\n        {\"key\": list(p.keys())[0], \"val\": list(p.values())[0], \"indices\": i}\n        for p, i in zip(params, indices_set_by_trainables)\n    ]\n
    "},{"location":"reference/utils/#jaxley.utils.cell_utils.remap_to_consecutive","title":"remap_to_consecutive(arr)","text":"

    Maps an array of integers to an array of consecutive integers.

    E.g. [0, 0, 1, 4, 4, 6, 6] -> [0, 0, 1, 2, 2, 3, 3]

    Source code in jaxley/utils/cell_utils.py
    def remap_to_consecutive(arr):\n    \"\"\"Maps an array of integers to an array of consecutive integers.\n\n    E.g. `[0, 0, 1, 4, 4, 6, 6] -> [0, 0, 1, 2, 2, 3, 3]`\n    \"\"\"\n    _, inverse_indices = jnp.unique(arr, return_inverse=True)\n    return inverse_indices\n
    "},{"location":"reference/utils/#jaxley.utils.plot_utils.plot_morph","title":"plot_morph(xyzr, dims=(0, 1), col='k', ax=None, type='line', morph_plot_kwargs={})","text":"

    Plot morphology.

    Parameters:

    Name Type Description Default dims Tuple[int]

    Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of two of them.

    (0, 1) type str

    Either line or scatter.

    'line' col str

    The color for all branches.

    'k' Source code in jaxley/utils/plot_utils.py
    def plot_morph(\n    xyzr,\n    dims: Tuple[int] = (0, 1),\n    col: str = \"k\",\n    ax: Optional[Axes] = None,\n    type: str = \"line\",\n    morph_plot_kwargs: Dict = {},\n):\n    \"\"\"Plot morphology.\n\n    Args:\n        dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of\n            two of them.\n        type: Either `line` or `scatter`.\n        col: The color for all branches.\n    \"\"\"\n\n    if ax is None:\n        _, ax = plt.subplots(1, 1, figsize=(3, 3))\n\n    for coords_of_branch in xyzr:\n        x1, x2 = coords_of_branch[:, dims].T\n\n        if \"line\" in type.lower():\n            _ = ax.plot(x1, x2, color=col, **morph_plot_kwargs)\n        elif \"scatter\" in type.lower():\n            _ = ax.scatter(x1, x2, color=col, **morph_plot_kwargs)\n        else:\n            raise NotImplementedError\n\n    return ax\n
    "},{"location":"reference/utils/#jaxley.utils.swc.swc_to_jaxley","title":"swc_to_jaxley(fname, max_branch_len=100.0, sort=True, num_lines=None)","text":"

    Read an SWC file and bring morphology into jaxley compatible formats.

    Parameters:

    Name Type Description Default fname str

    Path to swc file.

    required max_branch_len float

    Maximal length of one branch. If a branch exceeds this length, it is split into equal parts such that each subbranch is below max_branch_len.

    100.0 num_lines Optional[int]

    Number of lines of the SWC file to read.

    None Source code in jaxley/utils/swc.py
    def swc_to_jaxley(\n    fname: str,\n    max_branch_len: float = 100.0,\n    sort: bool = True,\n    num_lines: Optional[int] = None,\n) -> Tuple[List[int], List[float], List[Callable], List[float], List[np.ndarray]]:\n    \"\"\"Read an SWC file and bring morphology into `jaxley` compatible formats.\n\n    Args:\n        fname: Path to swc file.\n        max_branch_len: Maximal length of one branch. If a branch exceeds this length,\n            it is split into equal parts such that each subbranch is below\n            `max_branch_len`.\n        num_lines: Number of lines of the SWC file to read.\n    \"\"\"\n    content = np.loadtxt(fname)[:num_lines]\n    types = content[:, 1]\n    is_single_point_soma = types[0] == 1 and types[1] != 1\n\n    if is_single_point_soma:\n        # Warn here, but the conversion of the length happens in `_compute_pathlengths`.\n        warn(\n            \"Found a soma which consists of a single traced point. `Jaxley` \"\n            \"interprets this soma as a spherical compartment with radius \"\n            \"specified in the SWC file, i.e. with surface area 4*pi*r*r.\"\n        )\n    sorted_branches, types = _split_into_branches_and_sort(\n        content,\n        max_branch_len=max_branch_len,\n        is_single_point_soma=is_single_point_soma,\n        sort=sort,\n    )\n\n    parents = _build_parents(sorted_branches)\n    each_length = _compute_pathlengths(\n        sorted_branches, content[:, 1:6], is_single_point_soma=is_single_point_soma\n    )\n    pathlengths = [np.sum(length_traced) for length_traced in each_length]\n    for i, pathlen in enumerate(pathlengths):\n        if pathlen == 0.0:\n            warn(\"Found a segment with length 0. Clipping it to 1.0\")\n            pathlengths[i] = 1.0\n    radius_fns = _radius_generating_fns(\n        sorted_branches, content[:, 5], each_length, parents, types\n    )\n\n    if np.sum(np.asarray(parents) == -1) > 1.0:\n        parents = np.asarray([-1] + parents)\n        parents[1:] += 1\n        parents = parents.tolist()\n        pathlengths = [0.1] + pathlengths\n        radius_fns = [lambda x: content[0, 5] * np.ones_like(x)] + radius_fns\n        sorted_branches = [[0]] + sorted_branches\n\n        # Type of padded section is assumed to be of `custom` type:\n        # http://www.neuronland.org/NLMorphologyConverter/MorphologyFormats/SWC/Spec.html\n        types = [5.0] + types\n\n    all_coords_of_branches = []\n    for i, branch in enumerate(sorted_branches):\n        # Remove 1 because `content` is an array that is indexed from 0.\n        branch = np.asarray(branch) - 1\n\n        # Deal with additional branch that might have been added above in the lines\n        # `if np.sum(np.asarray(parents) == -1) > 1.0:`\n        branch[branch < 0] = 0\n\n        # Get traced coordinates of the branch.\n        coords_of_branch = content[branch, 2:6]\n        all_coords_of_branches.append(coords_of_branch)\n\n    return parents, pathlengths, radius_fns, types, all_coords_of_branches\n
    "},{"location":"reference/utils/#jaxley.utils.jax_utils.nested_checkpoint_scan","title":"nested_checkpoint_scan(f, init, xs, length=None, *, nested_lengths, scan_fn=jax.lax.scan, checkpoint_fn=jax.checkpoint)","text":"

    A version of lax.scan that supports recursive gradient checkpointing.

    Code taken from: https://github.com/google/jax/issues/2139

    The interface of nested_checkpoint_scan exactly matches lax.scan, except for the required nested_lengths argument.

    The key feature of nested_checkpoint_scan is that gradient calculations require O(max(nested_lengths)) memory, vs O(prod(nested_lengths)) for unnested scans, which it achieves by re-evaluating the forward pass len(nested_lengths) - 1 times.

    nested_checkpoint_scan reduces to lax.scan when nested_lengths has a single element.

    Parameters:

    Name Type Description Default f Callable[[Carry, Dict[str, ndarray]], Tuple[Carry, Output]]

    function to scan over.

    required init Carry

    initial value.

    required xs Dict[str, ndarray]

    scanned over values.

    required length Optional[int]

    leading length of all dimensions

    None nested_lengths Sequence[int]

    required list of lengths to scan over for each level of checkpointing. The product of nested_lengths must match length (if provided) and the size of the leading axis for all arrays in xs.

    required scan_fn

    function matching the API of lax.scan

    scan checkpoint_fn Callable[[Func], Func]

    function matching the API of jax.checkpoint.

    checkpoint Source code in jaxley/utils/jax_utils.py
    def nested_checkpoint_scan(\n    f: Callable[[Carry, Dict[str, jnp.ndarray]], Tuple[Carry, Output]],\n    init: Carry,\n    xs: Dict[str, jnp.ndarray],\n    length: Optional[int] = None,\n    *,\n    nested_lengths: Sequence[int],\n    scan_fn=jax.lax.scan,\n    checkpoint_fn: Callable[[Func], Func] = jax.checkpoint,\n):\n    \"\"\"A version of lax.scan that supports recursive gradient checkpointing.\n\n    Code taken from: https://github.com/google/jax/issues/2139\n\n    The interface of `nested_checkpoint_scan` exactly matches lax.scan, except for\n    the required `nested_lengths` argument.\n\n    The key feature of `nested_checkpoint_scan` is that gradient calculations\n    require O(max(nested_lengths)) memory, vs O(prod(nested_lengths)) for unnested\n    scans, which it achieves by re-evaluating the forward pass\n    `len(nested_lengths) - 1` times.\n\n    `nested_checkpoint_scan` reduces to `lax.scan` when `nested_lengths` has a\n    single element.\n\n    Args:\n        f: function to scan over.\n        init: initial value.\n        xs: scanned over values.\n        length: leading length of all dimensions\n        nested_lengths: required list of lengths to scan over for each level of\n            checkpointing. The product of nested_lengths must match length (if\n            provided) and the size of the leading axis for all arrays in ``xs``.\n        scan_fn: function matching the API of lax.scan\n        checkpoint_fn: function matching the API of jax.checkpoint.\n    \"\"\"\n    if length is not None and length != math.prod(nested_lengths):\n        raise ValueError(f\"inconsistent {length=} and {nested_lengths=}\")\n\n    def nested_reshape(x):\n        x = jnp.asarray(x)\n        new_shape = tuple(nested_lengths) + x.shape[1:]\n        return x.reshape(new_shape)\n\n    sub_xs = jax.tree_map(nested_reshape, xs)\n    return _inner_nested_scan(f, init, sub_xs, nested_lengths, scan_fn, checkpoint_fn)\n
    "},{"location":"reference/utils/#jaxley.utils.syn_utils.gather_synapes","title":"gather_synapes(number_of_compartments, post_syn_comp_inds, current_each_synapse_voltage_term, current_each_synapse_constant_term)","text":"

    Compute current at the post synapse.

    All this does it that it sums the synaptic currents that come into a particular compartment. It returns an array of as many elements as there are compartments.

    Source code in jaxley/utils/syn_utils.py
    def gather_synapes(\n    number_of_compartments: jnp.ndarray,\n    post_syn_comp_inds: np.ndarray,\n    current_each_synapse_voltage_term: jnp.ndarray,\n    current_each_synapse_constant_term: jnp.ndarray,\n) -> Tuple[jnp.ndarray, jnp.ndarray]:\n    \"\"\"Compute current at the post synapse.\n\n    All this does it that it sums the synaptic currents that come into a particular\n    compartment. It returns an array of as many elements as there are compartments.\n    \"\"\"\n    incoming_currents_voltages = jnp.zeros((number_of_compartments,))\n    incoming_currents_contant = jnp.zeros((number_of_compartments,))\n\n    dnums = ScatterDimensionNumbers(\n        update_window_dims=(),\n        inserted_window_dims=(0,),\n        scatter_dims_to_operand_dims=(0,),\n    )\n    incoming_currents_voltages = scatter_add(\n        incoming_currents_voltages,\n        post_syn_comp_inds[:, None],\n        current_each_synapse_voltage_term,\n        dnums,\n    )\n    incoming_currents_contant = scatter_add(\n        incoming_currents_contant,\n        post_syn_comp_inds[:, None],\n        current_each_synapse_constant_term,\n        dnums,\n    )\n    return incoming_currents_voltages, incoming_currents_contant\n
    "},{"location":"tutorial/01_morph_neurons/","title":"Basics of running simulations in Jaxley","text":"

    In this tutorial, you will learn how to:

    • build your first morphologically detailed cell or read it from SWC
    • stimulate the cell
    • record from the cell
    • visualize cells
    • run your first simulation

    Here is a code snippet which you will learn to understand in this tutorial:

    import jaxley as jx\nfrom jaxley.channels import Na, K, Leak\nimport matplotlib.pyplot as plt\n\n\n# Build the network.\ncomp = jx.Compartment()\nbranch = jx.Branch(comp, nseg=4)\ncell = jx.Cell(branch, parents=[-1, 0, 0, 1, 1])\n\n# Insert channels.\ncell.insert(Leak())\ncell.branch(0).insert(Na())\ncell.branch(0).insert(K())\n\n# Visualize the morphology.\ncell.compute_xyz()\nfig, ax = plt.subplots(1, 1, figsize=(4, 4))\ncell.vis(ax=ax)\n\n# Stimulate.\ncurrent = jx.step_current(i_delay=1.0, i_dur=1.0, i_amp=0.1, dt=0.025, t_max=10.0)\ncell.branch(0).loc(0.0).stimulate(current)\n\n# Record.\ncell.branch(0).loc(0.0).record(\"v\")\n\n# Simulate and plot.\nv = jx.integrate(net)\nplt.plot(v.T)\n

    First, we import the relevant libraries:

    from jax import config\nconfig.update(\"jax_enable_x64\", True)\nconfig.update(\"jax_platform_name\", \"cpu\")\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport jax.numpy as jnp\nfrom jax import jit\n\nimport jaxley as jx\nfrom jaxley.channels import Na, K, Leak\nfrom jaxley.synapses import IonotropicSynapse\nfrom jaxley.connect import fully_connect\n

    We will now build our first cell in Jaxley. You have two options to do this: you can either build a cell bottom-up by defining the morphology yourselve, or you can load cells from SWC files.

    "},{"location":"tutorial/01_morph_neurons/#define-the-cell-from-scratch","title":"Define the cell from scratch","text":"

    To define a cell from scratch you first have to define a single compartment and then assemble those compartments into a branch:

    comp = jx.Compartment()\nbranch = jx.Branch(comp, nseg=4)\n

    Next, we can assemble branches into a cell. To do so, we have to define for each branch what its parent branch is. A -1 entry means that this branch does not have a parent.

    parents = jnp.asarray([-1, 0, 0, 1, 1])\ncell = jx.Cell(branch, parents=parents)\n
    "},{"location":"tutorial/01_morph_neurons/#read-the-cell-from-an-swc-file","title":"Read the cell from an SWC file","text":"

    Alternatively, you could also load cells from SWC with

    cell = jx.read_swc(fname, nseg=4).

    "},{"location":"tutorial/01_morph_neurons/#visualize-the-cells","title":"Visualize the cells","text":"

    Cells can be visualized as follows:

    cell.compute_xyz()  # Only needed for visualization.\n\nfig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = cell.vis(ax=ax, col=\"k\")\n

    "},{"location":"tutorial/01_morph_neurons/#insert-mechanisms","title":"Insert mechanisms","text":"

    Currently, the cell does not contain any kind of ion channel (not even a leak). We can fix this by inserting a leak channel into the entire cell, and by inserting sodium and potassium into the zero-eth branch.

    cell.insert(Leak())\ncell.branch(0).insert(Na())\ncell.branch(0).insert(K())\n

    The easiest way to know which branch is the zero-eth branch (or, e.g., the zero-eth compartment of the zero-eth branch) is to plot it in a different color:

    fig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = cell.vis(ax=ax, col=\"k\")\n_ = cell.branch(0).vis(ax=ax, col=\"r\")\n_ = cell.branch(0).loc(0.0).vis(ax=ax, col=\"b\")\n

    "},{"location":"tutorial/01_morph_neurons/#stimulate-the-cell","title":"Stimulate the cell","text":"

    We next stimulate one of the compartments with a step current. For this, we first define the step current:

    dt = 0.025\nt_max = 10.0\ntime_vec = np.arange(0, t_max+dt, dt)\ncurrent = jx.step_current(i_delay=1.0, i_dur=1.0, i_amp=0.1, delta_t=dt, t_max=t_max)\n\nfig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = plt.plot(time_vec, current)\n

    We then stimulate one of the compartments of the cell with this step current:

    cell.delete_stimuli()\ncell.branch(0).loc(0.0).stimulate(current)\n
    Added 1 external_states. See `.externals` for details.\n
    "},{"location":"tutorial/01_morph_neurons/#define-recordings","title":"Define recordings","text":"

    Next, you have to define where to record the voltage. In this case, we will record the voltage at two locations:

    cell.delete_recordings()\ncell.branch(0).loc(0.0).record(\"v\")\ncell.branch(3).loc(1.0).record(\"v\")\n
    Added 1 recordings. See `.recordings` for details.\nAdded 1 recordings. See `.recordings` for details.\n

    We can again visualize these locations to understand where we inserted recordings:

    fig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = cell.vis(ax=ax)\n_ = cell.branch(0).loc(0.0).vis(ax=ax, col=\"b\")\n_ = cell.branch(3).loc(1.0).vis(ax=ax, col=\"g\")\n

    "},{"location":"tutorial/01_morph_neurons/#simulate-the-cell-response","title":"Simulate the cell response","text":"

    Having set up the cell, inserted stimuli and recordings, we are now ready to run a simulation with jx.integrate:

    voltages = jx.integrate(cell)\nprint(\"voltages.shape\", voltages.shape)\n
    voltages.shape (2, 402)\n

    The jx.integrate function returns an array of shape (num_recordings, num_timepoints). In our case, we inserted2` recordings and we simulated for 10ms at a 0.025 time step, which leads to 402 time steps.

    We can now visualize the voltage response:

    fig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = ax.plot(voltages[0], c=\"b\")\n_ = ax.plot(voltages[1], c=\"orange\")\n

    At the location of the first recording (in blue) the cell spiked, whereas at the second recording, it did not. This makes sense because we only inserted sodium and potassium channels into the first branch, but not in the entire cell.

    Congrats! You have just run your first morphologically detailed neuron simulation in Jaxley. We suggest to continue by learning how to build networks. If you are only interested in single cell simulations, you can directly jump to learning how to modify parameters of your simulation. If you want to simulate detailed morphologies from SWC files, checkout our tutorial on working with detailed morphologies.

    "},{"location":"tutorial/02_small_network/","title":"Network simulations in Jaxley","text":"

    In this tutorial, you will learn how to:

    • connect neurons into a network
    • visualize networks

    Here is a code snippet which you will learn to understand in this tutorial:

    import jaxley as jx\nfrom jaxley.synapses import IonotropicSynapse\nfrom jaxley.connect import connect\n\n\n# Define a network. `cell` is defined as in previous tutorial.\nnet = jx.Network([cell for _ in range(11)])\n\n# Define synapses.\nfully_connect(\n    net.cell(range(10)),\n    net.cell(10),\n    IonotropicSynapse(),\n)\n\n# Visualize the network.\nnet.compute_xyz()\nfig, ax = plt.subplots(1, 1, figsize=(4, 4))\nnet.vis(ax=ax, detail=\"full\", layers=[10, 1])  # or `detail=\"point\"`.\n

    In the previous tutorial, you learned how to build single cells with morphological detail, how to insert stimuli and recordings, and how to run a first simulation. In this tutorial, we will define networks of multiple cells and connect them with synapses. Let\u2019s get started:

    from jax import config\nconfig.update(\"jax_enable_x64\", True)\nconfig.update(\"jax_platform_name\", \"cpu\")\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport jax.numpy as jnp\nfrom jax import jit\n\nimport jaxley as jx\nfrom jaxley.channels import Na, K, Leak\nfrom jaxley.synapses import IonotropicSynapse\nfrom jaxley.connect import fully_connect, connect\n
    "},{"location":"tutorial/02_small_network/#define-the-network","title":"Define the network","text":"

    First, we define a cell as you saw in the previous tutorial.

    comp = jx.Compartment()\nbranch = jx.Branch(comp, nseg=4)\ncell = jx.Cell(branch, parents=[-1, 0, 0, 1, 1, 2, 2])\n

    We can assemble multiple cells into a network by using jx.Network, which takes a list of jx.Cells. Here, we assemble 11 cells into a network:

    num_cells = 11\nnet = jx.Network([cell for _ in range(num_cells)])\n

    At this point, we can already visualize this network:

    net.compute_xyz()\nnet.rotate(180)\nfig, ax = plt.subplots(1, 1, figsize=(3, 6))\n_ = net.vis(ax=ax, detail=\"full\", layers=[10, 1], layer_kwargs={\"within_layer_offset\": 150, \"between_layer_offset\": 200})\n

    Note: you can use move_to to have more control over the location of cells, e.g.: network.cell(i).move_to(x=0, y=200)

    As you can see, the neurons are not connected yet. Let\u2019s fix this by connecting neurons with synapses. We will build a network consisting of two layers: 10 neurons in the input layer and 1 neuron in the output layer.

    We can use Jaxley\u2019s fully_connect method to connect these layers:

    pre = net.cell(range(10))\npost = net.cell(10)\nfully_connect(pre, post, IonotropicSynapse())\n
    /Users/michaeldeistler/Documents/phd/jaxley/jaxley/modules/base.py:1533: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.\n  self.pointer.edges = pd.concat(\n

    Let\u2019s visualize this again:

    fig, ax = plt.subplots(1, 1, figsize=(3, 6))\n_ = net.vis(ax=ax, detail=\"full\", layers=[10, 1], layer_kwargs={\"within_layer_offset\": 150, \"between_layer_offset\": 200})\n

    As you can see, the full_connect method inserted one synapse (in blue) from every neuron in the first layer to the output neuron. The fully_connect method builds this synapse from the zero-eth compartment and zero-eth branch of the presynaptic neuron onto a random branch of the postsynaptic neuron. If you want more control over the pre- and post-synaptic branches, you can use the connect method:

    pre = net.cell(0).branch(5).loc(1.0)\npost = net.cell(10).branch(0).loc(0.0)\nconnect(pre, post, IonotropicSynapse())\n
    fig, ax = plt.subplots(1, 1, figsize=(3, 6))\n_ = net.vis(ax=ax, detail=\"full\", layers=[10, 1], layer_kwargs={\"within_layer_offset\": 150, \"between_layer_offset\": 200})\n

    "},{"location":"tutorial/02_small_network/#stimulating-recording-and-simulating-the-network","title":"Stimulating, recording, and simulating the network","text":"

    We will now set up a simulation of the network. This works exactly as it does for single neurons:

    # Stimulus.\ni_delay = 3.0  # ms\ni_amp = 0.05  # nA\ni_dur = 2.0  # ms\n\n# Duration and step size.\ndt = 0.025  # ms\nt_max = 50.0  # ms\n
    time_vec = jnp.arange(0.0, t_max + dt, dt)\n

    As a simple example, we insert sodium, potassium, and leak into every compartment of every cell of the network.

    net.insert(Na())\nnet.insert(K())\nnet.insert(Leak())\n

    We stimulate every neuron in the input layer and record the voltage from the output neuron:

    current = jx.step_current(i_delay, i_dur, i_amp, dt, t_max)\nnet.delete_stimuli()\nfor stim_ind in range(10):\n    net.cell(stim_ind).branch(0).loc(0.0).stimulate(current)\n\nnet.delete_recordings()\nnet.cell(10).branch(0).loc(0.0).record()\n
    Added 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 recordings. See `.recordings` for details.\n

    Finally, we can again run the network simulation and plot the result:

    s = jx.integrate(net)\n
    fig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = ax.plot(s.T)\n

    That\u2019s it! You now know how to simulate networks of morphologically detailed neurons. Next, you should learn how to modify parameters of your simulation in this tutorial.

    "},{"location":"tutorial/03_setting_parameters/","title":"Setting parameters and initial states","text":"

    In this tutorial, you will learn how to:

    • set parameters of Jaxley models such as compartment radius or channel conductances
    • set initial states
    • set synaptic parameters

    Here is a code snippet which you will learn to understand in this tutorial:

    cell = ...  # See tutorial on Basics of Jaxley.\ncell.insert(Na())\n\ncell.set(\"radius\", 1.0)  # Set compartment radius.\ncell.branch(0).set(\"Na_gNa\", 0.1)  # Set sodium maximal conductance.\ncell.set(\"v\", -65.0)  # Set initial voltage.\n\nnet = ...  # See tutorial on Networks of Jaxley.\nfully_connect(net.cell(0), net.cell(1), IonotropicSynapse())\nnet.IonotropicSynapse().set(\"IonotropicSynapse_gS\", 0.01)\n

    In the previous two tutorials, you learned how to build single cells or networks and how to simulate them. In this tutorial, you will learn how to change parameters of such simulations.

    Let\u2019s get started!

    import matplotlib.pyplot as plt\nimport numpy as np\nimport jax.numpy as jnp\nfrom jax import jit, vmap\n\nimport jaxley as jx\nfrom jaxley.channels import Na, K, Leak\n
    "},{"location":"tutorial/03_setting_parameters/#preface-building-the-cell-or-network","title":"Preface: Building the cell or network","text":"

    We first build a cell (or network) in the same way as we showed in the previous tutorials:

    dt = 0.025\nt_max = 10.0\n\ncomp = jx.Compartment()\nbranch = jx.Branch(comp, nseg=2)\ncell = jx.Cell(branch, parents=[-1, 0])\n
    "},{"location":"tutorial/03_setting_parameters/#setting-parameters-in-jaxley","title":"Setting parameters in Jaxley","text":"

    To modify parameters of the simulation, you can use the .set() method. For example

    cell.set(\"radius\", 0.1)\n
    will modify the radius of every compartment in the cell to 0.1 micrometer. You can also modify the parameters only of some branches:
    cell.branch(0).set(\"radius\", 1.0)\n
    or even of compartments:
    cell.branch(0).comp(0).set(\"radius\", 10.0)\n

    You can always inspect the current parameters by inspecting cell.nodes, which is a pandas Dataframe that contains all information about the cell. You can use .set() to set morphological parameters, channel parameters, synaptic parameters, and initial states, as outlined below:

    "},{"location":"tutorial/03_setting_parameters/#setting-morphological-parameters","title":"Setting morphological parameters","text":"

    Jaxley allows to set the following morphological parameters:

    • radius: the radius of the (zylindrical) compartment (in micrometer)
    • length: the length of the zylindrical compartment (in micrometer)
    • axial_resistivity: the resistivity of current flow between compartments (in ohm centimeter)
    cell.branch(0).set(\"axial_resistivity\", 1000.0)\ncell.set(\"length\", 1.0)  # This will set every compartment in the cell to have length 1.0.\n
    cell.nodes\n
    comp_index branch_index cell_index length radius axial_resistivity capacitance v 0 0 0 0 1.0 1.0 1000.0 1.0 -70.0 1 1 0 0 1.0 1.0 1000.0 1.0 -70.0 2 2 1 0 1.0 1.0 5000.0 1.0 -70.0 3 3 1 0 1.0 1.0 5000.0 1.0 -70.0"},{"location":"tutorial/03_setting_parameters/#setting-channel-parameters","title":"Setting channel parameters","text":"

    You can also modify channel parameters. Every parameter that should be modifiable has to be defined in self.channel_params of the channel.

    cell.insert(Na())\ncell.branch(1).comp(0).set(\"Na_gNa\", 0.1)\n
    cell.nodes\n
    comp_index branch_index cell_index length radius axial_resistivity capacitance v Na Na_gNa eNa vt Na_m Na_h 0 0 0 0 1.0 1.0 1000.0 1.0 -70.0 True 0.05 50.0 -60.0 0.2 0.2 1 1 0 0 1.0 1.0 1000.0 1.0 -70.0 True 0.05 50.0 -60.0 0.2 0.2 2 2 1 0 1.0 1.0 5000.0 1.0 -70.0 True 0.10 50.0 -60.0 0.2 0.2 3 3 1 0 1.0 1.0 5000.0 1.0 -70.0 True 0.05 50.0 -60.0 0.2 0.2"},{"location":"tutorial/03_setting_parameters/#setting-synaptic-parameters","title":"Setting synaptic parameters","text":"

    In order to set parameters of synapses, you have to use net.SynapseName.set(), e.g.:

    from jaxley.synapses import IonotropicSynapse\nfrom jaxley.connect import fully_connect\n\nnum_cells = 2\nnet = jx.Network([cell for _ in range(num_cells)])\nfully_connect(net.cell(0), net.cell(1), IonotropicSynapse())\n\n# Unlike for channels, you have to index into the synapse with `net.SynapseName`\nnet.IonotropicSynapse.set(\"IonotropicSynapse_gS\", 0.1)\n
    /Users/michaeldeistler/Documents/phd/jaxley/jaxley/modules/base.py:1533: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.\n  self.pointer.edges = pd.concat(\n

    You can inspect synaptic parameters and states with net.edges:

    net.edges\n
    pre_locs pre_branch_index pre_cell_index post_locs post_branch_index post_cell_index type type_ind global_pre_comp_index global_post_comp_index global_pre_branch_index global_post_branch_index IonotropicSynapse_gS IonotropicSynapse_e_syn IonotropicSynapse_k_minus IonotropicSynapse_s 0 0.25 0 0 0.25 1 1 IonotropicSynapse 0 0 6 0 3 0.1 0.0 0.025 0.2"},{"location":"tutorial/03_setting_parameters/#setting-initial-states","title":"Setting initial states","text":"

    Finally, you can also set initial states. These include the initial voltage v and the states of all channels and synapses (which must be defined in self.channel_states of the channel. For example:

    net.set(\"v\", -72.0)\nnet.IonotropicSynapse.set(\"IonotropicSynapse_s\", 0.1)\n
    net.nodes\n
    comp_index branch_index cell_index length radius axial_resistivity capacitance v Na Na_gNa eNa vt Na_m Na_h 0 0 0 0 1.0 1.0 1000.0 1.0 -72.0 True 0.05 50.0 -60.0 0.2 0.2 1 1 0 0 1.0 1.0 1000.0 1.0 -72.0 True 0.05 50.0 -60.0 0.2 0.2 2 2 1 0 1.0 1.0 5000.0 1.0 -72.0 True 0.10 50.0 -60.0 0.2 0.2 3 3 1 0 1.0 1.0 5000.0 1.0 -72.0 True 0.05 50.0 -60.0 0.2 0.2 4 4 2 1 1.0 1.0 1000.0 1.0 -72.0 True 0.05 50.0 -60.0 0.2 0.2 5 5 2 1 1.0 1.0 1000.0 1.0 -72.0 True 0.05 50.0 -60.0 0.2 0.2 6 6 3 1 1.0 1.0 5000.0 1.0 -72.0 True 0.10 50.0 -60.0 0.2 0.2 7 7 3 1 1.0 1.0 5000.0 1.0 -72.0 True 0.05 50.0 -60.0 0.2 0.2
    net.edges\n
    pre_locs pre_branch_index pre_cell_index post_locs post_branch_index post_cell_index type type_ind global_pre_comp_index global_post_comp_index global_pre_branch_index global_post_branch_index IonotropicSynapse_gS IonotropicSynapse_e_syn IonotropicSynapse_k_minus IonotropicSynapse_s 0 0.25 0 0 0.25 1 1 IonotropicSynapse 0 0 6 0 3 0.1 0.0 0.025 0.1"},{"location":"tutorial/03_setting_parameters/#summary","title":"Summary","text":"

    You can now modify parameters of your Jaxley simulation. In the next tutorial, you will learn how to make parameter sweeps (or stimulus sweeps) fast with jit-compilation and GPU parallelization.

    "},{"location":"tutorial/04_jit_and_vmap/","title":"Speeding up simulations with JIT-compilation and GPUs","text":"

    In this tutorial, you will learn how to:

    • make parameter sweeps in Jaxley
    • use jit to compile your simulations and make them faster
    • use vmap to parallelize simulations on GPUs

    Here is a code snippet which you will learn to understand in this tutorial:

    from jax import jit, vmap\n\n\ncell = ...  # See tutorial on Basics of Jaxley.\n\ndef simulate(params):\n    param_state = None\n    param_state = cell.data_set(\"Na_gNa\", params[0], param_state)\n    param_state = cell.data_set(\"K_gK\", params[1], param_state)\n    return jx.integrate(cell, param_state=param_state)\n\n# Define 100 sets of sodium and potassium conductances.\nall_params = jnp.asarray(np.random.rand(100, 2))\n\n# Fast for-loops with jit compilation.\njitted_simulate = jit(simulate)\nvoltages = [jitted_simulate(params) for params in all_params]\n\n# Using vmap for parallelization.\nvmapped_simulate = vmap(jitted_simulate, in_axes=(0,))\nvoltages = vmapped_simulate(all_params)\n

    In the previous tutorials, you learned how to build single cells or networks and how to change their parameters. In this tutorial, you will learn how to speed up such simulations by many orders of magnitude. This can be achieved in to ways:

    • by using JIT compilation
    • by using GPU parallelization

    Let\u2019s get started!

    "},{"location":"tutorial/04_jit_and_vmap/#using-gpu-or-cpu","title":"Using GPU or CPU","text":"

    In Jaxley you can set whether you want to use gpu or cpu with the following lines at the beginning of your script:

    from jax import config\nconfig.update(\"jax_platform_name\", \"cpu\")\n

    JAX (and Jaxley) also allow to choose between float32 and float64. Especially on GPUs, float32 will be faster, but we have experienced stability issues when simulating morphologically detailed neurons with float32.

    config.update(\"jax_enable_x64\", True)  # Set to false to use `float32`.\n

    Next, we will import relevant libraries:

    import matplotlib.pyplot as plt\nimport numpy as np\nimport jax.numpy as jnp\nfrom jax import jit, vmap\n\nimport jaxley as jx\nfrom jaxley.channels import Na, K, Leak\n
    "},{"location":"tutorial/04_jit_and_vmap/#building-the-cell-or-network","title":"Building the cell or network","text":"

    We first build a cell (or network) in the same way as we showed in the previous tutorials:

    dt = 0.025\nt_max = 10.0\n\ncomp = jx.Compartment()\nbranch = jx.Branch(comp, nseg=4)\ncell = jx.Cell(branch, parents=[-1, 0, 0, 1, 1, 2, 2])\n\ncell.insert(Na())\ncell.insert(K())\ncell.insert(Leak())\n\ncell.delete_stimuli()\ncurrent = jx.step_current(i_delay=1.0, i_dur=1.0, i_amp=0.1, delta_t=dt, t_max=t_max)\ncell.branch(0).loc(0.0).stimulate(current)\n\ncell.delete_recordings()\ncell.branch(0).loc(0.0).record()\n
    Added 1 external_states. See `.externals` for details.\nAdded 1 recordings. See `.recordings` for details.\n
    "},{"location":"tutorial/04_jit_and_vmap/#parameter-sweeps","title":"Parameter sweeps","text":"

    Assume you want to run the same cell with many different values for the sodium and potassium conductance, for example for genetic algorithms or for parameter sweeps. To do this efficiently in Jaxley, you have to use the data_set() method (in combination with jit and vmap, as shown later):

    def simulate(params):\n    param_state = None\n    param_state = cell.data_set(\"Na_gNa\", params[0], param_state)\n    param_state = cell.data_set(\"K_gK\", params[1], param_state)\n    return jx.integrate(cell, param_state=param_state)\n

    The .data_set() method takes three arguments:

    1) the name of the parameter you want to set. Jaxley allows to set the following parameters: \u201cradius\u201d, \u201clength\u201d, \u201caxial_resistivity\u201d, as well as all parameters of channels and synapses. 2) the value of the parameter. 3) a param_state which is initialized as None and is modified by .data_set(). This has to be passed to jx.integrate().

    Having done this, the simplest (but least efficient) way to perform the parameter sweep is to run a for-loop over many parameter sets:

    # Define 5 sets of sodium and potassium conductances.\nall_params = jnp.asarray(np.random.rand(5, 2))\n\nvoltages = jnp.asarray([simulate(params) for params in all_params])\nprint(\"voltages.shape\", voltages.shape)\n
    voltages.shape (5, 1, 402)\n

    The resulting voltages have shape (num_simulations, num_recordings, num_timesteps).

    "},{"location":"tutorial/04_jit_and_vmap/#stimulus-sweeps","title":"Stimulus sweeps","text":"

    In addition to running sweeps across multiple parameters, you can also run sweeeps across multiple stimuli (e.g. step current stimuli of different amplitudes. You can achieve this with the data_stimulate() method:

    def simulate(i_amp):\n    current = jx.step_current(1.0, 1.0, i_amp, 0.025, 10.0)\n\n    data_stimuli = None\n    data_stimuli = cell.branch(0).comp(0).data_stimulate(current, data_stimuli)\n    return jx.integrate(cell, data_stimuli=data_stimuli)\n

    "},{"location":"tutorial/04_jit_and_vmap/#speeding-up-for-loops-via-jit-compilation","title":"Speeding up for loops via jit compilation","text":"

    We can speed up such parameter sweeps (or stimulus sweeps) with jit compilation. jit compilation will compile the simulation when it is run for the first time, such that every other simulation will be must faster. This can be achieved by defining a new function which uses JAX\u2019s jit():

    jitted_simulate = jit(simulate)\n
    # First run, will be slow.\nvoltages = jitted_simulate(all_params[0])\n
    # More runs, will be much faster.\nvoltages = jnp.asarray([jitted_simulate(params) for params in all_params])\nprint(\"voltages.shape\", voltages.shape)\n
    voltages.shape (5, 1, 402)\n

    jit compilation can be up to 10k times faster, especially for small simulations with few compartments. For very large models, the gain obtained with jit will be much smaller (jit may even provide no speed up at all).

    "},{"location":"tutorial/04_jit_and_vmap/#speeding-up-with-gpu-parallelization-via-vmap","title":"Speeding up with GPU parallelization via vmap","text":"

    Another way to speed up parameter sweeps is with GPU parallelization. Parallelization in Jaxley can be achieved by using vmap of JAX. To do this, we first create a new function that handles multiple parameter sets directly:

    # Using vmap for parallelization.\nvmapped_simulate = vmap(jitted_simulate)\n

    We can then run this method on all parameter sets (all_params.shape == (100, 2)), and Jaxley will automatically parallelize across them. Of course, you will only get a speed-up if you have a GPU available and you specified gpu as device in the beginning of this tutorial.

    voltages = vmapped_simulate(all_params)\n

    GPU parallelization with vmap can give a large speed-up, which can easily be 2-3 orders of magnitude.

    "},{"location":"tutorial/04_jit_and_vmap/#combining-jit-and-vmap","title":"Combining jit and vmap","text":"

    Finally, you can also combine using jit and vmap. For example, you can run multiple batches of many parallel simulations. Each batch can be parallelized with vmap and simulating each batch can be compiled with jit:

    jitted_vmapped_simulate = jit(vmap(simulate))\n
    for batch in range(10):\n    all_params = jnp.asarray(np.random.rand(5, 2))\n    voltages_batch = jitted_vmapped_simulate(all_params)\n

    That\u2019s all you have to know about jit and vmap! If you have worked through this and the previous tutorials, you should be ready to set up your first network simulations.

    "},{"location":"tutorial/04_jit_and_vmap/#next-steps","title":"Next steps","text":"

    If you want to learn more, we recommend you to read the tutorial on building channel and synapse models or to read the tutorial on groups, which allow to make your Jaxley simulations more elegant and convenient to interact with.

    Alternatively, you can also directly jump ahead to the tutorial on training biophysical networks which will teach you how you can optimize parameters of biophysical models with gradient descent.

    "},{"location":"tutorial/05_channel_and_synapse_models/","title":"Building and using ion channel models","text":"

    In this tutorial, you will learn how to:

    • define your own ion channel models beyond the preconfigured channels in Jaxley

    This tutorial assumes that you have already learned how to build basic simulations.

    from jax import config\nconfig.update(\"jax_enable_x64\", True)\nconfig.update(\"jax_platform_name\", \"cpu\")\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport jax\nimport jax.numpy as jnp\nfrom jax import jit, value_and_grad\n\nimport jaxley as jx\n

    First, we define a cell as you saw in the previous tutorial:

    comp = jx.Compartment()\nbranch = jx.Branch(comp, nseg=4)\ncell = jx.Cell(branch, parents=[-1, 0, 0, 1, 1, 2, 2])\n

    You have also already learned how to insert preconfigured channels into Jaxley models:

    cell.insert(Na())\ncell.insert(K())\ncell.insert(Leak())\n

    In this tutorial, we will show you how to build your own channel and synapse models.

    "},{"location":"tutorial/05_channel_and_synapse_models/#your-own-channel","title":"Your own channel","text":"

    Below is how you can define your own channel. We will go into detail about individual parts of the code in the next couple of cells.

    import jax.numpy as jnp\nfrom jaxley.channels import Channel\nfrom jaxley.solver_gate import solve_gate_exponential\n\n\ndef exp_update_alpha(x, y):\n    return x / (jnp.exp(x / y) - 1.0)\n\nclass Potassium(Channel):\n    \"\"\"Potassium channel.\"\"\"\n\n    def __init__(self, name = None):\n        super().__init__(name)\n        self.channel_params = {\"gK_new\": 1e-4}\n        self.channel_states = {\"n_new\": 0.0}\n        self.current_name = \"i_K\"\n\n    def update_states(self, states, dt, v, params):\n        \"\"\"Update state.\"\"\"\n        ns = states[\"n_new\"]\n        alpha = 0.01 * exp_update_alpha(-(v + 55), 10)\n        beta = 0.125 * jnp.exp(-(v + 65) / 80)\n        new_n = solve_gate_exponential(ns, dt, alpha, beta)\n        return {\"n_new\": new_n}\n\n    def compute_current(self, states, v, params):\n        \"\"\"Return current.\"\"\"\n        ns = states[\"n_new\"]\n\n        # Multiply with 1000 to convert Siemens to milli Siemens.\n        kd_conds = params[\"gK_new\"] * ns**4 * 1000  # mS/cm^2\n\n        e_kd = -77.0        \n        return kd_conds * (v - e_kd)\n\n    def init_state(self, v, params):\n        alpha = 0.01 * exp_update_alpha(-(v + 55), 10)\n        beta = 0.125 * jnp.exp(-(v + 65) / 80)\n        return {\"n_new\": alpha / (alpha + beta)}\n

    Let\u2019s look at each part of this in detail.

    The below is simply a helper function for the solver of the gate variables:

    def exp_update_alpha(x, y):\n    return x / (jnp.exp(x / y) - 1.0)\n

    Next, we define our channel as a class. It should inherit from the Channel class and define channel_params, channel_states, and current_name.

    class Potassium(Channel):\n    \"\"\"Potassium channel.\"\"\"\n\n    def __init__(self, name=None):\n        super().__init__(name)\n        self.channel_params = {\"gK_new\": 1e-4}\n        self.channel_states = {\"n_new\": 0.0}\n        self.current_name = \"i_K\"\n

    Next, we have the update_states() method, which updates the gating variables:

        def update_states(self, states, dt, v, params):\n

    The inputs states to the update_states method is a dictionary which contains all states that are updated (including states of other channels). v is a jnp.ndarray which contains the voltage of a single compartment (shape ()). Let\u2019s get the state of the potassium channel which we are building here:

    ns = states[\"n_new\"]\n

    Next, we update the state of the channel. In this example, we do this with exponential Euler, but you can implement any solver yourself:

    alpha = 0.01 * exp_update_alpha(-(v + 55), 10)\nbeta = 0.125 * jnp.exp(-(v + 65) / 80)\nnew_n = solve_gate_exponential(ns, dt, alpha, beta)\nreturn {\"n_new\": new_n}\n

    A channel also needs a compute_current() method which returns the current throught the channel:

        def compute_current(self, states, v, params):\n        ns = states[\"n_new\"]\n\n        # Multiply with 1000 to convert Siemens to milli Siemens.\n        kd_conds = params[\"gK_new\"] * ns**4 * 1000  # mS/cm^2\n\n        e_kd = -77.0        \n        current = kd_conds * (v - e_kd)\n        return current\n

    Finally, the init_state() method can be implemented optionally. It can be used to automatically compute the initial state based on the voltage when cell.init_states() is run.

    Alright, done! We can now insert this channel into any jx.Module such as our cell:

    cell.insert(Potassium())\n
    cell.delete_stimuli()\ncurrent = jx.step_current(1.0, 1.0, 0.1, 0.025, 10.0)\ncell.branch(0).comp(0).stimulate(current)\n\ncell.delete_recordings()\ncell.branch(0).comp(0).record()\n
    Added 1 external_states. See `.externals` for details.\nAdded 1 recordings. See `.recordings` for details.\n
    s = jx.integrate(cell)\n
    fig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = ax.plot(s.T[:-1])\n_ = ax.set_ylim([-80, 50])\n_ = ax.set_xlabel(\"Time (ms)\")\n_ = ax.set_ylabel(\"Voltage (mV)\")\n

    "},{"location":"tutorial/05_channel_and_synapse_models/#your-own-synapse","title":"Your own synapse","text":"

    The parts below assume that you have already learned how to build network simulations in Jaxley.

    The below is an example of how to define your own synapse model in Jaxley:`

    import jax.numpy as jnp\nfrom jaxley.synapses.synapse import Synapse\n\n\nclass TestSynapse(Synapse):\n    \"\"\"\n    Compute syanptic current and update syanpse state.\n    \"\"\"\n    def __init__(self, name = None):\n        super().__init__(name)\n        self.synapse_params = {\"gChol\": 0.001, \"eChol\": 0.0}\n        self.synapse_states = {\"s_chol\": 0.1}\n\n    def update_states(self, states, delta_t, pre_voltage, post_voltage, params):\n        \"\"\"Return updated synapse state and current.\"\"\"\n        s_inf = 1.0 / (1.0 + jnp.exp((-35.0 - pre_voltage) / 10.0))\n        exp_term = jnp.exp(-delta_t)\n        new_s = states[\"s_chol\"] * exp_term + s_inf * (1.0 - exp_term)\n        return {\"s_chol\": new_s}\n\n    def compute_current(self, states, pre_voltage, post_voltage, params):\n        g_syn = params[\"gChol\"] * states[\"s_chol\"]\n        return g_syn * (post_voltage - params[\"eChol\"])\n

    As you can see above, synapses follow closely how channels are defined. The main difference is that the compute_current method takes two voltages: the pre-synaptic voltage (a jnp.ndarray of shape ()) and the post-synaptic voltage (a jnp.ndarray of shape ()).

    net = jx.Network([cell for _ in range(3)])\n
    from jaxley.connect import connect\n\npre = net.cell(0).branch(0).loc(0.0)\npost = net.cell(1).branch(0).loc(0.0)\nconnect(pre, post, TestSynapse())\n
    /Users/michaeldeistler/Documents/phd/jaxley/jaxley/modules/base.py:1533: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.\n  self.pointer.edges = pd.concat(\n
    net.cell(0).branch(0).loc(0.0).stimulate(jx.step_current(1.0, 2.0, 0.1, 0.025, 10.0))\nfor i in range(3):\n    net.cell(i).branch(0).loc(0.0).record()\n
    Added 1 external_states. See `.externals` for details.\nAdded 1 recordings. See `.recordings` for details.\nAdded 1 recordings. See `.recordings` for details.\nAdded 1 recordings. See `.recordings` for details.\n
    s = jx.integrate(net)\n
    fig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = ax.plot(s.T[:-1])\n_ = ax.set_ylim([-80, 50])\n_ = ax.set_xlabel(\"Time (ms)\")\n_ = ax.set_ylabel(\"Voltage (mV)\")\n

    That\u2019s it! You are now ready to build your own custom simulations and equip them with channel and synapse models!

    This tutorial does not have an immediate follow-up tutorial. You could read the tutorial on groups, which allow to make your Jaxley simulations more elegant and convenient to interact with.

    Alternatively, you can also directly jump ahead to the tutorial on training biophysical networks which will teach you how you can optimize parameters of biophysical models with gradient descent.

    "},{"location":"tutorial/06_groups/","title":"Defining groups for easier handling of complex networks","text":"

    In this tutorial, you will learn how to:

    • define groups (aka sectionlists) to simplify iteractions with Jaxley

    Here is a code snippet which you will learn to understand in this tutorial:

    from jax import jit, vmap\n\n\nnet = ...  # See tutorial on Basics of Jaxley.\n\nnet.cell(0).add_to_group(\"fast_spiking\")\nnet.cell(1).add_to_group(\"slow_spiking\")\n\ndef simulate(params):\n    param_state = None\n    param_state = net.fast_spiking.data_set(\"HH_gNa\", params[0], param_state)\n    param_state = net.slow_spiking.data_set(\"HH_gNa\", params[1], param_state)\n    return jx.integrate(net, param_state=param_state)\n\n# Define sodium for fast and slow spiking neurons.\nparams = jnp.asarray([1.0, 0.1])\n\n# Run simulation.\nvoltages = simulate(params)\n

    In many cases, you might want to group several compartments (or branches, or cells) and assign a unique parameter or mechanism to this group. For example, you might want to define a couple of branches as basal and then assign a Hodgkin-Huxley mechanism only to those branches. Or you might define a couple of cells as fast spiking and assign them a high value for the sodium conductance. We describe how you can do this in this tutorial.

    from jax import config\nconfig.update(\"jax_enable_x64\", True)\nconfig.update(\"jax_platform_name\", \"cpu\")\n\nimport time\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport jax\nimport jax.numpy as jnp\nfrom jax import jit, value_and_grad\n\nimport jaxley as jx\nfrom jaxley.channels import Na, K, Leak\nfrom jaxley.synapses import IonotropicSynapse\nfrom jaxley.connect import fully_connect\n

    First, we define a network as you saw in the previous tutorial:

    comp = jx.Compartment()\nbranch = jx.Branch(comp, nseg=2)\ncell = jx.Cell(branch, parents=[-1, 0, 0, 1])\nnetwork = jx.Network([cell for _ in range(3)])\n\npre = network.cell([0, 1])\npost = network.cell([2])\nfully_connect(pre, post, IonotropicSynapse())\n\nnetwork.insert(Na())\nnetwork.insert(K())\nnetwork.insert(Leak())\n
    /Users/michaeldeistler/Documents/phd/jaxley/jaxley/modules/base.py:1533: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.\n  self.pointer.edges = pd.concat(\n
    "},{"location":"tutorial/06_groups/#group-apical-dendrites","title":"Group: apical dendrites","text":"

    Assume that, in each of the five neurons in this network, the second and forth branch are apical dendrites. We can define this as:

    for cell_ind in range(3):\n    network.cell(cell_ind).branch(1).add_to_group(\"apical\")\n    network.cell(cell_ind).branch(3).add_to_group(\"apical\")\n

    After this, we can access network.apical as we previously accesses anything else:

    network.apical.set(\"radius\", 0.3)\n
    network.apical.view\n
    comp_index branch_index cell_index length radius axial_resistivity capacitance v Na Na_gNa ... K_gK eK K_n Leak Leak_gLeak Leak_eLeak global_comp_index global_branch_index global_cell_index controlled_by_param 2 2 1 0 10.0 0.3 5000.0 1.0 -70.0 True 0.05 ... 0.005 -90.0 0.2 True 0.0001 -70.0 2 1 0 0 3 3 1 0 10.0 0.3 5000.0 1.0 -70.0 True 0.05 ... 0.005 -90.0 0.2 True 0.0001 -70.0 3 1 0 0 6 6 3 0 10.0 0.3 5000.0 1.0 -70.0 True 0.05 ... 0.005 -90.0 0.2 True 0.0001 -70.0 6 3 0 0 7 7 3 0 10.0 0.3 5000.0 1.0 -70.0 True 0.05 ... 0.005 -90.0 0.2 True 0.0001 -70.0 7 3 0 0 10 10 5 1 10.0 0.3 5000.0 1.0 -70.0 True 0.05 ... 0.005 -90.0 0.2 True 0.0001 -70.0 10 5 1 0 11 11 5 1 10.0 0.3 5000.0 1.0 -70.0 True 0.05 ... 0.005 -90.0 0.2 True 0.0001 -70.0 11 5 1 0 14 14 7 1 10.0 0.3 5000.0 1.0 -70.0 True 0.05 ... 0.005 -90.0 0.2 True 0.0001 -70.0 14 7 1 0 15 15 7 1 10.0 0.3 5000.0 1.0 -70.0 True 0.05 ... 0.005 -90.0 0.2 True 0.0001 -70.0 15 7 1 0 18 18 9 2 10.0 0.3 5000.0 1.0 -70.0 True 0.05 ... 0.005 -90.0 0.2 True 0.0001 -70.0 18 9 2 0 19 19 9 2 10.0 0.3 5000.0 1.0 -70.0 True 0.05 ... 0.005 -90.0 0.2 True 0.0001 -70.0 19 9 2 0 22 22 11 2 10.0 0.3 5000.0 1.0 -70.0 True 0.05 ... 0.005 -90.0 0.2 True 0.0001 -70.0 22 11 2 0 23 23 11 2 10.0 0.3 5000.0 1.0 -70.0 True 0.05 ... 0.005 -90.0 0.2 True 0.0001 -70.0 23 11 2 0

    12 rows \u00d7 25 columns

    "},{"location":"tutorial/06_groups/#group-fast-spiking","title":"Group: fast spiking","text":"

    Similarly, you could define a group of fast-spiking cells. Assume that the first and second cell are fast-spiking:

    network.cell(0).add_to_group(\"fast_spiking\")\nnetwork.cell(1).add_to_group(\"fast_spiking\")\n
    network.fast_spiking.set(\"Na_gNa\", 0.4)\n
    network.fast_spiking.view\n
    comp_index branch_index cell_index length radius axial_resistivity capacitance v Na Na_gNa ... K_gK eK K_n Leak Leak_gLeak Leak_eLeak global_comp_index global_branch_index global_cell_index controlled_by_param 0 0 0 0 10.0 1.0 5000.0 1.0 -70.0 True 0.4 ... 0.005 -90.0 0.2 True 0.0001 -70.0 0 0 0 0 1 1 0 0 10.0 1.0 5000.0 1.0 -70.0 True 0.4 ... 0.005 -90.0 0.2 True 0.0001 -70.0 1 0 0 0 2 2 1 0 10.0 0.3 5000.0 1.0 -70.0 True 0.4 ... 0.005 -90.0 0.2 True 0.0001 -70.0 2 1 0 0 3 3 1 0 10.0 0.3 5000.0 1.0 -70.0 True 0.4 ... 0.005 -90.0 0.2 True 0.0001 -70.0 3 1 0 0 4 4 2 0 10.0 1.0 5000.0 1.0 -70.0 True 0.4 ... 0.005 -90.0 0.2 True 0.0001 -70.0 4 2 0 0 5 5 2 0 10.0 1.0 5000.0 1.0 -70.0 True 0.4 ... 0.005 -90.0 0.2 True 0.0001 -70.0 5 2 0 0 6 6 3 0 10.0 0.3 5000.0 1.0 -70.0 True 0.4 ... 0.005 -90.0 0.2 True 0.0001 -70.0 6 3 0 0 7 7 3 0 10.0 0.3 5000.0 1.0 -70.0 True 0.4 ... 0.005 -90.0 0.2 True 0.0001 -70.0 7 3 0 0 8 8 4 1 10.0 1.0 5000.0 1.0 -70.0 True 0.4 ... 0.005 -90.0 0.2 True 0.0001 -70.0 8 4 1 0 9 9 4 1 10.0 1.0 5000.0 1.0 -70.0 True 0.4 ... 0.005 -90.0 0.2 True 0.0001 -70.0 9 4 1 0 10 10 5 1 10.0 0.3 5000.0 1.0 -70.0 True 0.4 ... 0.005 -90.0 0.2 True 0.0001 -70.0 10 5 1 0 11 11 5 1 10.0 0.3 5000.0 1.0 -70.0 True 0.4 ... 0.005 -90.0 0.2 True 0.0001 -70.0 11 5 1 0 12 12 6 1 10.0 1.0 5000.0 1.0 -70.0 True 0.4 ... 0.005 -90.0 0.2 True 0.0001 -70.0 12 6 1 0 13 13 6 1 10.0 1.0 5000.0 1.0 -70.0 True 0.4 ... 0.005 -90.0 0.2 True 0.0001 -70.0 13 6 1 0 14 14 7 1 10.0 0.3 5000.0 1.0 -70.0 True 0.4 ... 0.005 -90.0 0.2 True 0.0001 -70.0 14 7 1 0 15 15 7 1 10.0 0.3 5000.0 1.0 -70.0 True 0.4 ... 0.005 -90.0 0.2 True 0.0001 -70.0 15 7 1 0

    16 rows \u00d7 25 columns

    "},{"location":"tutorial/06_groups/#groups-from-swc-files","title":"Groups from SWC files","text":"

    Note: If you are reading swc morphologigies, you can automatically assign groups with jx.read_swc(file_name, assign_groups=True). After that, you can directly use cell.soma, cell.apical, cell.basal, or cell.axon.

    "},{"location":"tutorial/06_groups/#how-groups-are-interpreted-by-make_trainable","title":"How groups are interpreted by .make_trainable()","text":"

    If you make a parameter of a group trainable, then it will be treated as a single shared parameter for a given property:

    network.fast_spiking.make_trainable(\"Na_gNa\")\n
    Number of newly added trainable parameters: 1. Total number of trainable parameters: 1\n

    As such, get_parameters() returns only a single trainable parameter, which will be the sodium conductance for every compartment of every fast-spiking neuron:

    network.get_parameters()\n
    [{'Na_gNa': Array([0.4], dtype=float64)}]\n

    If, instead, you would want a separate parameter for every fast-spiking cell, you should not use the group, but instead do the following (remember that fast-spiking neurons had indices [0,1,3]):

    network.cell([0,1]).make_trainable(\"axial_resistivity\")\n
    Number of newly added trainable parameters: 2. Total number of trainable parameters: 3\n
    network.get_parameters()\n
    [{'Na_gNa': Array([0.4], dtype=float64)},\n {'axial_resistivity': Array([5000., 5000.], dtype=float64)}]\n

    This generated three parameters for the axial resistivitiy, each corresponding to one cell.

    "},{"location":"tutorial/06_groups/#summary","title":"Summary","text":"

    Groups allow you to organize your simulation in a more intuitive way, and they allow to perform parameter sharing with make_trainable().

    "},{"location":"tutorial/07_gradient_descent/","title":"Training biophysical models","text":"

    In this tutorial, you will learn how to train biophysical models in Jaxley. This includes the following:

    • compute the gradient with respect to parameters
    • use parameter transformations
    • use multi-level checkpointing
    • define optimizers
    • write dataloaders and parallelize across data

    Here is a code snippet which you will learn to understand in this tutorial:

    from jax import jit, vmap, value_and_grad\nimport jaxley as jx\n\n\nnet = ...  # See tutorial on the basics of `Jaxley`.\n\n# Define which parameters to train.\nnet.cell(\"all\").make_trainable(\"HH_gNa\")\nnet.IonotropicSynapse.make_trainable(\"IonotropicSynapse_gS\")\nparameters = net.get_parameters()\n\n# Define parameter transform and apply it to the parameters.\ntransform = jx.ParamTransform(\n    lowers={\"HH_gNa\": 0.0, \"IonotropicSynapse_gS\": 0.0},\n    uppers={\"HH_gNa\": 1.0, \"IonotropicSynapse_gS\": 1.0},\n)\nopt_params = transform.inverse(parameters)\n\n# Define simulation and batch it across stimuli.\ndef simulate(params, datapoint):\n    current = jx.datapoint_to_step_currents(i_delay=1.0, i_dur=1.0, i_amps=datapoint, dt=0.025, t_max=5.0)\n    data_stimuli = net.cell(0).branch(0).comp(0).data_stimulate(current, None\n    return jx.integrate(net, params=params, data_stimuli=data_stimuli, checkpoint_inds=[20, 20])\n\nbatch_simulate = vmap(simulate, in_axes=(None, 0))\n\n# Define loss function and its gradient.\ndef loss_fn(opt_params, datapoints, label):\n    params = transform.forward(opt_params)\n    voltages = batch_simulate(params, datapoints)\n    return jnp.abs(jnp.mean(voltages) - label)\n\ngrad_fn = jit(value_and_grad(loss_fn, argnums=0))\n\n# Define data and dataloader.\ndata = jnp.asarray(np.random.randn(100, 3))\ndataloader = Dataset.from_tensor_slices((inputs, labels))\ndataloader = dataloader.shuffle(dataloader.cardinality()).batch(4)\n\n# Define the optimizer.\noptimizer = optax.Adam(lr=0.01)\nopt_state = optimizer.init_state(opt_params)\n\nfor epoch in range(10):\n    for batch in dataloader:\n        stimuli = batch[0].numpy()\n        labels = batch[1].numpy()\n        loss, gradient = grad_fn(opt_params, stimuli, labels)\n\n        # Optimizer step.\n        updates, opt_state = optimizer.update(gradient, opt_state)\n        opt_params = optax.apply_updates(opt_params, updates)\n

    from jax import config\nconfig.update(\"jax_enable_x64\", True)\nconfig.update(\"jax_platform_name\", \"cpu\")\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport jax\nimport jax.numpy as jnp\nfrom jax import jit, vmap, value_and_grad\n\nimport jaxley as jx\nfrom jaxley.channels import Leak\nfrom jaxley.synapses import TanhRateSynapse\nfrom jaxley.connect import fully_connect\n

    First, we define a network as you saw in the previous tutorial:

    _ = np.random.seed(0)  # For synaptic locations.\n\ncomp = jx.Compartment()\nbranch = jx.Branch(comp, nseg=2)\ncell = jx.Cell(branch, parents=[-1, 0, 0])\nnet = jx.Network([cell for _ in range(3)])\n\npre = net.cell([0, 1])\npost = net.cell([2])\nfully_connect(pre, post, TanhRateSynapse())\n\n# Change some default values of the tanh synapse.\nnet.TanhRateSynapse.set(\"TanhRateSynapse_x_offset\", -60.0)\nnet.TanhRateSynapse.set(\"TanhRateSynapse_gS\", 1e-3)\nnet.TanhRateSynapse.set(\"TanhRateSynapse_slope\", 0.1)\n\nnet.insert(Leak())\n
    /Users/michaeldeistler/Documents/phd/jaxley/jaxley/modules/base.py:1533: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.\n  self.pointer.edges = pd.concat(\n

    This network consists of three neurons arranged in two layers:

    net.compute_xyz()\nnet.rotate(180)\nfig, ax = plt.subplots(1, 1, figsize=(3, 2))\n_ = net.vis(ax=ax, detail=\"full\", layers=[2, 1], layer_kwargs={\"within_layer_offset\": 100.0, \"between_layer_offset\": 100.0}) \n

    We consider the last neuron as the output neuron and record the voltage from there:

    net.delete_recordings()\nnet.cell(0).branch(0).loc(0.0).record()\nnet.cell(1).branch(0).loc(0.0).record()\nnet.cell(2).branch(0).loc(0.0).record()\n
    Added 1 recordings. See `.recordings` for details.\nAdded 1 recordings. See `.recordings` for details.\nAdded 1 recordings. See `.recordings` for details.\n
    "},{"location":"tutorial/07_gradient_descent/#defining-a-dataset","title":"Defining a dataset","text":"

    We will train this biophysical network on a classification task. The inputs will be values and the label is binary:

    inputs = jnp.asarray(np.random.rand(100, 2))\nlabels = jnp.asarray((inputs[:, 0] + inputs[:, 1]) > 1.0)\n
    fig, ax = plt.subplots(1, 1, figsize=(3, 2))\n_ = ax.scatter(inputs[labels, 0], inputs[labels, 1])\n_ = ax.scatter(inputs[~labels, 0], inputs[~labels, 1])\n

    labels = labels.astype(float)\n
    "},{"location":"tutorial/07_gradient_descent/#defining-trainable-parameters","title":"Defining trainable parameters","text":"
    net.delete_trainables()\n

    This follows the same API as .set() seen in the previous tutorial. If you want to use a single parameter for all radiuses in the entire network, do:

    net.make_trainable(\"radius\")\n
    Number of newly added trainable parameters: 1. Total number of trainable parameters: 1\n

    We can also define parameters for individual compartments. To do this, use the \"all\" key. The following defines a separate parameter the sodium conductance for every compartment in the entire network:

    net.cell(\"all\").branch(\"all\").loc(\"all\").make_trainable(\"Leak_gLeak\")\n
    Number of newly added trainable parameters: 18. Total number of trainable parameters: 19\n
    "},{"location":"tutorial/07_gradient_descent/#making-synaptic-parameters-trainable","title":"Making synaptic parameters trainable","text":"

    Synaptic parameters can be made trainable in the exact same way. To use a single parameter for all syanptic conductances in the entire network, do

    net.TanhRateSynapse.make_trainable(\"TanhRateSynapse_gS\")\n

    Here, we use a different syanptic conductance for all syanpses. This can be done as follows:

    net.TanhRateSynapse(\"all\").make_trainable(\"TanhRateSynapse_gS\")\n
    Number of newly added trainable parameters: 2. Total number of trainable parameters: 21\n
    "},{"location":"tutorial/07_gradient_descent/#running-the-simulation","title":"Running the simulation","text":"

    Once all parameters are defined, you have to use .get_parameters() to obtain all trainable parameters. This is also the time to check how many trainable parameters your network has:

    params = net.get_parameters()\n

    You can now run the simulation with the trainable parameters by passing them to the jx.integrate function.

    s = jx.integrate(net, params=params, t_max=10.0)\n
    "},{"location":"tutorial/07_gradient_descent/#stimulating-the-network","title":"Stimulating the network","text":"

    The network above does not yet get any stimuli. We will use the 2D inputs from the dataset to stimulate the two input neurons. The amplitude of the step current corresponds to the input value. Below is the simulator that defines this:

    def simulate(params, inputs):\n    currents = jx.datapoint_to_step_currents(i_delay=1.0, i_dur=1.0, i_amp=inputs / 10, delta_t=0.025, t_max=10.0)\n\n    data_stimuli = None\n    data_stimuli = net.cell(0).branch(2).loc(1.0).data_stimulate(currents[0], data_stimuli=data_stimuli)\n    data_stimuli = net.cell(1).branch(2).loc(1.0).data_stimulate(currents[1], data_stimuli=data_stimuli)\n\n    return jx.integrate(net, params=params, data_stimuli=data_stimuli)\n\nbatched_simulate = vmap(simulate, in_axes=(None, 0))\n

    We can also inspect some traces:

    traces = batched_simulate(params, inputs[:4])\n
    fig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = ax.plot(traces[:, 2, :].T)\n

    "},{"location":"tutorial/07_gradient_descent/#defining-a-loss-function","title":"Defining a loss function","text":"

    Let us define a loss function to be optimized:

    def loss(params, inputs, labels):\n    traces = batched_simulate(params, inputs)  # Shape `(batchsize, num_recordings, timepoints)`.\n    prediction = jnp.mean(traces[:, 2], axis=1)  # Use the average over time of the output neuron (2) as prediction.\n    prediction = (prediction + 72.0) / 5  # Such that the prediction is roughly in [0, 1].\n    losses = jnp.abs(prediction - labels)  # Mean absolute error loss.\n    return jnp.mean(losses)  # Average across the batch.\n

    And we can use JAX\u2019s inbuilt functions to take the gradient through the entire ODE:

    jitted_grad = jit(value_and_grad(loss, argnums=0))\n
    value, gradient = jitted_grad(params, inputs[:4], labels[:4])\n
    "},{"location":"tutorial/07_gradient_descent/#defining-parameter-transformations","title":"Defining parameter transformations","text":"

    Before training, however, we will enforce for all parameters to be within a prespecified range (such that, e.g., conductances can not become negative)

    transform = jx.ParamTransform(\n    lowers={\n        \"Leak_gLeak\": 1e-5,\n        \"radius\": 0.1,\n        \"TanhRateSynapse_gS\": 1e-5,\n    },\n    uppers={\n        \"Leak_gLeak\": 1e-3,\n        \"radius\": 5.0,\n        \"TanhRateSynapse_gS\": 1e-2,\n    }, \n)\n

    With these modify the loss function acocrdingly:

    def loss(opt_params, inputs, labels):\n    transform.forward(opt_params)\n\n    traces = batched_simulate(params, inputs)  # Shape `(batchsize, num_recordings, timepoints)`.\n    prediction = jnp.mean(traces[:, 2], axis=1)  # Use the average over time of the output neuron (2) as prediction.\n    prediction = (prediction + 72.0)  # Such that the prediction is around 0.\n    losses = jnp.abs(prediction - labels)  # Mean absolute error loss.\n    return jnp.mean(losses)  # Average across the batch.\n
    "},{"location":"tutorial/07_gradient_descent/#using-checkpointing","title":"Using checkpointing","text":"

    Checkpointing allows to vastly reduce the memory requirements of training biophysical models.

    t_max = 5.0\ndt = 0.025\n\nlevels = 2\ntime_points = t_max // dt + 2\ncheckpoints = [int(np.ceil(time_points**(1/levels))) for _ in range(levels)]\n

    To enable checkpointing, we have to modify the simulate function appropriately and use

    jx.integrate(..., checkpoint_inds=checkpoints)\n
    as done below:

    def simulate(params, inputs):\n    currents = jx.datapoint_to_step_currents(i_delay=1.0, i_dur=1.0, i_amp=inputs / 10.0, delta_t=dt, t_max=t_max)\n\n    data_stimuli = None\n    data_stimuli = net.cell(0).branch(2).loc(1.0).data_stimulate(currents[0], data_stimuli=data_stimuli)\n    data_stimuli = net.cell(1).branch(2).loc(1.0).data_stimulate(currents[1], data_stimuli=data_stimuli)\n\n    return jx.integrate(net, params=params, data_stimuli=data_stimuli, checkpoint_lengths=checkpoints)\n\nbatched_simulate = vmap(simulate, in_axes=(None, 0))\n\n\ndef predict(params, inputs):\n    traces = simulate(params, inputs)  # Shape `(batchsize, num_recordings, timepoints)`.\n    prediction = jnp.mean(traces[2])  # Use the average over time of the output neuron (2) as prediction.\n    return prediction + 72.0  # Such that the prediction is around 0.\n\nbatched_predict = vmap(predict, in_axes=(None, 0))\n\n\ndef loss(opt_params, inputs, labels):\n    params = transform.forward(opt_params)\n\n    predictions = batched_predict(params, inputs)\n    losses = jnp.abs(predictions - labels)  # Mean absolute error loss.\n    return jnp.mean(losses)  # Average across the batch.\n\njitted_grad = jit(value_and_grad(loss, argnums=0))\n
    "},{"location":"tutorial/07_gradient_descent/#training","title":"Training","text":"

    We will use the ADAM optimizer from the optax library to optimize the free parameters (you have to install the package with pip install optax first):

    import optax\n
    opt_params = transform.inverse(params)\noptimizer = optax.adam(learning_rate=0.01)\nopt_state = optimizer.init(opt_params)\n
    "},{"location":"tutorial/07_gradient_descent/#writing-a-dataloader","title":"Writing a dataloader","text":"
    import tensorflow as tf\nfrom tensorflow.data import Dataset\n
    batch_size = 4\n\ntf.random.set_seed(1)\ndataloader = Dataset.from_tensor_slices((inputs, labels))\ndataloader = dataloader.shuffle(dataloader.cardinality()).batch(batch_size)\n
    "},{"location":"tutorial/07_gradient_descent/#training-loop","title":"Training loop","text":"
    for epoch in range(10):\n    epoch_loss = 0.0\n    for batch_ind, batch in enumerate(dataloader):\n        current_batch = batch[0].numpy()\n        label_batch = batch[1].numpy()\n        loss_val, gradient = jitted_grad(opt_params, current_batch, label_batch)\n        updates, opt_state = optimizer.update(gradient, opt_state)\n        opt_params = optax.apply_updates(opt_params, updates)\n        epoch_loss += loss_val\n\n    print(f\"epoch {epoch}, loss {epoch_loss}\")\n\nfinal_params = transform.forward(opt_params)\n
    epoch 0, loss 25.61663325387099\nepoch 1, loss 21.7304402547341\nepoch 2, loss 15.943236054666484\nepoch 3, loss 9.191846765081072\nepoch 4, loss 7.256558484588674\nepoch 5, loss 6.577375342584615\nepoch 6, loss 6.568056585075223\nepoch 7, loss 6.510474263850299\nepoch 8, loss 6.481302675498705\nepoch 9, loss 6.5030439519558865\n
    ntest = 32\npredictions = batched_predict(final_params, inputs[:ntest])\n
    fig, ax = plt.subplots(1, 1, figsize=(3, 2))\n_ = ax.scatter(labels[:ntest], predictions)\n_ = ax.set_xlabel(\"Label\")\n_ = ax.set_ylabel(\"Prediction\")\n

    Indeed, the loss goes down and the network successfully classifies the patterns.

    "},{"location":"tutorial/07_gradient_descent/#summary","title":"Summary","text":"

    Puh, this was a pretty dense tutorial with a lot of material. You should have learned how to:

    • compute the gradient with respect to parameters
    • use parameter transformations
    • use multi-level checkpointing
    • define optimizers
    • write dataloaders and parallelize across data

    This was the last tutorial of the Jaxley toolbox. If anything is still unclear please create a discussion. If you find any bugs, please open an issue. Happy coding!

    "},{"location":"tutorial/08_importing_morphologies/","title":"Working with morphologies","text":"

    In this tutorial, you will learn how to:

    • Load morphologies and make them compatible with Jaxley
    • How to use the visualization features
    • How to assemble a small network of morphologically accurate cells.

    Here is a code snippet which you will learn to understand in this tutorial:

    import jaxley as jx\n\ncell = jx.read_swc(\"my_cell.swc\", nseg=4, assign_groups=True)\n

    To work with more complicated morphologies, Jaxley supports importing morphological reconstructions via .swc files. .swc is currently the only supported format. Other formats like .asc need to be converted to .swc first, for example using the BlueBrain\u2019s morph-tool. For more information on the exact specifications of .swc see here.

    import jaxley as jx\nfrom jaxley.synapses import IonotropicSynapse\nimport matplotlib.pyplot as plt\n

    To work with .swc files, Jaxley implements a custom .swc reader. The reader traces the morphology and identifies all uninterrupted sections. These are then partitioned into branches, each of which will be approximated by a number of equally many compartments that can be simulated fully in parallel.

    To demonstrate this, let\u2019s import an example morphology of a Layer 5 pyramidal cell and visualize it.

    # import swc file into jx.Cell object\nfname = \"data/morph.swc\"\ncell = jx.read_swc(fname, nseg=8, max_branch_len=2000.0, assign_groups=True)\n\n# print shape (num_cells, num_branches, num_comps)\nprint(cell.shape)\n\ncell.show()\n
    (1, 157, 8)\n
    comp_index branch_index cell_index length radius axial_resistivity capacitance v 0 0 0 0 0.01250 8.119 5000.0 1.0 -70.0 1 1 0 0 0.01250 8.119 5000.0 1.0 -70.0 2 2 0 0 0.01250 8.119 5000.0 1.0 -70.0 3 3 0 0 0.01250 8.119 5000.0 1.0 -70.0 4 4 0 0 0.01250 8.119 5000.0 1.0 -70.0 ... ... ... ... ... ... ... ... ... 1251 1251 156 0 24.12382 0.550 5000.0 1.0 -70.0 1252 1252 156 0 24.12382 0.550 5000.0 1.0 -70.0 1253 1253 156 0 24.12382 0.550 5000.0 1.0 -70.0 1254 1254 156 0 24.12382 0.550 5000.0 1.0 -70.0 1255 1255 156 0 24.12382 0.550 5000.0 1.0 -70.0

    1256 rows \u00d7 8 columns

    As we can see, this yields a morphology that is approximated by 1256 compartments. Depending on the amount of detail that you need, you can also change the number of compartments in each branch:

    cell = jx.read_swc(fname, nseg=2, max_branch_len=2000.0, assign_groups=True)\n\n# print shape (num_cells, num_branches, num_comps)\nprint(cell.shape)\n\ncell.show()\n
    (1, 157, 2)\n
    comp_index branch_index cell_index length radius axial_resistivity capacitance v 0 0 0 0 0.050000 8.119000 5000.0 1.0 -70.0 1 1 0 0 0.050000 8.119000 5000.0 1.0 -70.0 2 2 1 0 6.241557 7.493344 5000.0 1.0 -70.0 3 3 1 0 6.241557 4.273686 5000.0 1.0 -70.0 4 4 2 0 4.160500 7.960000 5000.0 1.0 -70.0 ... ... ... ... ... ... ... ... ... 309 309 154 0 49.728572 0.400000 5000.0 1.0 -70.0 310 310 155 0 46.557908 0.494201 5000.0 1.0 -70.0 311 311 155 0 46.557908 0.302202 5000.0 1.0 -70.0 312 312 156 0 96.495281 0.742532 5000.0 1.0 -70.0 313 313 156 0 96.495281 0.550000 5000.0 1.0 -70.0

    314 rows \u00d7 8 columns

    # visualize the cell\ncell.vis()\nplt.axis(\"off\")\nplt.title(\"L5PC\")\nplt.show()\n

    While we only use two compartments to approximate each branch in this example, we can see the morphology is still plotted in great detail. This is because we always plot the full .swc reconstruction irrespective of the number of compartments used. This is stored in the cell.xyzr attribute in a per branch fashion.

    To highlight each branch seperately, we can iterate over them.

    fig, ax = plt.subplots(1, 1, figsize=(5, 5))\n# define colorwheel with 10 colors\ncolors = plt.cm.tab10.colors\nfor i in range(cell.shape[1]):\n    cell.branch(i).vis(ax=ax, col=colors[i % 10])\nplt.axis(\"off\")\nplt.title(\"Branches\")\nplt.show()\n

    Since Jaxley supports grouping different branches or compartments together, we can also use the id labels provided by the .swc file to assign group labels to the jx.Cell object.

    print(list(cell.group_nodes.keys()))\n\nfig, ax = plt.subplots(1, 1, figsize=(5, 5))\ncolors = plt.cm.tab10.colors\ncell.basal.vis(ax=ax, col=colors[2])\ncell.soma.vis(ax=ax, col=colors[1])\ncell.apical.vis(ax=ax, col=colors[0])\nplt.axis(\"off\")\nplt.title(\"Groups\")\nplt.show()\n
    ['soma', 'basal', 'apical', 'custom']\n

    To build a network of morphologically detailed cells, we can now connect several reconstructed cells together and also visualize the network. However, since all cells are going to have the same center, Jaxley will naively plot all of them on top of each other. To seperate out the cells, we therefore have to move them to a new location first.

    net = jx.Network([cell]*5)\njx.connect(net[0,0,0], net[2,0,0], IonotropicSynapse())\njx.connect(net[0,0,0], net[3,0,0], IonotropicSynapse())\njx.connect(net[0,0,0], net[4,0,0], IonotropicSynapse())\n\njx.connect(net[1,0,0], net[2,0,0], IonotropicSynapse())\njx.connect(net[1,0,0], net[3,0,0], IonotropicSynapse())\njx.connect(net[1,0,0], net[4,0,0], IonotropicSynapse())\n\nnet.rotate(-90)\n\nnet.cell(0).move(0, 300)\nnet.cell(1).move(0, 500)\n\nnet.cell(2).move(900, 200)\nnet.cell(3).move(900, 400)\nnet.cell(4).move(900, 600)\n\nnet.vis()\nplt.axis(\"off\")\nplt.show()\n
    /home/jnsbck/Uni/PhD/projects/jaxleyverse/jaxley/jaxley/modules/base.py:1528: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.\n  self.pointer.edges = pd.concat(\n

    Congrats! You have now learned how to vizualize and build networks out of very complex morphologies. To simulate this network, you can follow the steps in the tutroial on how to build a network.

    "}]} \ No newline at end of file diff --git a/sitemap.xml.gz b/sitemap.xml.gz index 5a98f6137c31cd5c3450a2c4c8c766b519f36161..7c87c0185b85ab991a510f5ef1e149cf49213f75 100644 GIT binary patch delta 15 Wcmdnayq%d%zMF$X$896qMn(W3(*zIz delta 15 Wcmdnayq%d%zMF$X#%UwlMn(W3UjzIA