diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 473a875b..85cd8e34 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -195,6 +195,7 @@ def _append_to_channel_params_and_state( def _append_to_channel_nodes(self, index, channel): """Adds channel nodes from constituents to `self.channel_nodes`.""" name = type(channel).__name__ + if name in self.channel_nodes: self.channel_nodes[name] = pd.concat( [self.channel_nodes[name], index] @@ -387,7 +388,12 @@ def get_parameters(self): return self.trainable_params def get_all_parameters(self, trainable_params): - """Return all parameters (and coupling conductances) needed to simulate.""" + """Return all parameters (and coupling conductances) needed to simulate. + + This is done by first obtaining the current value of every parameter (not only + the trainable ones) and then replacing the trainable ones with the value + in `trainable_params()`. + """ params = {} for key, val in self.params.items(): params[key] = val diff --git a/jaxley/modules/branch.py b/jaxley/modules/branch.py index f5b85fa3..92bc25a6 100644 --- a/jaxley/modules/branch.py +++ b/jaxley/modules/branch.py @@ -31,6 +31,9 @@ def __init__( compartment_list = [compartments for _ in range(nseg)] else: compartment_list = compartments + # Compartments are currently defined in reverse. See also #30. This `.reverse` + # is needed to make `tests/test_composability_of_modules.py` pass. + compartment_list.reverse() self._append_to_params_and_state(compartment_list) for comp in compartment_list: diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index 01007439..67678264 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -52,6 +52,44 @@ def __init__( ] self.initialize() + + # Indexing. + self.nodes = pd.DataFrame( + dict( + comp_index=np.arange(self.nseg * self.total_nbranches).tolist(), + branch_index=( + np.arange(self.nseg * self.total_nbranches) // self.nseg + ).tolist(), + cell_index=list( + itertools.chain( + *[ + [i] * (self.nseg * b) + for i, b in enumerate(self.nbranches_per_cell) + ] + ) + ), + ) + ) + + # Channel indexing. + for i, cell in enumerate(self.cells): + for channel in cell.channels: + name = type(channel).__name__ + comp_inds = deepcopy(cell.channel_nodes[name]["comp_index"].to_numpy()) + branch_inds = deepcopy( + cell.channel_nodes[name]["branch_index"].to_numpy() + ) + comp_inds += self.nseg * self.cumsum_nbranches[i] + branch_inds += self.cumsum_nbranches[i] + index = pd.DataFrame.from_dict( + dict( + comp_index=comp_inds, + branch_index=branch_inds, + cell_index=[i] * len(comp_inds), + ) + ) + self._append_to_channel_nodes(index, channel) + self.initialized_conds = False def _append_synapses_to_params_and_state(self, connectivities): @@ -106,43 +144,6 @@ def init_morph(self): exclude_first=False, ) - # Indexing. - self.nodes = pd.DataFrame( - dict( - comp_index=np.arange(self.nseg * self.total_nbranches).tolist(), - branch_index=( - np.arange(self.nseg * self.total_nbranches) // self.nseg - ).tolist(), - cell_index=list( - itertools.chain( - *[ - [i] * (self.nseg * b) - for i, b in enumerate(self.nbranches_per_cell) - ] - ) - ), - ) - ) - - # Channel indexing. - for i, cell in enumerate(self.cells): - for channel in cell.channels: - name = type(channel).__name__ - comp_inds = deepcopy(cell.channel_nodes[name]["comp_index"].to_numpy()) - branch_inds = deepcopy( - cell.channel_nodes[name]["branch_index"].to_numpy() - ) - comp_inds += self.nseg * self.cumsum_nbranches[i] - branch_inds += self.cumsum_nbranches[i] - index = pd.DataFrame.from_dict( - dict( - comp_index=comp_inds, - branch_index=branch_inds, - cell_index=[i] * len(comp_inds), - ) - ) - self._append_to_channel_nodes(index, channel) - self.initialized_morph = True def init_conds(self, params): diff --git a/tests/test_composability_of_modules.py b/tests/test_composability_of_modules.py new file mode 100644 index 00000000..36c33b89 --- /dev/null +++ b/tests/test_composability_of_modules.py @@ -0,0 +1,91 @@ +import jax + +jax.config.update("jax_enable_x64", True) +jax.config.update("jax_platform_name", "cpu") + +import jax.numpy as jnp + +import jaxley as jx +from jaxley.channels import HHChannel + + +def test_compose_branch(): + """Test inserting to comp and composing to branch equals inserting to branch.""" + dt = 0.025 + t_max = 3.0 + current = jx.step_current(1.0, 1.0, 0.1, dt, t_max) + + comp1 = jx.Compartment() + comp1.insert(HHChannel()) + comp2 = jx.Compartment() + branch1 = jx.Branch([comp1, comp2]) + branch1.comp(0.0).record() + branch1.comp(0.0).stimulate(current) + + comp = jx.Compartment() + branch2 = jx.Branch(comp, nseg=2) + branch2.comp(0.0).insert(HHChannel()) + branch2.comp(0.0).record() + branch2.comp(0.0).stimulate(current) + + voltages1 = jx.integrate(branch1, delta_t=dt) + voltages2 = jx.integrate(branch2, delta_t=dt) + + assert jnp.max(jnp.abs(voltages1 - voltages2)) < 1e-8 + + +def test_compose_cell(): + """Test inserting to branch and composing to cell equals inserting to cell.""" + nseg_per_branch = 4 + dt = 0.025 + t_max = 3.0 + current = jx.step_current(1.0, 1.0, 0.1, dt, t_max) + + comp = jx.Compartment() + + branch1 = jx.Branch(comp, nseg_per_branch) + branch1.insert(HHChannel()) + branch2 = jx.Branch(comp, nseg_per_branch) + cell1 = jx.Cell([branch1, branch2], parents=[-1, 0]) + cell1.branch(0).comp(0.0).record() + cell1.branch(0).comp(0.0).stimulate(current) + + branch = jx.Branch(comp, nseg_per_branch) + cell2 = jx.Cell(branch, parents=[-1, 0]) + cell2.branch(0).insert(HHChannel()) + cell2.branch(0).comp(0.0).record() + cell2.branch(0).comp(0.0).stimulate(current) + + voltages1 = jx.integrate(cell1, delta_t=dt) + voltages2 = jx.integrate(cell2, delta_t=dt) + + assert jnp.max(jnp.abs(voltages1 - voltages2)) < 1e-8 + + +def test_compose_net(): + """Test inserting to cell and composing to net equals inserting to net.""" + nseg_per_branch = 4 + dt = 0.025 + t_max = 3.0 + current = jx.step_current(1.0, 1.0, 0.1, dt, t_max) + + comp = jx.Compartment() + branch = jx.Branch(comp, nseg_per_branch) + + cell1 = jx.Cell(branch, parents=[-1, 0, 0]) + cell1.insert(HHChannel()) + cell2 = jx.Cell(branch, parents=[-1, 0, 0]) + net1 = jx.Network([cell1, cell2], []) + net1.cell(0).branch(0).comp(0.0).record() + net1.cell(0).branch(0).comp(0.0).stimulate(current) + + cell = jx.Cell(branch, parents=[-1, 0, 0]) + net2 = jx.Network([cell, cell], []) + net2.cell(0).insert(HHChannel()) + net2.cell(0).branch(0).comp(0.0).record() + net2.cell(0).branch(0).comp(0.0).stimulate(current) + + voltages1 = jx.integrate(net1, delta_t=dt) + voltages2 = jx.integrate(net2, delta_t=dt) + + assert jnp.max(jnp.abs(voltages1 - voltages2)) < 1e-8