Skip to content

Commit

Permalink
bugfix for synapse indexing if a list is passed
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Dec 11, 2023
1 parent 721a05a commit b06df16
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 3 deletions.
8 changes: 7 additions & 1 deletion jaxley/modules/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,8 +493,14 @@ def show(

def adjust_view(self, key: str, index: float):
"""Update view."""
if index != "all":
if isinstance(index, int) or isinstance(index, np.int64):
self.view = self.view[self.view[key] == index]
elif isinstance(index, list):
self.view = self.view[self.view[key].isin(index)]
else:
assert index == "all"
self.view["controlled_by_param"] -= self.view["controlled_by_param"].iloc[0]

return self

def set_params(self, key: str, val: float):
Expand Down
69 changes: 67 additions & 2 deletions tests/test_synapse_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,79 @@
jax.config.update("jax_platform_name", "cpu")

import jax.numpy as jnp
import numpy as np

import jaxley as jx
from jaxley.synapses import GlutamateSynapse, TestSynapse


def test_set_params_and_querying_params():
def test_set_params_and_querying_params_one_type():
"""Test if the correct parameters are set if one type of synapses is inserted."""
comp = jx.Compartment()
branch = jx.Branch(comp, nseg=4)
cell = jx.Cell(branch, parents=[-1])
net = jx.Network([cell for _ in range(4)])

for pre_ind in [0, 1]:
for post_ind in [2, 3]:
pre = net.cell(pre_ind).branch(0).comp(0.0)
post = net.cell(post_ind).branch(0).comp(0.0)
pre.connect(post, GlutamateSynapse())

net.set_params("gS", 0.15)
assert np.all(net.syn_params["gS"] == 0.15)

net.GlutamateSynapse.set_params("gS", 0.32)
assert np.all(net.syn_params["gS"] == 0.32)

net.GlutamateSynapse(1).set_params("gS", 0.18)
assert net.syn_params["gS"][1] == 0.18
assert np.all(net.syn_params["gS"][np.asarray([0, 2, 3])] == 0.32)

net.GlutamateSynapse([2, 3]).set_params("gS", 0.12)
assert net.syn_params["gS"][0] == 0.32
assert net.syn_params["gS"][1] == 0.18
assert np.all(net.syn_params["gS"][np.asarray([2, 3])] == 0.12)


def test_set_params_and_querying_params_two_types():
"""Test whether the correct parameters are set."""
pass
comp = jx.Compartment()
branch = jx.Branch(comp, nseg=4)
cell = jx.Cell(branch, parents=[-1])
net = jx.Network([cell for _ in range(4)])

for pre_ind in [0, 1]:
for post_ind, synapse in zip([2, 3], [GlutamateSynapse(), TestSynapse()]):
pre = net.cell(pre_ind).branch(0).comp(0.0)
post = net.cell(post_ind).branch(0).comp(0.0)
pre.connect(post, synapse)

net.set_params("gS", 0.15)
assert np.all(net.syn_params["gS"] == 0.15)
assert np.all(net.syn_params["gC"] == 0.5) # 0.5 is the default value.

net.GlutamateSynapse.set_params("gS", 0.32)
assert np.all(net.syn_params["gS"] == 0.32)
assert np.all(net.syn_params["gC"] == 0.5) # 0.5 is the default value.

net.TestSynapse.set_params("gC", 0.18)
assert np.all(net.syn_params["gS"] == 0.32)
assert np.all(net.syn_params["gC"] == 0.18)

net.GlutamateSynapse(1).set_params("gS", 0.24)
assert net.syn_params["gS"][0] == 0.32
assert net.syn_params["gS"][1] == 0.24
assert np.all(net.syn_params["gC"] == 0.18)

net.GlutamateSynapse([0, 1]).set_params("gS", 0.27)
assert np.all(net.syn_params["gS"] == 0.27)
assert np.all(net.syn_params["gC"] == 0.18)

net.TestSynapse([0, 1]).set_params("gC", 0.21)
assert np.all(net.syn_params["gS"] == 0.27)
assert np.all(net.syn_params["gC"] == 0.21)


def test_shuffling_order_of_set_params():
"""Test whether the result is the same if the order of `set_params` is changed."""
Expand Down

0 comments on commit b06df16

Please sign in to comment.