Skip to content

Commit

Permalink
update all tutorials
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Nov 30, 2023
1 parent ca99b9b commit 069ccc9
Show file tree
Hide file tree
Showing 6 changed files with 205 additions and 244 deletions.
52 changes: 23 additions & 29 deletions tutorials/01_small_network.ipynb

Large diffs are not rendered by default.

85 changes: 28 additions & 57 deletions tutorials/02_setting_parameters.ipynb

Large diffs are not rendered by default.

132 changes: 65 additions & 67 deletions tutorials/03_gradient.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"outputs": [],
"source": [
"# I have experienced stability issues with float32.\n",
"from jax.config import config\n",
"from jax import config\n",
"config.update(\"jax_enable_x64\", True)\n",
"config.update(\"jax_platform_name\", \"cpu\")\n",
"\n",
Expand All @@ -44,8 +44,6 @@
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import jax\n",
"import jax.numpy as jnp\n",
Expand Down Expand Up @@ -128,7 +126,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 11,
"id": "ff784bcb",
"metadata": {},
"outputs": [],
Expand All @@ -138,7 +136,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 12,
"id": "222f9a00",
"metadata": {},
"outputs": [],
Expand All @@ -148,15 +146,15 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 13,
"id": "90affb0c-dc77-47c7-be3c-18e2829cc820",
"metadata": {},
"outputs": [],
"source": [
"for cell_ind in range(5):\n",
" network.cell(cell_ind).branch(1).comp(0.0).record()\n",
" \n",
"current = jx.step_current(i_delay, i_dur, i_amp, time_vec)\n",
"current = jx.step_current(i_delay, i_dur, i_amp, dt, t_max)\n",
"for stim_ind in range(2):\n",
" network.cell(stim_ind).branch(1).comp(0.0).stimulate(current)"
]
Expand All @@ -179,10 +177,18 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 14,
"id": "10cb5b1e",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of newly added trainable parameters: 1. Total number of trainable parameters: 1\n"
]
}
],
"source": [
"network.make_trainable(\"radius\")"
]
Expand All @@ -197,10 +203,18 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 15,
"id": "c90be7f3",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of newly added trainable parameters: 200. Total number of trainable parameters: 201\n"
]
}
],
"source": [
"network.cell(\"all\").branch(\"all\").comp(\"all\").make_trainable(\"gNa\")"
]
Expand All @@ -223,10 +237,18 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 16,
"id": "f31901bd",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of newly added trainable parameters: 1. Total number of trainable parameters: 202\n"
]
}
],
"source": [
"network.make_trainable(\"gS\")"
]
Expand All @@ -241,10 +263,18 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 17,
"id": "12fe7828",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of newly added trainable parameters: 6. Total number of trainable parameters: 208\n"
]
}
],
"source": [
"network.GlutamateSynapse(\"all\").make_trainable(\"gS\")"
]
Expand All @@ -267,7 +297,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 18,
"id": "40a48eea",
"metadata": {},
"outputs": [],
Expand All @@ -286,7 +316,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 19,
"id": "4eb3f8f1",
"metadata": {},
"outputs": [],
Expand All @@ -312,7 +342,7 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 20,
"id": "a29f1ac2",
"metadata": {},
"outputs": [],
Expand All @@ -332,7 +362,7 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": 21,
"id": "f38d61a9",
"metadata": {},
"outputs": [],
Expand All @@ -342,7 +372,7 @@
},
{
"cell_type": "code",
"execution_count": 26,
"execution_count": 22,
"id": "9ac97e04",
"metadata": {},
"outputs": [],
Expand All @@ -362,22 +392,10 @@
},
{
"cell_type": "code",
"execution_count": 27,
"execution_count": 23,
"id": "d9ccf1b6",
"metadata": {},
"outputs": [
{
"ename": "ModuleNotFoundError",
"evalue": "No module named 'optax'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m/var/folders/j9/w7ftvg_16t1f9bgp1cy4bt1r0000gn/T/ipykernel_6746/3781452166.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0moptax\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'optax'"
]
}
],
"outputs": [],
"source": [
"import optax"
]
Expand All @@ -392,7 +410,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 24,
"id": "710a1545",
"metadata": {},
"outputs": [],
Expand All @@ -413,7 +431,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 25,
"id": "800f959e",
"metadata": {},
"outputs": [],
Expand All @@ -437,51 +455,31 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 26,
"id": "9d639efa",
"metadata": {},
"outputs": [],
"source": [
"opt_params = transform.inverse(params)\n",
"optimizer = optax.adam(learning_rate=1e-2)\n",
"optimizer = optax.adam(learning_rate=1e-1)\n",
"opt_state = optimizer.init(opt_params)"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 27,
"id": "0e4aebd0-283e-4165-8c24-4b6fb811135e",
"metadata": {},
"outputs": [],
"source": [
"epoch_losses = []\n",
"\n",
"for epoch in range(5):\n",
" loss_val, gradient = jitted_grad(opt_params)\n",
" updates, opt_state = optimizer.update(gradient, opt_state)\n",
" opt_params = optax.apply_updates(opt_params, updates)\n",
"\n",
" print(f\"epoch {epoch}, loss {loss_val}\")\n",
" epoch_losses.append(loss_val)\n",
" \n",
"final_params = transform.forward(opt_params)"
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "134af3e1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 0, loss -64.97740510297487\n",
"epoch 1, loss -64.98296369502924\n",
"epoch 2, loss -64.98846441030534\n",
"epoch 3, loss -64.99390672049375\n",
"epoch 4, loss -64.99929015505666\n"
"epoch 0, loss -64.97740512171988\n",
"epoch 1, loss -65.03050878143252\n",
"epoch 2, loss -65.07820463321355\n",
"epoch 3, loss -65.12091871011332\n",
"epoch 4, loss -65.15909222980945\n"
]
}
],
Expand Down Expand Up @@ -510,9 +508,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "jax",
"display_name": "neurax",
"language": "python",
"name": "jax"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -524,7 +522,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.12"
"version": "3.10.11"
}
},
"nbformat": 4,
Expand Down
38 changes: 27 additions & 11 deletions tutorials/04_groups.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"outputs": [],
"source": [
"# I have experienced stability issues with float32.\n",
"from jax.config import config\n",
"from jax import config\n",
"config.update(\"jax_enable_x64\", True)\n",
"config.update(\"jax_platform_name\", \"cpu\")\n",
"\n",
Expand Down Expand Up @@ -618,7 +618,15 @@
"execution_count": 14,
"id": "3a399a56",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of newly added trainable parameters: 1. Total number of trainable parameters: 1\n"
]
}
],
"source": [
"network.fast_spiking.make_trainable(\"gNa\")"
]
Expand All @@ -640,7 +648,7 @@
{
"data": {
"text/plain": [
"[{'gNa': DeviceArray([[0.4]], dtype=float64)}]"
"[{'gNa': Array([[0.4]], dtype=float64)}]"
]
},
"execution_count": 15,
Expand All @@ -665,7 +673,15 @@
"execution_count": 16,
"id": "99a6c389",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of newly added trainable parameters: 3. Total number of trainable parameters: 4\n"
]
}
],
"source": [
"network.cell([0,1,3]).make_trainable(\"axial_resistivity\")"
]
Expand All @@ -679,10 +695,10 @@
{
"data": {
"text/plain": [
"[{'gNa': DeviceArray([[0.4]], dtype=float64)},\n",
" {'axial_resistivity': DeviceArray([[5000.],\n",
" [5000.],\n",
" [5000.]], dtype=float64)}]"
"[{'gNa': Array([[0.4]], dtype=float64)},\n",
" {'axial_resistivity': Array([[5000.],\n",
" [5000.],\n",
" [5000.]], dtype=float64)}]"
]
},
"execution_count": 17,
Expand All @@ -705,9 +721,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "jax",
"display_name": "neurax",
"language": "python",
"name": "jax"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -719,7 +735,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.12"
"version": "3.10.11"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit 069ccc9

Please sign in to comment.