Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New recommended API #173

Merged
merged 1 commit into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions jaxley/modules/branch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from copy import deepcopy
from typing import Callable, Dict, List, Optional
from typing import Callable, Dict, List, Optional, Union

import jax.numpy as jnp
import numpy as np
Expand All @@ -14,14 +14,29 @@ class Branch(Module):
branch_params: Dict = {}
branch_states: Dict = {}

def __init__(self, compartments: List[Compartment]):
def __init__(
self,
compartments: Union[Compartment, List[Compartment]],
nseg: Optional[int] = None,
):
super().__init__()
assert (
isinstance(compartments, Compartment) or nseg is None
), "If `compartments` is a list then you cannot set `nseg`."
assert (
isinstance(compartments, List) or nseg is not None
), "If `compartments` is not a list then you have to set `nseg`."
self._init_params_and_state(self.branch_params, self.branch_states)
self._append_to_params_and_state(compartments)
for comp in compartments:
if isinstance(compartments, Compartment):
compartment_list = [compartments for _ in range(nseg)]
else:
compartment_list = compartments

self._append_to_params_and_state(compartment_list)
for comp in compartment_list:
self._append_to_channel_params_and_state(comp)

self.nseg = len(compartments)
self.nseg = len(compartment_list)
self.total_nbranches = 1
self.nbranches_per_cell = [1]
self.cumsum_nbranches = jnp.asarray([0, 1])
Expand All @@ -36,7 +51,7 @@ def __init__(self, compartments: List[Compartment]):
)

# Channel indexing.
for i, comp in enumerate(compartments):
for i, comp in enumerate(compartment_list):
index = pd.DataFrame.from_dict(
dict(comp_index=[i], branch_index=[0], cell_index=[0])
)
Expand Down
26 changes: 17 additions & 9 deletions jaxley/modules/cell.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from copy import deepcopy
from typing import Callable, Dict, List, Optional
from typing import Callable, Dict, List, Optional, Union

import jax.numpy as jnp
import numpy as np
Expand All @@ -23,18 +23,26 @@ class Cell(Module):
cell_params: Dict = {}
cell_states: Dict = {}

def __init__(self, branches: List[Branch], parents: List):
def __init__(self, branches: Union[Branch, List[Branch]], parents: List):
super().__init__()
assert isinstance(branches, Branch) or len(parents) == len(
branches
), "If `branches` is a list then you have to provide equally many parents, i.e. len(branches) == len(parents)."
self._init_params_and_state(self.cell_params, self.cell_states)
self._append_to_params_and_state(branches)
for branch in branches:
if isinstance(branches, Branch):
branch_list = [branches for _ in range(len(parents))]
else:
branch_list = branches

self._append_to_params_and_state(branch_list)
for branch in branch_list:
self._append_to_channel_params_and_state(branch)

self.nseg = branches[0].nseg
self.total_nbranches = len(branches)
self.nbranches_per_cell = [len(branches)]
self.nseg = branch_list[0].nseg
self.total_nbranches = len(branch_list)
self.nbranches_per_cell = [len(branch_list)]
self.comb_parents = jnp.asarray(parents)
self.cumsum_nbranches = jnp.asarray([0, len(branches)])
self.cumsum_nbranches = jnp.asarray([0, len(branch_list)])

# Indexing.
self.nodes = pd.DataFrame(
Expand All @@ -48,7 +56,7 @@ def __init__(self, branches: List[Branch], parents: List):
)

# Channel indexing.
for i, branch in enumerate(branches):
for i, branch in enumerate(branch_list):
for channel in branch.channels:
name = type(channel).__name__
comp_inds = deepcopy(
Expand Down
2 changes: 1 addition & 1 deletion tests/neurax_vs_neuron/test_branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def test_similarity_complex():
def _jaxley_complex(i_delay, i_dur, i_amp, dt, t_max, diams):
nseg = 16
comp = jx.Compartment()
branch = jx.Branch([comp for _ in range(nseg)])
branch = jx.Branch(comp, nseg)

branch.insert(HHChannel())

Expand Down
4 changes: 2 additions & 2 deletions tests/neurax_vs_neuron/test_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def test_similarity():
def _run_jaxley(i_delay, i_dur, i_amp, dt, t_max):
nseg_per_branch = 8
comp = jx.Compartment().initialize()
branch = jx.Branch([comp for _ in range(nseg_per_branch)]).initialize()
cell = jx.Cell([branch for _ in range(3)], parents=[-1, 0, 0]).initialize()
branch = jx.Branch(comp, nseg_per_branch).initialize()
cell = jx.Cell(branch, parents=[-1, 0, 0]).initialize()
cell.insert(HHChannel())

cell.set_params("radius", 5.0)
Expand Down
42 changes: 42 additions & 0 deletions tests/test_api_equivalence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import jax

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")

import jax.numpy as jnp

import jaxley as jx


def test_api_equivalence():
"""Test the API for recording and stimulating."""
nseg_per_branch = 2
depth = 2
dt = 0.025

parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)]
parents = jnp.asarray(parents)
num_branches = len(parents)

comp = jx.Compartment().initialize()

branch1 = jx.Branch([comp for _ in range(nseg_per_branch)]).initialize()
cell1 = jx.Cell(
[branch1 for _ in range(num_branches)], parents=parents
).initialize()

branch2 = jx.Branch(comp, nseg=nseg_per_branch).initialize()
cell2 = jx.Cell(branch2, parents=parents).initialize()

cell1.branch(2).comp(0.4).record()
cell2.branch(2).comp(0.4).record()

current = jx.step_current(0.5, 1.0, 1.0, dt, 3.0)
cell1.branch(1).comp(1.0).stimulate(current)
cell2.branch(1).comp(1.0).stimulate(current)

voltages1 = jx.integrate(cell1, delta_t=dt)
voltages2 = jx.integrate(cell2, delta_t=dt)
assert (
jnp.max(jnp.abs(voltages1 - voltages2)) < 1e-8
), "Voltages do not match between APIs."
6 changes: 3 additions & 3 deletions tests/test_cell_matches_branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def _run_long_branch(dt, t_max):
nseg_per_branch = 8

comp = jx.Compartment()
branch = jx.Branch([comp for _ in range(nseg_per_branch)])
branch = jx.Branch(comp, nseg_per_branch)
branch.insert(HHChannel())

branch.comp("all").make_trainable("radius", 1.0)
Expand All @@ -40,8 +40,8 @@ def _run_short_branches(dt, t_max):
parents = jnp.asarray([-1, 0])

comp = jx.Compartment()
branch = jx.Branch([comp for _ in range(nseg_per_branch)])
cell = jx.Cell([branch for _ in range(2)], parents=parents)
branch = jx.Branch(comp, nseg_per_branch)
cell = jx.Cell(branch, parents=parents)
cell.insert(HHChannel())

cell.branch("all").comp("all").make_trainable("radius", 1.0)
Expand Down
14 changes: 6 additions & 8 deletions tests/test_make_trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,18 @@ def test_make_trainable():
depth = 5
parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)]
parents = jnp.asarray(parents)
num_branches = len(parents)

comp = jx.Compartment().initialize()
branch = jx.Branch([comp for _ in range(nseg_per_branch)]).initialize()
cell = jx.Cell([branch for _ in range(num_branches)], parents=parents).initialize()
branch = jx.Branch(comp, nseg_per_branch).initialize()
cell = jx.Cell(branch, parents=parents).initialize()
cell.insert(HHChannel())

cell.branch(0).comp(0.0).set_params("length", 12.0)
cell.branch(1).comp(1.0).set_params("gNa", 0.2)
assert cell.num_trainable_params == 2
assert cell.num_trainable_params == 0

cell.branch([0, 1]).make_trainable("radius", 1.0)
assert cell.num_trainable_params == 4
assert cell.num_trainable_params == 2
cell.branch([0, 1]).make_trainable("length")
cell.branch([0, 1]).make_trainable("axial_resistivity", [600.0, 700.0])
cell.branch([0, 1]).make_trainable("gNa")
Expand All @@ -44,11 +43,10 @@ def test_make_trainable_network():
depth = 5
parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)]
parents = jnp.asarray(parents)
num_branches = len(parents)

comp = jx.Compartment().initialize()
branch = jx.Branch([comp for _ in range(nseg_per_branch)]).initialize()
cell = jx.Cell([branch for _ in range(num_branches)], parents=parents).initialize()
branch = jx.Branch(comp, nseg_per_branch).initialize()
cell = jx.Cell(branch, parents=parents).initialize()
cell.insert(HHChannel())

conns = [
Expand Down
10 changes: 4 additions & 6 deletions tests/test_record_and_stimulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@ def test_record_and_stimulate_api():
depth = 2
parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)]
parents = jnp.asarray(parents)
num_branches = len(parents)

comp = jx.Compartment().initialize()
branch = jx.Branch([comp for _ in range(nseg_per_branch)]).initialize()
cell = jx.Cell([branch for _ in range(num_branches)], parents=parents).initialize()
branch = jx.Branch(comp, nseg_per_branch).initialize()
cell = jx.Cell(branch, parents=parents).initialize()

cell.branch(0).comp(0.0).record()
cell.branch(1).comp(1.0).record()
Expand All @@ -36,11 +35,10 @@ def test_record_shape():
depth = 2
parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)]
parents = jnp.asarray(parents)
num_branches = len(parents)

comp = jx.Compartment().initialize()
branch = jx.Branch([comp for _ in range(nseg_per_branch)]).initialize()
cell = jx.Cell([branch for _ in range(num_branches)], parents=parents).initialize()
branch = jx.Branch(comp, nseg_per_branch).initialize()
cell = jx.Cell(branch, parents=parents).initialize()

current = jx.step_current(0.0, 1.0, 1.0, 0.025, 3.0)
cell.branch(1).comp(1.0).stimulate(current)
Expand Down