diff --git a/docs/tutorials/05_channel_and_synapse_models.ipynb b/docs/tutorials/05_channel_and_synapse_models.ipynb index c1b1e715..1009fbbf 100644 --- a/docs/tutorials/05_channel_and_synapse_models.ipynb +++ b/docs/tutorials/05_channel_and_synapse_models.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "c19e95c6", + "id": "84d5d597", "metadata": {}, "source": [ "# Building and using ion channel models\n", @@ -17,7 +17,7 @@ { "cell_type": "code", "execution_count": 1, - "id": "11544a58", + "id": "a9a16457", "metadata": {}, "outputs": [], "source": [ @@ -36,7 +36,7 @@ }, { "cell_type": "markdown", - "id": "e141cff7-ffc1-4e4a-bfe5-2b87792de748", + "id": "1a91e1c7", "metadata": {}, "source": [ "First, we define a cell as you saw in the [previous tutorial](https://jaxleyverse.github.io/jaxley/latest/tutorial/01_morph_neurons/):" @@ -45,7 +45,7 @@ { "cell_type": "code", "execution_count": 2, - "id": "b308ed28-193c-4c15-be6c-019d81c03883", + "id": "082c1774", "metadata": {}, "outputs": [], "source": [ @@ -56,7 +56,7 @@ }, { "cell_type": "markdown", - "id": "4257adb3-df24-4a03-982c-0ae0983067b1", + "id": "9c181374", "metadata": {}, "source": [ "You have also already learned how to insert preconfigured channels into `Jaxley` models:\n", @@ -71,7 +71,7 @@ }, { "cell_type": "markdown", - "id": "b4521def", + "id": "fc631157", "metadata": {}, "source": [ "### Your own channel\n", @@ -81,7 +81,7 @@ { "cell_type": "code", "execution_count": 3, - "id": "6edc6581", + "id": "e8d96197", "metadata": {}, "outputs": [], "source": [ @@ -97,6 +97,7 @@ " \"\"\"Potassium channel.\"\"\"\n", "\n", " def __init__(self, name = None):\n", + " self.current_is_in_mA_per_cm2 = True\n", " super().__init__(name)\n", " self.channel_params = {\"gK_new\": 1e-4}\n", " self.channel_states = {\"n_new\": 0.0}\n", @@ -113,9 +114,7 @@ " def compute_current(self, states, v, params):\n", " \"\"\"Return current.\"\"\"\n", " ns = states[\"n_new\"]\n", - " \n", - " # Multiply with 1000 to convert Siemens to milli Siemens.\n", - " kd_conds = params[\"gK_new\"] * ns**4 * 1000 # mS/cm^2\n", + " kd_conds = params[\"gK_new\"] * ns**4 # S/cm^2\n", "\n", " e_kd = -77.0 \n", " return kd_conds * (v - e_kd)\n", @@ -128,7 +127,7 @@ }, { "cell_type": "markdown", - "id": "204afb5b", + "id": "c7812bf1", "metadata": {}, "source": [ "Let's look at each part of this in detail. \n", @@ -139,12 +138,13 @@ " return x / (jnp.exp(x / y) - 1.0)\n", "```\n", "\n", - "Next, we define our channel as a class. It should inherit from the `Channel` class and define `channel_params`, `channel_states`, and `current_name`.\n", + "Next, we define our channel as a class. It should inherit from the `Channel` class and define `channel_params`, `channel_states`, and `current_name`. You also need to set `self.current_is_in_mA_per_cm2=True` as the first line on your `__init__()` method. This is to acknowledge that your current is returned in `mA/cm2` (not in `uA/cm2`, as would have been required in Jaxley versions 0.4.0 or older).\n", "```python\n", "class Potassium(Channel):\n", " \"\"\"Potassium channel.\"\"\"\n", "\n", " def __init__(self, name=None):\n", + " self.current_is_in_mA_per_cm2 = True\n", " super().__init__(name)\n", " self.channel_params = {\"gK_new\": 1e-4}\n", " self.channel_states = {\"n_new\": 0.0}\n", @@ -175,7 +175,7 @@ " ns = states[\"n_new\"]\n", " \n", " # Multiply with 1000 to convert Siemens to milli Siemens.\n", - " kd_conds = params[\"gK_new\"] * ns**4 * 1000 # mS/cm^2\n", + " kd_conds = params[\"gK_new\"] * ns**4 # S/cm^2\n", "\n", " e_kd = -77.0 \n", " current = kd_conds * (v - e_kd)\n", @@ -187,7 +187,7 @@ }, { "cell_type": "markdown", - "id": "c6c17ea2", + "id": "794f6279", "metadata": {}, "source": [ "Alright, done! We can now insert this channel into any `jx.Module` such as our cell:" @@ -196,7 +196,7 @@ { "cell_type": "code", "execution_count": 4, - "id": "8bc79fc5", + "id": "94c270e3", "metadata": {}, "outputs": [], "source": [ @@ -206,7 +206,7 @@ { "cell_type": "code", "execution_count": 5, - "id": "23447b7e-7554-4e82-b702-db0cb6b517ca", + "id": "313ee72f", "metadata": {}, "outputs": [ { @@ -230,7 +230,7 @@ { "cell_type": "code", "execution_count": 6, - "id": "f4c678f2", + "id": "08a81449", "metadata": {}, "outputs": [], "source": [ @@ -240,7 +240,7 @@ { "cell_type": "code", "execution_count": 7, - "id": "1a10abc0", + "id": "4318f1bc", "metadata": {}, "outputs": [ { @@ -264,7 +264,7 @@ }, { "cell_type": "markdown", - "id": "43c9f1ea", + "id": "e302590d", "metadata": {}, "source": [ "### Your own synapse\n", @@ -277,7 +277,7 @@ { "cell_type": "code", "execution_count": 9, - "id": "62c719f9", + "id": "5c68fee6", "metadata": {}, "outputs": [], "source": [ @@ -308,7 +308,7 @@ }, { "cell_type": "markdown", - "id": "0de23892-4147-4ca2-a39d-97e4985c8e27", + "id": "d608b85c", "metadata": {}, "source": [ "As you can see above, synapses follow closely how channels are defined. The main difference is that the `compute_current` method takes two voltages: the pre-synaptic voltage (a `jnp.ndarray` of shape `()`) and the post-synaptic voltage (a `jnp.ndarray` of shape `()`)." @@ -317,7 +317,7 @@ { "cell_type": "code", "execution_count": 10, - "id": "97ee3c49", + "id": "0652ace1", "metadata": {}, "outputs": [], "source": [ @@ -327,7 +327,7 @@ { "cell_type": "code", "execution_count": 11, - "id": "a8c9428c-d0dc-4892-966e-11a43aea216d", + "id": "69dde19f", "metadata": {}, "outputs": [], "source": [ @@ -341,7 +341,7 @@ { "cell_type": "code", "execution_count": 12, - "id": "fc6fd5a2", + "id": "4f19d456", "metadata": {}, "outputs": [ { @@ -364,7 +364,7 @@ { "cell_type": "code", "execution_count": 13, - "id": "867f4f38", + "id": "3e955552", "metadata": {}, "outputs": [], "source": [ @@ -374,7 +374,7 @@ { "cell_type": "code", "execution_count": 14, - "id": "01782d18", + "id": "9cfc8f33", "metadata": {}, "outputs": [ { @@ -398,7 +398,7 @@ }, { "cell_type": "markdown", - "id": "5b944bee-bf80-4f6b-a5d2-5cd119e6d4a4", + "id": "e3f88051", "metadata": {}, "source": [ "That's it! You are now ready to build your own custom simulations and equip them with channel and synapse models!\n", diff --git a/jaxley/channels/channel.py b/jaxley/channels/channel.py index bf6f1d1a..678b1e1e 100644 --- a/jaxley/channels/channel.py +++ b/jaxley/channels/channel.py @@ -3,6 +3,7 @@ from abc import ABC, abstractmethod from typing import Dict, Optional, Tuple +from warnings import warn import jax.numpy as jnp @@ -20,6 +21,24 @@ class Channel: current_name = None def __init__(self, name: Optional[str] = None): + contact = ( + "If you have any questions, please reach out via email to " + "michael.deistler@uni-tuebingen.de or create an issue on Github: " + "https://github.com/jaxleyverse/jaxley/issues. Thank you!" + ) + if ( + not hasattr(self, "current_is_in_mA_per_cm2") + or not self.current_is_in_mA_per_cm2 + ): + raise ValueError( + "The channel you are using is deprecated. " + "In Jaxley version 0.5.0, we changed the unit of the current returned " + "by `compute_current` of channels from `uA/cm^2` to `mA/cm^2`. Please " + "update your channel model (by dividing the resulting current by 1000) " + "and set `self.current_is_in_mA_per_cm2=True` as the first line " + f"in the `__init__()` method of your channel. {contact}" + ) + self._name = name if name else self.__class__.__name__ @property diff --git a/jaxley/channels/hh.py b/jaxley/channels/hh.py index a4c35b2e..c19bf002 100644 --- a/jaxley/channels/hh.py +++ b/jaxley/channels/hh.py @@ -13,6 +13,8 @@ class HH(Channel): """Hodgkin-Huxley channel.""" def __init__(self, name: Optional[str] = None): + self.current_is_in_mA_per_cm2 = True + super().__init__(name) prefix = self._name self.channel_params = { @@ -52,10 +54,9 @@ def compute_current( prefix = self._name m, h, n = states[f"{prefix}_m"], states[f"{prefix}_h"], states[f"{prefix}_n"] - # Multiply with 1000 to convert Siemens to milli Siemens. - gNa = params[f"{prefix}_gNa"] * (m**3) * h * 1000 # mS/cm^2 - gK = params[f"{prefix}_gK"] * n**4 * 1000 # mS/cm^2 - gLeak = params[f"{prefix}_gLeak"] * 1000 # mS/cm^2 + gNa = params[f"{prefix}_gNa"] * (m**3) * h # S/cm^2 + gK = params[f"{prefix}_gK"] * n**4 # S/cm^2 + gLeak = params[f"{prefix}_gLeak"] # S/cm^2 return ( gNa * (v - params[f"{prefix}_eNa"]) diff --git a/jaxley/channels/pospischil.py b/jaxley/channels/pospischil.py index 2ce3b94f..5884deac 100644 --- a/jaxley/channels/pospischil.py +++ b/jaxley/channels/pospischil.py @@ -36,6 +36,8 @@ class Leak(Channel): """Leak current""" def __init__(self, name: Optional[str] = None): + self.current_is_in_mA_per_cm2 = True + super().__init__(name) prefix = self._name self.channel_params = { @@ -60,8 +62,7 @@ def compute_current( ): """Return current.""" prefix = self._name - # Multiply with 1000 to convert Siemens to milli Siemens. - gLeak = params[f"{prefix}_gLeak"] * 1000 # mS/cm^2 + gLeak = params[f"{prefix}_gLeak"] # S/cm^2 return gLeak * (v - params[f"{prefix}_eLeak"]) def init_state(self, states, v, params, delta_t): @@ -72,6 +73,8 @@ class Na(Channel): """Sodium channel""" def __init__(self, name: Optional[str] = None): + self.current_is_in_mA_per_cm2 = True + super().__init__(name) prefix = self._name self.channel_params = { @@ -103,8 +106,7 @@ def compute_current( prefix = self._name m, h = states[f"{prefix}_m"], states[f"{prefix}_h"] - # Multiply with 1000 to convert Siemens to milli Siemens. - gNa = params[f"{prefix}_gNa"] * (m**3) * h * 1000 # mS/cm^2 + gNa = params[f"{prefix}_gNa"] * (m**3) * h # S/cm^2 current = gNa * (v - params["eNa"]) return current @@ -142,6 +144,8 @@ class K(Channel): """Potassium channel""" def __init__(self, name: Optional[str] = None): + self.current_is_in_mA_per_cm2 = True + super().__init__(name) prefix = self._name self.channel_params = { @@ -172,8 +176,7 @@ def compute_current( prefix = self._name n = states[f"{prefix}_n"] - # Multiply with 1000 to convert Siemens to milli Siemens. - gK = params[f"{prefix}_gK"] * (n**4) * 1000 # mS/cm^2 + gK = params[f"{prefix}_gK"] * (n**4) # S/cm^2 return gK * (v - params["eK"]) @@ -197,6 +200,8 @@ class Km(Channel): """Slow M Potassium channel""" def __init__(self, name: Optional[str] = None): + self.current_is_in_mA_per_cm2 = True + super().__init__(name) prefix = self._name self.channel_params = { @@ -229,8 +234,7 @@ def compute_current( prefix = self._name p = states[f"{prefix}_p"] - # Multiply with 1000 to convert Siemens to milli Siemens. - gKm = params[f"{prefix}_gKm"] * p * 1000 # mS/cm^2 + gKm = params[f"{prefix}_gKm"] * p # S/cm^2 return gKm * (v - params["eK"]) def init_state(self, states, v, params, delta_t): @@ -253,6 +257,8 @@ class CaL(Channel): """L-type Calcium channel""" def __init__(self, name: Optional[str] = None): + self.current_is_in_mA_per_cm2 = True + super().__init__(name) prefix = self._name self.channel_params = { @@ -282,9 +288,7 @@ def compute_current( """Return current.""" prefix = self._name q, r = states[f"{prefix}_q"], states[f"{prefix}_r"] - - # Multiply with 1000 to convert Siemens to milli Siemens. - gCaL = params[f"{prefix}_gCaL"] * (q**2) * r * 1000 # mS/cm^2 + gCaL = params[f"{prefix}_gCaL"] * (q**2) * r # S/cm^2 return gCaL * (v - params["eCa"]) @@ -321,6 +325,8 @@ class CaT(Channel): """T-type Calcium channel""" def __init__(self, name: Optional[str] = None): + self.current_is_in_mA_per_cm2 = True + super().__init__(name) prefix = self._name self.channel_params = { @@ -354,8 +360,7 @@ def compute_current( u = states[f"{prefix}_u"] s_inf = 1.0 / (1.0 + save_exp(-(v + params[f"{prefix}_vx"] + 57.0) / 6.2)) - # Multiply with 1000 to convert Siemens to milli Siemens. - gCaT = params[f"{prefix}_gCaT"] * (s_inf**2) * u * 1000 # mS/cm^2 + gCaT = params[f"{prefix}_gCaT"] * (s_inf**2) * u # S/cm^2 return gCaT * (v - params["eCa"]) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index d00654a9..acf300e2 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -1805,8 +1805,10 @@ def _channel_currents( ) voltage_term = (membrane_currents[1] - membrane_currents[0]) / diff constant_term = membrane_currents[0] - voltage_term * voltages[indices] - voltage_terms = voltage_terms.at[indices].add(voltage_term) - constant_terms = constant_terms.at[indices].add(-constant_term) + + # * 1000 to convert from mA/cm^2 to uA/cm^2. + voltage_terms = voltage_terms.at[indices].add(voltage_term * 1000.0) + constant_terms = constant_terms.at[indices].add(-constant_term * 1000.0) # Save the current (for the unperturbed voltage) as a state that will # also be passed to the state update. diff --git a/tests/test_channels.py b/tests/test_channels.py index 89403626..41024040 100644 --- a/tests/test_channels.py +++ b/tests/test_channels.py @@ -10,11 +10,104 @@ import jax.numpy as jnp import numpy as np import pytest -from jaxley_mech.channels.l5pc import CaNernstReversal, CaPump import jaxley as jx from jaxley.channels import HH, CaL, CaT, Channel, K, Km, Leak, Na -from jaxley.solver_gate import save_exp, solve_inf_gate_exponential +from jaxley.solver_gate import exponential_euler, save_exp, solve_inf_gate_exponential + + +class CaPump(Channel): + """Calcium dynamics tracking inside calcium concentration, modeled after Destexhe et al. 1994.""" + + def __init__( + self, + name: Optional[str] = None, + ): + self.current_is_in_mA_per_cm2 = True + super().__init__(name) + self.channel_params = { + f"{self._name}_gamma": 0.05, # Fraction of free calcium (not buffered) + f"{self._name}_decay": 80, # Rate of removal of calcium in ms + f"{self._name}_depth": 0.1, # Depth of shell in um + f"{self._name}_minCai": 1e-4, # Minimum intracellular calcium concentration in mM + } + self.channel_states = { + f"CaCon_i": 5e-05, # Initial internal calcium concentration in mM + } + self.current_name = f"i_Ca" + self.META = { + "reference": "Modified from Destexhe et al., 1994", + "mechanism": "Calcium dynamics", + } + + def update_states(self, u, dt, voltages, params): + """Update internal calcium concentration based on calcium current and decay.""" + prefix = self._name + ica = u["i_Ca"] / 1_000.0 + cai = u["CaCon_i"] + gamma = params[f"{prefix}_gamma"] + decay = params[f"{prefix}_decay"] + depth = params[f"{prefix}_depth"] + minCai = params[f"{prefix}_minCai"] + + FARADAY = 96485 # Coulombs per mole + + # Calculate the contribution of calcium currents to cai change + drive_channel = -10_000.0 * ica * gamma / (2 * FARADAY * depth) + + cai_tau = decay + cai_inf = minCai + decay * drive_channel + new_cai = exponential_euler(cai, dt, cai_inf, cai_tau) + + return {f"CaCon_i": new_cai} + + def compute_current(self, u, voltages, params): + """This dynamics model does not directly contribute to the membrane current.""" + return 0 + + def init_state(self, states, voltages, params, delta_t): + """Initialize the state at fixed point of gate dynamics.""" + return {} + + +class CaNernstReversal(Channel): + """Compute Calcium reversal from inner and outer concentration of calcium.""" + + def __init__( + self, + name: Optional[str] = None, + ): + self.current_is_in_mA_per_cm2 = True + super().__init__(name) + self.channel_constants = { + "F": 96485.3329, # C/mol (Faraday's constant) + "T": 279.45, # Kelvin (temperature) + "R": 8.314, # J/(mol K) (gas constant) + } + self.channel_params = {} + self.channel_states = {"eCa": 0.0, "CaCon_i": 5e-05, "CaCon_e": 2.0} + self.current_name = f"i_Ca" + + def update_states(self, u, dt, voltages, params): + """Update internal calcium concentration based on calcium current and decay.""" + R, T, F = ( + self.channel_constants["R"], + self.channel_constants["T"], + self.channel_constants["F"], + ) + Cai = u["CaCon_i"] + Cao = u["CaCon_e"] + C = R * T / (2 * F) * 1000 # mV + vCa = C * jnp.log(Cao / Cai) + return {"eCa": vCa, "CaCon_i": Cai, "CaCon_e": Cao} + + def compute_current(self, u, voltages, params): + """This dynamics model does not directly contribute to the membrane current.""" + return 0 + + def init_state(self, states, voltages, params, delta_t): + """Initialize the state at fixed point of gate dynamics.""" + return {} def test_channel_set_name(): @@ -105,6 +198,7 @@ def test_init_states(): class KCA11(Channel): def __init__(self, name: Optional[str] = None): + self.current_is_in_mA_per_cm2 = True super().__init__(name) prefix = self._name self.channel_params = { @@ -196,6 +290,7 @@ class User(Channel): """The channel which uses currents of Dummy1 and Dummy2 to update its states.""" def __init__(self, name: Optional[str] = None): + self.current_is_in_mA_per_cm2 = True super().__init__(name) self.channel_params = {} self.channel_states = {"cumulative": 0.0} @@ -211,6 +306,7 @@ def compute_current(self, states, v, params): class Dummy1(Channel): def __init__(self, name: Optional[str] = None): + self.current_is_in_mA_per_cm2 = True super().__init__(name) self.channel_params = {} self.channel_states = {} @@ -224,6 +320,7 @@ def compute_current(self, states, v, params): class Dummy2(Channel): def __init__(self, name: Optional[str] = None): + self.current_is_in_mA_per_cm2 = True super().__init__(name) self.channel_params = {} self.channel_states = {} diff --git a/tests/test_shared_state.py b/tests/test_shared_state.py index 9f29ff7b..3e7642ce 100644 --- a/tests/test_shared_state.py +++ b/tests/test_shared_state.py @@ -20,6 +20,7 @@ class Dummy1(Channel): """A dummy channel which simply accumulates a state (same state as dummy2).""" def __init__(self, name: Optional[str] = None): + self.current_is_in_mA_per_cm2 = True super().__init__(name) self.channel_params = {} self.channel_states = {"Dummy_s": 0.0} @@ -42,6 +43,7 @@ class Dummy2(Channel): """A dummy channel which simply accumulates a state (same state as dummy1).""" def __init__(self, name: Optional[str] = None): + self.current_is_in_mA_per_cm2 = True super().__init__(name) self.channel_params = {} self.channel_states = {"Dummy_s": 0.0} @@ -64,6 +66,7 @@ class CaHVA(Channel): """High-Voltage-Activated (HVA) Ca2+ channel""" def __init__(self, name: Optional[str] = None): + self.current_is_in_mA_per_cm2 = True super().__init__(name) self.channel_params = { f"{self._name}_gCaHVA": 0.00001, # S/cm^2 @@ -135,6 +138,7 @@ def __init__( self, name: Optional[str] = None, ): + self.current_is_in_mA_per_cm2 = True super().__init__(name) self.channel_params = { f"{self._name}_gamma": 0.05, # Fraction of free calcium (not buffered)