Skip to content

Commit

Permalink
rename neurax -> jaxley in tutorials
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Nov 21, 2023
1 parent f682da0 commit 5ae9134
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 101 deletions.
34 changes: 17 additions & 17 deletions tutorials/01_small_network.ipynb

Large diffs are not rendered by default.

36 changes: 18 additions & 18 deletions tutorials/02_setting_parameters.ipynb

Large diffs are not rendered by default.

32 changes: 16 additions & 16 deletions tutorials/03_gradient.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"source": [
"# Obtaining the gradient and training (optimizing) the parameters\n",
"\n",
"In this tutorial, we will describe how you can use JAX's automatic differentiation to obtain gradients through `neurax` simulations and how you can use optimize the parameters with the Adam optimizer."
"In this tutorial, we will describe how you can use JAX's automatic differentiation to obtain gradients through `jaxley` simulations and how you can use optimize the parameters with the Adam optimizer."
]
},
{
Expand Down Expand Up @@ -51,9 +51,9 @@
"import jax.numpy as jnp\n",
"from jax import jit, value_and_grad\n",
"\n",
"import neurax as nx\n",
"from neurax.channels import HHChannel\n",
"from neurax.synapses import GlutamateSynapse"
"import jaxley as jx\n",
"from jaxley.channels import HHChannel\n",
"from jaxley.synapses import GlutamateSynapse"
]
},
{
Expand Down Expand Up @@ -109,9 +109,9 @@
"metadata": {},
"outputs": [],
"source": [
"comp = nx.Compartment()\n",
"branch = nx.Branch([comp for _ in range(nseg_per_branch)])\n",
"cell = nx.Cell([branch for _ in range(5)], parents=jnp.asarray([-1, 0, 0, 1, 1]))"
"comp = jx.Compartment()\n",
"branch = jx.Branch([comp for _ in range(nseg_per_branch)])\n",
"cell = jx.Cell([branch for _ in range(5)], parents=jnp.asarray([-1, 0, 0, 1, 1]))"
]
},
{
Expand All @@ -122,8 +122,8 @@
"outputs": [],
"source": [
"_ = np.random.seed(0)\n",
"conn_builder = nx.ConnectivityBuilder([cell.total_nbranches for _ in range(5)])\n",
"connectivities = [nx.Connectivity(GlutamateSynapse(), conn_builder.fc(np.arange(0, 2), np.arange(2, 5)))]"
"conn_builder = jx.ConnectivityBuilder([cell.total_nbranches for _ in range(5)])\n",
"connectivities = [jx.Connectivity(GlutamateSynapse(), conn_builder.fc(np.arange(0, 2), np.arange(2, 5)))]"
]
},
{
Expand All @@ -133,7 +133,7 @@
"metadata": {},
"outputs": [],
"source": [
"network = nx.Network([cell for _ in range(5)], connectivities)"
"network = jx.Network([cell for _ in range(5)], connectivities)"
]
},
{
Expand All @@ -156,7 +156,7 @@
"for cell_ind in range(5):\n",
" network.cell(cell_ind).branch(1).comp(0.0).record()\n",
" \n",
"current = nx.step_current(i_delay, i_dur, i_amp, time_vec)\n",
"current = jx.step_current(i_delay, i_dur, i_amp, time_vec)\n",
"for stim_ind in range(2):\n",
" network.cell(stim_ind).branch(1).comp(0.0).stimulate(current)"
]
Expand Down Expand Up @@ -281,7 +281,7 @@
"id": "cf68cf64",
"metadata": {},
"source": [
"You can now run the simulation with the trainable parameters by passing them to the `nx.integrate` function."
"You can now run the simulation with the trainable parameters by passing them to the `jx.integrate` function."
]
},
{
Expand All @@ -291,7 +291,7 @@
"metadata": {},
"outputs": [],
"source": [
"s = nx.integrate(network, delta_t=dt, params=params)"
"s = jx.integrate(network, delta_t=dt, params=params)"
]
},
{
Expand All @@ -318,7 +318,7 @@
"outputs": [],
"source": [
"def loss(params):\n",
" s = nx.integrate(network, delta_t=dt, params=params)\n",
" s = jx.integrate(network, delta_t=dt, params=params)\n",
" return jnp.sum(s[0, -1])"
]
},
Expand Down Expand Up @@ -397,7 +397,7 @@
"metadata": {},
"outputs": [],
"source": [
"transform = nx.ParamTransform(\n",
"transform = jx.ParamTransform(\n",
" lowers={\"gNa\": 0.05, \"gK\": 0.01, \"gLeak\": 0.0001, \"radius\": 0.1, \"length\": 1.0, \"axial_resistivity\": 500.0, \"gS\": 0.01}, \n",
" uppers={\"gNa\": 1.1, \"gK\": 0.3, \"gLeak\": 0.001, \"radius\": 5.0, \"length\": 20.0, \"axial_resistivity\": 5500.0, \"gS\": 5.0}, \n",
")"
Expand All @@ -420,7 +420,7 @@
"source": [
"def loss(params):\n",
" params = transform.forward(params)\n",
" s = nx.integrate(network, delta_t=dt, params=params)\n",
" s = jx.integrate(network, delta_t=dt, params=params)\n",
" return jnp.sum(s[0, -1])\n",
"\n",
"jitted_grad = jit(value_and_grad(loss))\n",
Expand Down
18 changes: 9 additions & 9 deletions tutorials/04_groups.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@
"import jax.numpy as jnp\n",
"from jax import jit, value_and_grad\n",
"\n",
"import neurax as nx\n",
"from neurax.channels import HHChannel\n",
"from neurax.synapses import GlutamateSynapse"
"import jaxley as jx\n",
"from jaxley.channels import HHChannel\n",
"from jaxley.synapses import GlutamateSynapse"
]
},
{
Expand Down Expand Up @@ -101,9 +101,9 @@
"metadata": {},
"outputs": [],
"source": [
"comp = nx.Compartment()\n",
"branch = nx.Branch([comp for _ in range(nseg_per_branch)])\n",
"cell = nx.Cell([branch for _ in range(5)], parents=jnp.asarray([-1, 0, 0, 1, 1]))"
"comp = jx.Compartment()\n",
"branch = jx.Branch([comp for _ in range(nseg_per_branch)])\n",
"cell = jx.Cell([branch for _ in range(5)], parents=jnp.asarray([-1, 0, 0, 1, 1]))"
]
},
{
Expand All @@ -114,8 +114,8 @@
"outputs": [],
"source": [
"_ = np.random.seed(0)\n",
"conn_builder = nx.ConnectivityBuilder([cell.total_nbranches for _ in range(5)])\n",
"connectivities = [nx.Connectivity(GlutamateSynapse(), conn_builder.fc(np.arange(0, 2), np.arange(2, 5)))]"
"conn_builder = jx.ConnectivityBuilder([cell.total_nbranches for _ in range(5)])\n",
"connectivities = [jx.Connectivity(GlutamateSynapse(), conn_builder.fc(np.arange(0, 2), np.arange(2, 5)))]"
]
},
{
Expand All @@ -125,7 +125,7 @@
"metadata": {},
"outputs": [],
"source": [
"network = nx.Network([cell for _ in range(5)], connectivities)\n",
"network = jx.Network([cell for _ in range(5)], connectivities)\n",
"network.insert(HHChannel())"
]
},
Expand Down
Loading

0 comments on commit 5ae9134

Please sign in to comment.