From 14e6f7dd14e639052a6c4ecbced2cd01e2a954b4 Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Wed, 22 Nov 2023 09:29:41 +0100 Subject: [PATCH] New recommended API --- jaxley/modules/branch.py | 27 +++++++++++++---- jaxley/modules/cell.py | 26 +++++++++++------ tests/neurax_vs_neuron/test_branch.py | 2 +- tests/neurax_vs_neuron/test_cell.py | 4 +-- tests/test_api_equivalence.py | 42 +++++++++++++++++++++++++++ tests/test_cell_matches_branch.py | 6 ++-- tests/test_make_trainable.py | 14 ++++----- tests/test_record_and_stimulate.py | 10 +++---- 8 files changed, 96 insertions(+), 35 deletions(-) create mode 100644 tests/test_api_equivalence.py diff --git a/jaxley/modules/branch.py b/jaxley/modules/branch.py index baa0a4bf..f5b85fa3 100644 --- a/jaxley/modules/branch.py +++ b/jaxley/modules/branch.py @@ -1,5 +1,5 @@ from copy import deepcopy -from typing import Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional, Union import jax.numpy as jnp import numpy as np @@ -14,14 +14,29 @@ class Branch(Module): branch_params: Dict = {} branch_states: Dict = {} - def __init__(self, compartments: List[Compartment]): + def __init__( + self, + compartments: Union[Compartment, List[Compartment]], + nseg: Optional[int] = None, + ): super().__init__() + assert ( + isinstance(compartments, Compartment) or nseg is None + ), "If `compartments` is a list then you cannot set `nseg`." + assert ( + isinstance(compartments, List) or nseg is not None + ), "If `compartments` is not a list then you have to set `nseg`." self._init_params_and_state(self.branch_params, self.branch_states) - self._append_to_params_and_state(compartments) - for comp in compartments: + if isinstance(compartments, Compartment): + compartment_list = [compartments for _ in range(nseg)] + else: + compartment_list = compartments + + self._append_to_params_and_state(compartment_list) + for comp in compartment_list: self._append_to_channel_params_and_state(comp) - self.nseg = len(compartments) + self.nseg = len(compartment_list) self.total_nbranches = 1 self.nbranches_per_cell = [1] self.cumsum_nbranches = jnp.asarray([0, 1]) @@ -36,7 +51,7 @@ def __init__(self, compartments: List[Compartment]): ) # Channel indexing. - for i, comp in enumerate(compartments): + for i, comp in enumerate(compartment_list): index = pd.DataFrame.from_dict( dict(comp_index=[i], branch_index=[0], cell_index=[0]) ) diff --git a/jaxley/modules/cell.py b/jaxley/modules/cell.py index c171e11e..c56a8e66 100644 --- a/jaxley/modules/cell.py +++ b/jaxley/modules/cell.py @@ -1,5 +1,5 @@ from copy import deepcopy -from typing import Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional, Union import jax.numpy as jnp import numpy as np @@ -23,18 +23,26 @@ class Cell(Module): cell_params: Dict = {} cell_states: Dict = {} - def __init__(self, branches: List[Branch], parents: List): + def __init__(self, branches: Union[Branch, List[Branch]], parents: List): super().__init__() + assert isinstance(branches, Branch) or len(parents) == len( + branches + ), "If `branches` is a list then you have to provide equally many parents, i.e. len(branches) == len(parents)." self._init_params_and_state(self.cell_params, self.cell_states) - self._append_to_params_and_state(branches) - for branch in branches: + if isinstance(branches, Branch): + branch_list = [branches for _ in range(len(parents))] + else: + branch_list = branches + + self._append_to_params_and_state(branch_list) + for branch in branch_list: self._append_to_channel_params_and_state(branch) - self.nseg = branches[0].nseg - self.total_nbranches = len(branches) - self.nbranches_per_cell = [len(branches)] + self.nseg = branch_list[0].nseg + self.total_nbranches = len(branch_list) + self.nbranches_per_cell = [len(branch_list)] self.comb_parents = jnp.asarray(parents) - self.cumsum_nbranches = jnp.asarray([0, len(branches)]) + self.cumsum_nbranches = jnp.asarray([0, len(branch_list)]) # Indexing. self.nodes = pd.DataFrame( @@ -48,7 +56,7 @@ def __init__(self, branches: List[Branch], parents: List): ) # Channel indexing. - for i, branch in enumerate(branches): + for i, branch in enumerate(branch_list): for channel in branch.channels: name = type(channel).__name__ comp_inds = deepcopy( diff --git a/tests/neurax_vs_neuron/test_branch.py b/tests/neurax_vs_neuron/test_branch.py index 4e20c3b0..45e65a8c 100644 --- a/tests/neurax_vs_neuron/test_branch.py +++ b/tests/neurax_vs_neuron/test_branch.py @@ -159,7 +159,7 @@ def test_similarity_complex(): def _jaxley_complex(i_delay, i_dur, i_amp, dt, t_max, diams): nseg = 16 comp = jx.Compartment() - branch = jx.Branch([comp for _ in range(nseg)]) + branch = jx.Branch(comp, nseg) branch.insert(HHChannel()) diff --git a/tests/neurax_vs_neuron/test_cell.py b/tests/neurax_vs_neuron/test_cell.py index c0dc8fc3..a04add9b 100644 --- a/tests/neurax_vs_neuron/test_cell.py +++ b/tests/neurax_vs_neuron/test_cell.py @@ -37,8 +37,8 @@ def test_similarity(): def _run_jaxley(i_delay, i_dur, i_amp, dt, t_max): nseg_per_branch = 8 comp = jx.Compartment().initialize() - branch = jx.Branch([comp for _ in range(nseg_per_branch)]).initialize() - cell = jx.Cell([branch for _ in range(3)], parents=[-1, 0, 0]).initialize() + branch = jx.Branch(comp, nseg_per_branch).initialize() + cell = jx.Cell(branch, parents=[-1, 0, 0]).initialize() cell.insert(HHChannel()) cell.set_params("radius", 5.0) diff --git a/tests/test_api_equivalence.py b/tests/test_api_equivalence.py new file mode 100644 index 00000000..690dfc4c --- /dev/null +++ b/tests/test_api_equivalence.py @@ -0,0 +1,42 @@ +import jax + +jax.config.update("jax_enable_x64", True) +jax.config.update("jax_platform_name", "cpu") + +import jax.numpy as jnp + +import jaxley as jx + + +def test_api_equivalence(): + """Test the API for recording and stimulating.""" + nseg_per_branch = 2 + depth = 2 + dt = 0.025 + + parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)] + parents = jnp.asarray(parents) + num_branches = len(parents) + + comp = jx.Compartment().initialize() + + branch1 = jx.Branch([comp for _ in range(nseg_per_branch)]).initialize() + cell1 = jx.Cell( + [branch1 for _ in range(num_branches)], parents=parents + ).initialize() + + branch2 = jx.Branch(comp, nseg=nseg_per_branch).initialize() + cell2 = jx.Cell(branch2, parents=parents).initialize() + + cell1.branch(2).comp(0.4).record() + cell2.branch(2).comp(0.4).record() + + current = jx.step_current(0.5, 1.0, 1.0, dt, 3.0) + cell1.branch(1).comp(1.0).stimulate(current) + cell2.branch(1).comp(1.0).stimulate(current) + + voltages1 = jx.integrate(cell1, delta_t=dt) + voltages2 = jx.integrate(cell2, delta_t=dt) + assert ( + jnp.max(jnp.abs(voltages1 - voltages2)) < 1e-8 + ), "Voltages do not match between APIs." diff --git a/tests/test_cell_matches_branch.py b/tests/test_cell_matches_branch.py index 5eb77edb..1bf0fff0 100644 --- a/tests/test_cell_matches_branch.py +++ b/tests/test_cell_matches_branch.py @@ -16,7 +16,7 @@ def _run_long_branch(dt, t_max): nseg_per_branch = 8 comp = jx.Compartment() - branch = jx.Branch([comp for _ in range(nseg_per_branch)]) + branch = jx.Branch(comp, nseg_per_branch) branch.insert(HHChannel()) branch.comp("all").make_trainable("radius", 1.0) @@ -40,8 +40,8 @@ def _run_short_branches(dt, t_max): parents = jnp.asarray([-1, 0]) comp = jx.Compartment() - branch = jx.Branch([comp for _ in range(nseg_per_branch)]) - cell = jx.Cell([branch for _ in range(2)], parents=parents) + branch = jx.Branch(comp, nseg_per_branch) + cell = jx.Cell(branch, parents=parents) cell.insert(HHChannel()) cell.branch("all").comp("all").make_trainable("radius", 1.0) diff --git a/tests/test_make_trainable.py b/tests/test_make_trainable.py index ec1774a4..0abe789e 100644 --- a/tests/test_make_trainable.py +++ b/tests/test_make_trainable.py @@ -17,19 +17,18 @@ def test_make_trainable(): depth = 5 parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)] parents = jnp.asarray(parents) - num_branches = len(parents) comp = jx.Compartment().initialize() - branch = jx.Branch([comp for _ in range(nseg_per_branch)]).initialize() - cell = jx.Cell([branch for _ in range(num_branches)], parents=parents).initialize() + branch = jx.Branch(comp, nseg_per_branch).initialize() + cell = jx.Cell(branch, parents=parents).initialize() cell.insert(HHChannel()) cell.branch(0).comp(0.0).set_params("length", 12.0) cell.branch(1).comp(1.0).set_params("gNa", 0.2) - assert cell.num_trainable_params == 2 + assert cell.num_trainable_params == 0 cell.branch([0, 1]).make_trainable("radius", 1.0) - assert cell.num_trainable_params == 4 + assert cell.num_trainable_params == 2 cell.branch([0, 1]).make_trainable("length") cell.branch([0, 1]).make_trainable("axial_resistivity", [600.0, 700.0]) cell.branch([0, 1]).make_trainable("gNa") @@ -44,11 +43,10 @@ def test_make_trainable_network(): depth = 5 parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)] parents = jnp.asarray(parents) - num_branches = len(parents) comp = jx.Compartment().initialize() - branch = jx.Branch([comp for _ in range(nseg_per_branch)]).initialize() - cell = jx.Cell([branch for _ in range(num_branches)], parents=parents).initialize() + branch = jx.Branch(comp, nseg_per_branch).initialize() + cell = jx.Cell(branch, parents=parents).initialize() cell.insert(HHChannel()) conns = [ diff --git a/tests/test_record_and_stimulate.py b/tests/test_record_and_stimulate.py index 8865b1f1..897930b1 100644 --- a/tests/test_record_and_stimulate.py +++ b/tests/test_record_and_stimulate.py @@ -14,11 +14,10 @@ def test_record_and_stimulate_api(): depth = 2 parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)] parents = jnp.asarray(parents) - num_branches = len(parents) comp = jx.Compartment().initialize() - branch = jx.Branch([comp for _ in range(nseg_per_branch)]).initialize() - cell = jx.Cell([branch for _ in range(num_branches)], parents=parents).initialize() + branch = jx.Branch(comp, nseg_per_branch).initialize() + cell = jx.Cell(branch, parents=parents).initialize() cell.branch(0).comp(0.0).record() cell.branch(1).comp(1.0).record() @@ -36,11 +35,10 @@ def test_record_shape(): depth = 2 parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)] parents = jnp.asarray(parents) - num_branches = len(parents) comp = jx.Compartment().initialize() - branch = jx.Branch([comp for _ in range(nseg_per_branch)]).initialize() - cell = jx.Cell([branch for _ in range(num_branches)], parents=parents).initialize() + branch = jx.Branch(comp, nseg_per_branch).initialize() + cell = jx.Cell(branch, parents=parents).initialize() current = jx.step_current(0.0, 1.0, 1.0, 0.025, 3.0) cell.branch(1).comp(1.0).stimulate(current)