Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfixes for composability. #172

Merged
merged 3 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def _append_to_channel_params_and_state(
def _append_to_channel_nodes(self, index, channel):
"""Adds channel nodes from constituents to `self.channel_nodes`."""
name = type(channel).__name__

if name in self.channel_nodes:
self.channel_nodes[name] = pd.concat(
[self.channel_nodes[name], index]
Expand Down Expand Up @@ -387,7 +388,12 @@ def get_parameters(self):
return self.trainable_params

def get_all_parameters(self, trainable_params):
"""Return all parameters (and coupling conductances) needed to simulate."""
"""Return all parameters (and coupling conductances) needed to simulate.

This is done by first obtaining the current value of every parameter (not only
the trainable ones) and then replacing the trainable ones with the value
in `trainable_params()`.
"""
params = {}
for key, val in self.params.items():
params[key] = val
Expand Down
3 changes: 3 additions & 0 deletions jaxley/modules/branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ def __init__(
compartment_list = [compartments for _ in range(nseg)]
else:
compartment_list = compartments
# Compartments are currently defined in reverse. See also #30. This `.reverse`
# is needed to make `tests/test_composability_of_modules.py` pass.
compartment_list.reverse()

self._append_to_params_and_state(compartment_list)
for comp in compartment_list:
Expand Down
75 changes: 38 additions & 37 deletions jaxley/modules/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,44 @@ def __init__(
]

self.initialize()

# Indexing.
self.nodes = pd.DataFrame(
dict(
comp_index=np.arange(self.nseg * self.total_nbranches).tolist(),
branch_index=(
np.arange(self.nseg * self.total_nbranches) // self.nseg
).tolist(),
cell_index=list(
itertools.chain(
*[
[i] * (self.nseg * b)
for i, b in enumerate(self.nbranches_per_cell)
]
)
),
)
)

# Channel indexing.
for i, cell in enumerate(self.cells):
for channel in cell.channels:
name = type(channel).__name__
comp_inds = deepcopy(cell.channel_nodes[name]["comp_index"].to_numpy())
branch_inds = deepcopy(
cell.channel_nodes[name]["branch_index"].to_numpy()
)
comp_inds += self.nseg * self.cumsum_nbranches[i]
branch_inds += self.cumsum_nbranches[i]
index = pd.DataFrame.from_dict(
dict(
comp_index=comp_inds,
branch_index=branch_inds,
cell_index=[i] * len(comp_inds),
)
)
self._append_to_channel_nodes(index, channel)

self.initialized_conds = False

def _append_synapses_to_params_and_state(self, connectivities):
Expand Down Expand Up @@ -106,43 +144,6 @@ def init_morph(self):
exclude_first=False,
)

# Indexing.
self.nodes = pd.DataFrame(
dict(
comp_index=np.arange(self.nseg * self.total_nbranches).tolist(),
branch_index=(
np.arange(self.nseg * self.total_nbranches) // self.nseg
).tolist(),
cell_index=list(
itertools.chain(
*[
[i] * (self.nseg * b)
for i, b in enumerate(self.nbranches_per_cell)
]
)
),
)
)

# Channel indexing.
for i, cell in enumerate(self.cells):
for channel in cell.channels:
name = type(channel).__name__
comp_inds = deepcopy(cell.channel_nodes[name]["comp_index"].to_numpy())
branch_inds = deepcopy(
cell.channel_nodes[name]["branch_index"].to_numpy()
)
comp_inds += self.nseg * self.cumsum_nbranches[i]
branch_inds += self.cumsum_nbranches[i]
index = pd.DataFrame.from_dict(
dict(
comp_index=comp_inds,
branch_index=branch_inds,
cell_index=[i] * len(comp_inds),
)
)
self._append_to_channel_nodes(index, channel)

self.initialized_morph = True

def init_conds(self, params):
Expand Down
91 changes: 91 additions & 0 deletions tests/test_composability_of_modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import jax

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")

import jax.numpy as jnp

import jaxley as jx
from jaxley.channels import HHChannel


def test_compose_branch():
"""Test inserting to comp and composing to branch equals inserting to branch."""
dt = 0.025
t_max = 3.0
current = jx.step_current(1.0, 1.0, 0.1, dt, t_max)

comp1 = jx.Compartment()
comp1.insert(HHChannel())
comp2 = jx.Compartment()
branch1 = jx.Branch([comp1, comp2])
branch1.comp(0.0).record()
branch1.comp(0.0).stimulate(current)

comp = jx.Compartment()
branch2 = jx.Branch(comp, nseg=2)
branch2.comp(0.0).insert(HHChannel())
branch2.comp(0.0).record()
branch2.comp(0.0).stimulate(current)

voltages1 = jx.integrate(branch1, delta_t=dt)
voltages2 = jx.integrate(branch2, delta_t=dt)

assert jnp.max(jnp.abs(voltages1 - voltages2)) < 1e-8


def test_compose_cell():
"""Test inserting to branch and composing to cell equals inserting to cell."""
nseg_per_branch = 4
dt = 0.025
t_max = 3.0
current = jx.step_current(1.0, 1.0, 0.1, dt, t_max)

comp = jx.Compartment()

branch1 = jx.Branch(comp, nseg_per_branch)
branch1.insert(HHChannel())
branch2 = jx.Branch(comp, nseg_per_branch)
cell1 = jx.Cell([branch1, branch2], parents=[-1, 0])
cell1.branch(0).comp(0.0).record()
cell1.branch(0).comp(0.0).stimulate(current)

branch = jx.Branch(comp, nseg_per_branch)
cell2 = jx.Cell(branch, parents=[-1, 0])
cell2.branch(0).insert(HHChannel())
cell2.branch(0).comp(0.0).record()
cell2.branch(0).comp(0.0).stimulate(current)

voltages1 = jx.integrate(cell1, delta_t=dt)
voltages2 = jx.integrate(cell2, delta_t=dt)

assert jnp.max(jnp.abs(voltages1 - voltages2)) < 1e-8


def test_compose_net():
"""Test inserting to cell and composing to net equals inserting to net."""
nseg_per_branch = 4
dt = 0.025
t_max = 3.0
current = jx.step_current(1.0, 1.0, 0.1, dt, t_max)

comp = jx.Compartment()
branch = jx.Branch(comp, nseg_per_branch)

cell1 = jx.Cell(branch, parents=[-1, 0, 0])
cell1.insert(HHChannel())
cell2 = jx.Cell(branch, parents=[-1, 0, 0])
net1 = jx.Network([cell1, cell2], [])
net1.cell(0).branch(0).comp(0.0).record()
net1.cell(0).branch(0).comp(0.0).stimulate(current)

cell = jx.Cell(branch, parents=[-1, 0, 0])
net2 = jx.Network([cell, cell], [])
net2.cell(0).insert(HHChannel())
net2.cell(0).branch(0).comp(0.0).record()
net2.cell(0).branch(0).comp(0.0).stimulate(current)

voltages1 = jx.integrate(net1, delta_t=dt)
voltages2 = jx.integrate(net2, delta_t=dt)

assert jnp.max(jnp.abs(voltages1 - voltages2)) < 1e-8