Skip to content

Commit

Permalink
Adapt all tests to new recommended API
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Nov 14, 2023
1 parent 657a901 commit a8197bc
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 23 deletions.
4 changes: 2 additions & 2 deletions neurax/modules/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ class Cell(Module):

def __init__(self, branches: Union[Branch, List[Branch]], parents: List):
super().__init__()
assert (
isinstance(branches, Branch) or len(parents) == len(branches) is None
assert isinstance(branches, Branch) or len(parents) == len(
branches
), "If `branches` is a list then you have to provide equally many parents, i.e. len(branches) == len(parents)."
self._init_params_and_state(self.cell_params, self.cell_states)
if isinstance(branches, Branch):
Expand Down
4 changes: 2 additions & 2 deletions tests/neurax_vs_neuron/test_branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_similarity():
def _run_neurax(i_delay, i_dur, i_amp, dt, t_max):
nseg_per_branch = 8
comp = nx.Compartment().initialize()
branch = nx.Branch([comp for _ in range(nseg_per_branch)]).initialize()
branch = nx.Branch(comp, nseg_per_branch).initialize()
branch.insert(HHChannel())

radiuses = np.linspace(3.0, 15.0, nseg_per_branch)
Expand Down Expand Up @@ -159,7 +159,7 @@ def test_similarity_complex():
def _neurax_complex(i_delay, i_dur, i_amp, dt, t_max, diams):
nseg = 16
comp = nx.Compartment()
branch = nx.Branch([comp for _ in range(nseg)])
branch = nx.Branch(comp, nseg)

branch.insert(HHChannel())

Expand Down
4 changes: 2 additions & 2 deletions tests/neurax_vs_neuron/test_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def test_similarity():
def _run_neurax(i_delay, i_dur, i_amp, dt, t_max):
nseg_per_branch = 8
comp = nx.Compartment().initialize()
branch = nx.Branch([comp for _ in range(nseg_per_branch)]).initialize()
cell = nx.Cell([branch for _ in range(3)], parents=[-1, 0, 0]).initialize()
branch = nx.Branch(comp, nseg_per_branch).initialize()
cell = nx.Cell(branch, parents=[-1, 0, 0]).initialize()
cell.insert(HHChannel())

cell.set_params("radius", 5.0)
Expand Down
8 changes: 6 additions & 2 deletions tests/test_api_equivalence.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ def test_api_equivalence():
comp = nx.Compartment().initialize()

branch1 = nx.Branch([comp for _ in range(nseg_per_branch)]).initialize()
cell1 = nx.Cell([branch1 for _ in range(num_branches)], parents=parents).initialize()
cell1 = nx.Cell(
[branch1 for _ in range(num_branches)], parents=parents
).initialize()

branch2 = nx.Branch(comp, nseg=nseg_per_branch).initialize()
cell2 = nx.Cell(branch2, parents=parents).initialize()
Expand All @@ -35,4 +37,6 @@ def test_api_equivalence():

voltages1 = nx.integrate(cell1, delta_t=dt)
voltages2 = nx.integrate(cell2, delta_t=dt)
assert jnp.max(jnp.abs(voltages1 - voltages2)) < 1e-8, "Voltages do not match between APIs."
assert (
jnp.max(jnp.abs(voltages1 - voltages2)) < 1e-8
), "Voltages do not match between APIs."
6 changes: 3 additions & 3 deletions tests/test_cell_matches_branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def _run_long_branch(dt, t_max):
nseg_per_branch = 8

comp = nx.Compartment()
branch = nx.Branch([comp for _ in range(nseg_per_branch)])
branch = nx.Branch(comp, nseg_per_branch)
branch.insert(HHChannel())

branch.comp("all").make_trainable("radius", 1.0)
Expand All @@ -40,8 +40,8 @@ def _run_short_branches(dt, t_max):
parents = jnp.asarray([-1, 0])

comp = nx.Compartment()
branch = nx.Branch([comp for _ in range(nseg_per_branch)])
cell = nx.Cell([branch for _ in range(2)], parents=parents)
branch = nx.Branch(comp, nseg_per_branch)
cell = nx.Cell(branch, parents=parents)
cell.insert(HHChannel())

cell.branch("all").comp("all").make_trainable("radius", 1.0)
Expand Down
10 changes: 4 additions & 6 deletions tests/test_make_trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@ def test_make_trainable():
depth = 5
parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)]
parents = jnp.asarray(parents)
num_branches = len(parents)

comp = nx.Compartment().initialize()
branch = nx.Branch([comp for _ in range(nseg_per_branch)]).initialize()
cell = nx.Cell([branch for _ in range(num_branches)], parents=parents).initialize()
branch = nx.Branch(comp, nseg_per_branch).initialize()
cell = nx.Cell(branch, parents=parents).initialize()
cell.insert(HHChannel())

cell.branch(0).comp(0.0).set_params("length", 12.0)
Expand All @@ -42,11 +41,10 @@ def test_make_trainable_network():
depth = 5
parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)]
parents = jnp.asarray(parents)
num_branches = len(parents)

comp = nx.Compartment().initialize()
branch = nx.Branch([comp for _ in range(nseg_per_branch)]).initialize()
cell = nx.Cell([branch for _ in range(num_branches)], parents=parents).initialize()
branch = nx.Branch(comp, nseg_per_branch).initialize()
cell = nx.Cell(branch, parents=parents).initialize()
cell.insert(HHChannel())

conns = [
Expand Down
10 changes: 4 additions & 6 deletions tests/test_record_and_stimulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@ def test_record_and_stimulate_api():
depth = 2
parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)]
parents = jnp.asarray(parents)
num_branches = len(parents)

comp = nx.Compartment().initialize()
branch = nx.Branch([comp for _ in range(nseg_per_branch)]).initialize()
cell = nx.Cell([branch for _ in range(num_branches)], parents=parents).initialize()
branch = nx.Branch(comp, nseg_per_branch).initialize()
cell = nx.Cell(branch, parents=parents).initialize()

cell.branch(0).comp(0.0).record()
cell.branch(1).comp(1.0).record()
Expand All @@ -36,11 +35,10 @@ def test_record_shape():
depth = 2
parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)]
parents = jnp.asarray(parents)
num_branches = len(parents)

comp = nx.Compartment().initialize()
branch = nx.Branch([comp for _ in range(nseg_per_branch)]).initialize()
cell = nx.Cell([branch for _ in range(num_branches)], parents=parents).initialize()
branch = nx.Branch(comp, nseg_per_branch).initialize()
cell = nx.Cell(branch, parents=parents).initialize()

current = nx.step_current(0.0, 1.0, 1.0, 0.025, 3.0)
cell.branch(1).comp(1.0).stimulate(current)
Expand Down

0 comments on commit a8197bc

Please sign in to comment.