Skip to content

Commit

Permalink
New recommended API
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Nov 22, 2023
1 parent 4455f1f commit 14e6f7d
Show file tree
Hide file tree
Showing 8 changed files with 96 additions and 35 deletions.
27 changes: 21 additions & 6 deletions jaxley/modules/branch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from copy import deepcopy
from typing import Callable, Dict, List, Optional
from typing import Callable, Dict, List, Optional, Union

import jax.numpy as jnp
import numpy as np
Expand All @@ -14,14 +14,29 @@ class Branch(Module):
branch_params: Dict = {}
branch_states: Dict = {}

def __init__(self, compartments: List[Compartment]):
def __init__(
self,
compartments: Union[Compartment, List[Compartment]],
nseg: Optional[int] = None,
):
super().__init__()
assert (
isinstance(compartments, Compartment) or nseg is None
), "If `compartments` is a list then you cannot set `nseg`."
assert (
isinstance(compartments, List) or nseg is not None
), "If `compartments` is not a list then you have to set `nseg`."
self._init_params_and_state(self.branch_params, self.branch_states)
self._append_to_params_and_state(compartments)
for comp in compartments:
if isinstance(compartments, Compartment):
compartment_list = [compartments for _ in range(nseg)]
else:
compartment_list = compartments

self._append_to_params_and_state(compartment_list)
for comp in compartment_list:
self._append_to_channel_params_and_state(comp)

self.nseg = len(compartments)
self.nseg = len(compartment_list)
self.total_nbranches = 1
self.nbranches_per_cell = [1]
self.cumsum_nbranches = jnp.asarray([0, 1])
Expand All @@ -36,7 +51,7 @@ def __init__(self, compartments: List[Compartment]):
)

# Channel indexing.
for i, comp in enumerate(compartments):
for i, comp in enumerate(compartment_list):
index = pd.DataFrame.from_dict(
dict(comp_index=[i], branch_index=[0], cell_index=[0])
)
Expand Down
26 changes: 17 additions & 9 deletions jaxley/modules/cell.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from copy import deepcopy
from typing import Callable, Dict, List, Optional
from typing import Callable, Dict, List, Optional, Union

import jax.numpy as jnp
import numpy as np
Expand All @@ -23,18 +23,26 @@ class Cell(Module):
cell_params: Dict = {}
cell_states: Dict = {}

def __init__(self, branches: List[Branch], parents: List):
def __init__(self, branches: Union[Branch, List[Branch]], parents: List):
super().__init__()
assert isinstance(branches, Branch) or len(parents) == len(
branches
), "If `branches` is a list then you have to provide equally many parents, i.e. len(branches) == len(parents)."
self._init_params_and_state(self.cell_params, self.cell_states)
self._append_to_params_and_state(branches)
for branch in branches:
if isinstance(branches, Branch):
branch_list = [branches for _ in range(len(parents))]
else:
branch_list = branches

self._append_to_params_and_state(branch_list)
for branch in branch_list:
self._append_to_channel_params_and_state(branch)

self.nseg = branches[0].nseg
self.total_nbranches = len(branches)
self.nbranches_per_cell = [len(branches)]
self.nseg = branch_list[0].nseg
self.total_nbranches = len(branch_list)
self.nbranches_per_cell = [len(branch_list)]
self.comb_parents = jnp.asarray(parents)
self.cumsum_nbranches = jnp.asarray([0, len(branches)])
self.cumsum_nbranches = jnp.asarray([0, len(branch_list)])

# Indexing.
self.nodes = pd.DataFrame(
Expand All @@ -48,7 +56,7 @@ def __init__(self, branches: List[Branch], parents: List):
)

# Channel indexing.
for i, branch in enumerate(branches):
for i, branch in enumerate(branch_list):
for channel in branch.channels:
name = type(channel).__name__
comp_inds = deepcopy(
Expand Down
2 changes: 1 addition & 1 deletion tests/neurax_vs_neuron/test_branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def test_similarity_complex():
def _jaxley_complex(i_delay, i_dur, i_amp, dt, t_max, diams):
nseg = 16
comp = jx.Compartment()
branch = jx.Branch([comp for _ in range(nseg)])
branch = jx.Branch(comp, nseg)

branch.insert(HHChannel())

Expand Down
4 changes: 2 additions & 2 deletions tests/neurax_vs_neuron/test_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def test_similarity():
def _run_jaxley(i_delay, i_dur, i_amp, dt, t_max):
nseg_per_branch = 8
comp = jx.Compartment().initialize()
branch = jx.Branch([comp for _ in range(nseg_per_branch)]).initialize()
cell = jx.Cell([branch for _ in range(3)], parents=[-1, 0, 0]).initialize()
branch = jx.Branch(comp, nseg_per_branch).initialize()
cell = jx.Cell(branch, parents=[-1, 0, 0]).initialize()
cell.insert(HHChannel())

cell.set_params("radius", 5.0)
Expand Down
42 changes: 42 additions & 0 deletions tests/test_api_equivalence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import jax

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")

import jax.numpy as jnp

import jaxley as jx


def test_api_equivalence():
"""Test the API for recording and stimulating."""
nseg_per_branch = 2
depth = 2
dt = 0.025

parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)]
parents = jnp.asarray(parents)
num_branches = len(parents)

comp = jx.Compartment().initialize()

branch1 = jx.Branch([comp for _ in range(nseg_per_branch)]).initialize()
cell1 = jx.Cell(
[branch1 for _ in range(num_branches)], parents=parents
).initialize()

branch2 = jx.Branch(comp, nseg=nseg_per_branch).initialize()
cell2 = jx.Cell(branch2, parents=parents).initialize()

cell1.branch(2).comp(0.4).record()
cell2.branch(2).comp(0.4).record()

current = jx.step_current(0.5, 1.0, 1.0, dt, 3.0)
cell1.branch(1).comp(1.0).stimulate(current)
cell2.branch(1).comp(1.0).stimulate(current)

voltages1 = jx.integrate(cell1, delta_t=dt)
voltages2 = jx.integrate(cell2, delta_t=dt)
assert (
jnp.max(jnp.abs(voltages1 - voltages2)) < 1e-8
), "Voltages do not match between APIs."
6 changes: 3 additions & 3 deletions tests/test_cell_matches_branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def _run_long_branch(dt, t_max):
nseg_per_branch = 8

comp = jx.Compartment()
branch = jx.Branch([comp for _ in range(nseg_per_branch)])
branch = jx.Branch(comp, nseg_per_branch)
branch.insert(HHChannel())

branch.comp("all").make_trainable("radius", 1.0)
Expand All @@ -40,8 +40,8 @@ def _run_short_branches(dt, t_max):
parents = jnp.asarray([-1, 0])

comp = jx.Compartment()
branch = jx.Branch([comp for _ in range(nseg_per_branch)])
cell = jx.Cell([branch for _ in range(2)], parents=parents)
branch = jx.Branch(comp, nseg_per_branch)
cell = jx.Cell(branch, parents=parents)
cell.insert(HHChannel())

cell.branch("all").comp("all").make_trainable("radius", 1.0)
Expand Down
14 changes: 6 additions & 8 deletions tests/test_make_trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,18 @@ def test_make_trainable():
depth = 5
parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)]
parents = jnp.asarray(parents)
num_branches = len(parents)

comp = jx.Compartment().initialize()
branch = jx.Branch([comp for _ in range(nseg_per_branch)]).initialize()
cell = jx.Cell([branch for _ in range(num_branches)], parents=parents).initialize()
branch = jx.Branch(comp, nseg_per_branch).initialize()
cell = jx.Cell(branch, parents=parents).initialize()
cell.insert(HHChannel())

cell.branch(0).comp(0.0).set_params("length", 12.0)
cell.branch(1).comp(1.0).set_params("gNa", 0.2)
assert cell.num_trainable_params == 2
assert cell.num_trainable_params == 0

cell.branch([0, 1]).make_trainable("radius", 1.0)
assert cell.num_trainable_params == 4
assert cell.num_trainable_params == 2
cell.branch([0, 1]).make_trainable("length")
cell.branch([0, 1]).make_trainable("axial_resistivity", [600.0, 700.0])
cell.branch([0, 1]).make_trainable("gNa")
Expand All @@ -44,11 +43,10 @@ def test_make_trainable_network():
depth = 5
parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)]
parents = jnp.asarray(parents)
num_branches = len(parents)

comp = jx.Compartment().initialize()
branch = jx.Branch([comp for _ in range(nseg_per_branch)]).initialize()
cell = jx.Cell([branch for _ in range(num_branches)], parents=parents).initialize()
branch = jx.Branch(comp, nseg_per_branch).initialize()
cell = jx.Cell(branch, parents=parents).initialize()
cell.insert(HHChannel())

conns = [
Expand Down
10 changes: 4 additions & 6 deletions tests/test_record_and_stimulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@ def test_record_and_stimulate_api():
depth = 2
parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)]
parents = jnp.asarray(parents)
num_branches = len(parents)

comp = jx.Compartment().initialize()
branch = jx.Branch([comp for _ in range(nseg_per_branch)]).initialize()
cell = jx.Cell([branch for _ in range(num_branches)], parents=parents).initialize()
branch = jx.Branch(comp, nseg_per_branch).initialize()
cell = jx.Cell(branch, parents=parents).initialize()

cell.branch(0).comp(0.0).record()
cell.branch(1).comp(1.0).record()
Expand All @@ -36,11 +35,10 @@ def test_record_shape():
depth = 2
parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)]
parents = jnp.asarray(parents)
num_branches = len(parents)

comp = jx.Compartment().initialize()
branch = jx.Branch([comp for _ in range(nseg_per_branch)]).initialize()
cell = jx.Cell([branch for _ in range(num_branches)], parents=parents).initialize()
branch = jx.Branch(comp, nseg_per_branch).initialize()
cell = jx.Cell(branch, parents=parents).initialize()

current = jx.step_current(0.0, 1.0, 1.0, 0.025, 3.0)
cell.branch(1).comp(1.0).stimulate(current)
Expand Down

0 comments on commit 14e6f7d

Please sign in to comment.