Skip to content

Commit

Permalink
fix units of compute_current (#461)
Browse files Browse the repository at this point in the history
* fix units of compute_current

* Enforce new unit convention

* Add raise and warn for old channel models

* add option to use `current_is_in_mA_per_cm2=False`

* Add test

* Remove warning

* fixups
  • Loading branch information
michaeldeistler authored Oct 24, 2024
1 parent 61c5c27 commit 5ce7c45
Show file tree
Hide file tree
Showing 7 changed files with 176 additions and 48 deletions.
54 changes: 27 additions & 27 deletions docs/tutorials/05_channel_and_synapse_models.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "markdown",
"id": "c19e95c6",
"id": "84d5d597",
"metadata": {},
"source": [
"# Building and using ion channel models\n",
Expand All @@ -17,7 +17,7 @@
{
"cell_type": "code",
"execution_count": 1,
"id": "11544a58",
"id": "a9a16457",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -36,7 +36,7 @@
},
{
"cell_type": "markdown",
"id": "e141cff7-ffc1-4e4a-bfe5-2b87792de748",
"id": "1a91e1c7",
"metadata": {},
"source": [
"First, we define a cell as you saw in the [previous tutorial](https://jaxleyverse.github.io/jaxley/latest/tutorial/01_morph_neurons/):"
Expand All @@ -45,7 +45,7 @@
{
"cell_type": "code",
"execution_count": 2,
"id": "b308ed28-193c-4c15-be6c-019d81c03883",
"id": "082c1774",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -56,7 +56,7 @@
},
{
"cell_type": "markdown",
"id": "4257adb3-df24-4a03-982c-0ae0983067b1",
"id": "9c181374",
"metadata": {},
"source": [
"You have also already learned how to insert preconfigured channels into `Jaxley` models:\n",
Expand All @@ -71,7 +71,7 @@
},
{
"cell_type": "markdown",
"id": "b4521def",
"id": "fc631157",
"metadata": {},
"source": [
"### Your own channel\n",
Expand All @@ -81,7 +81,7 @@
{
"cell_type": "code",
"execution_count": 3,
"id": "6edc6581",
"id": "e8d96197",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -97,6 +97,7 @@
" \"\"\"Potassium channel.\"\"\"\n",
"\n",
" def __init__(self, name = None):\n",
" self.current_is_in_mA_per_cm2 = True\n",
" super().__init__(name)\n",
" self.channel_params = {\"gK_new\": 1e-4}\n",
" self.channel_states = {\"n_new\": 0.0}\n",
Expand All @@ -113,9 +114,7 @@
" def compute_current(self, states, v, params):\n",
" \"\"\"Return current.\"\"\"\n",
" ns = states[\"n_new\"]\n",
" \n",
" # Multiply with 1000 to convert Siemens to milli Siemens.\n",
" kd_conds = params[\"gK_new\"] * ns**4 * 1000 # mS/cm^2\n",
" kd_conds = params[\"gK_new\"] * ns**4 # S/cm^2\n",
"\n",
" e_kd = -77.0 \n",
" return kd_conds * (v - e_kd)\n",
Expand All @@ -128,7 +127,7 @@
},
{
"cell_type": "markdown",
"id": "204afb5b",
"id": "c7812bf1",
"metadata": {},
"source": [
"Let's look at each part of this in detail. \n",
Expand All @@ -139,12 +138,13 @@
" return x / (jnp.exp(x / y) - 1.0)\n",
"```\n",
"\n",
"Next, we define our channel as a class. It should inherit from the `Channel` class and define `channel_params`, `channel_states`, and `current_name`.\n",
"Next, we define our channel as a class. It should inherit from the `Channel` class and define `channel_params`, `channel_states`, and `current_name`. You also need to set `self.current_is_in_mA_per_cm2=True` as the first line on your `__init__()` method. This is to acknowledge that your current is returned in `mA/cm2` (not in `uA/cm2`, as would have been required in Jaxley versions 0.4.0 or older).\n",
"```python\n",
"class Potassium(Channel):\n",
" \"\"\"Potassium channel.\"\"\"\n",
"\n",
" def __init__(self, name=None):\n",
" self.current_is_in_mA_per_cm2 = True\n",
" super().__init__(name)\n",
" self.channel_params = {\"gK_new\": 1e-4}\n",
" self.channel_states = {\"n_new\": 0.0}\n",
Expand Down Expand Up @@ -175,7 +175,7 @@
" ns = states[\"n_new\"]\n",
" \n",
" # Multiply with 1000 to convert Siemens to milli Siemens.\n",
" kd_conds = params[\"gK_new\"] * ns**4 * 1000 # mS/cm^2\n",
" kd_conds = params[\"gK_new\"] * ns**4 # S/cm^2\n",
"\n",
" e_kd = -77.0 \n",
" current = kd_conds * (v - e_kd)\n",
Expand All @@ -187,7 +187,7 @@
},
{
"cell_type": "markdown",
"id": "c6c17ea2",
"id": "794f6279",
"metadata": {},
"source": [
"Alright, done! We can now insert this channel into any `jx.Module` such as our cell:"
Expand All @@ -196,7 +196,7 @@
{
"cell_type": "code",
"execution_count": 4,
"id": "8bc79fc5",
"id": "94c270e3",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -206,7 +206,7 @@
{
"cell_type": "code",
"execution_count": 5,
"id": "23447b7e-7554-4e82-b702-db0cb6b517ca",
"id": "313ee72f",
"metadata": {},
"outputs": [
{
Expand All @@ -230,7 +230,7 @@
{
"cell_type": "code",
"execution_count": 6,
"id": "f4c678f2",
"id": "08a81449",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -240,7 +240,7 @@
{
"cell_type": "code",
"execution_count": 7,
"id": "1a10abc0",
"id": "4318f1bc",
"metadata": {},
"outputs": [
{
Expand All @@ -264,7 +264,7 @@
},
{
"cell_type": "markdown",
"id": "43c9f1ea",
"id": "e302590d",
"metadata": {},
"source": [
"### Your own synapse\n",
Expand All @@ -277,7 +277,7 @@
{
"cell_type": "code",
"execution_count": 9,
"id": "62c719f9",
"id": "5c68fee6",
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -308,7 +308,7 @@
},
{
"cell_type": "markdown",
"id": "0de23892-4147-4ca2-a39d-97e4985c8e27",
"id": "d608b85c",
"metadata": {},
"source": [
"As you can see above, synapses follow closely how channels are defined. The main difference is that the `compute_current` method takes two voltages: the pre-synaptic voltage (a `jnp.ndarray` of shape `()`) and the post-synaptic voltage (a `jnp.ndarray` of shape `()`)."
Expand All @@ -317,7 +317,7 @@
{
"cell_type": "code",
"execution_count": 10,
"id": "97ee3c49",
"id": "0652ace1",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -327,7 +327,7 @@
{
"cell_type": "code",
"execution_count": 11,
"id": "a8c9428c-d0dc-4892-966e-11a43aea216d",
"id": "69dde19f",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -341,7 +341,7 @@
{
"cell_type": "code",
"execution_count": 12,
"id": "fc6fd5a2",
"id": "4f19d456",
"metadata": {},
"outputs": [
{
Expand All @@ -364,7 +364,7 @@
{
"cell_type": "code",
"execution_count": 13,
"id": "867f4f38",
"id": "3e955552",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -374,7 +374,7 @@
{
"cell_type": "code",
"execution_count": 14,
"id": "01782d18",
"id": "9cfc8f33",
"metadata": {},
"outputs": [
{
Expand All @@ -398,7 +398,7 @@
},
{
"cell_type": "markdown",
"id": "5b944bee-bf80-4f6b-a5d2-5cd119e6d4a4",
"id": "e3f88051",
"metadata": {},
"source": [
"That's it! You are now ready to build your own custom simulations and equip them with channel and synapse models!\n",
Expand Down
19 changes: 19 additions & 0 deletions jaxley/channels/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from abc import ABC, abstractmethod
from typing import Dict, Optional, Tuple
from warnings import warn

import jax.numpy as jnp

Expand All @@ -20,6 +21,24 @@ class Channel:
current_name = None

def __init__(self, name: Optional[str] = None):
contact = (
"If you have any questions, please reach out via email to "
"[email protected] or create an issue on Github: "
"https://github.com/jaxleyverse/jaxley/issues. Thank you!"
)
if (
not hasattr(self, "current_is_in_mA_per_cm2")
or not self.current_is_in_mA_per_cm2
):
raise ValueError(
"The channel you are using is deprecated. "
"In Jaxley version 0.5.0, we changed the unit of the current returned "
"by `compute_current` of channels from `uA/cm^2` to `mA/cm^2`. Please "
"update your channel model (by dividing the resulting current by 1000) "
"and set `self.current_is_in_mA_per_cm2=True` as the first line "
f"in the `__init__()` method of your channel. {contact}"
)

self._name = name if name else self.__class__.__name__

@property
Expand Down
9 changes: 5 additions & 4 deletions jaxley/channels/hh.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ class HH(Channel):
"""Hodgkin-Huxley channel."""

def __init__(self, name: Optional[str] = None):
self.current_is_in_mA_per_cm2 = True

super().__init__(name)
prefix = self._name
self.channel_params = {
Expand Down Expand Up @@ -52,10 +54,9 @@ def compute_current(
prefix = self._name
m, h, n = states[f"{prefix}_m"], states[f"{prefix}_h"], states[f"{prefix}_n"]

# Multiply with 1000 to convert Siemens to milli Siemens.
gNa = params[f"{prefix}_gNa"] * (m**3) * h * 1000 # mS/cm^2
gK = params[f"{prefix}_gK"] * n**4 * 1000 # mS/cm^2
gLeak = params[f"{prefix}_gLeak"] * 1000 # mS/cm^2
gNa = params[f"{prefix}_gNa"] * (m**3) * h # S/cm^2
gK = params[f"{prefix}_gK"] * n**4 # S/cm^2
gLeak = params[f"{prefix}_gLeak"] # S/cm^2

return (
gNa * (v - params[f"{prefix}_eNa"])
Expand Down
Loading

0 comments on commit 5ce7c45

Please sign in to comment.