From 2f2ccd1a80c9e0545d3325078f2327ffc78f5b1d Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Tue, 29 Oct 2024 11:18:05 +0100 Subject: [PATCH] fix/rm: rm test for laxy indexing into groups, and rebase onto main. --- jaxley/modules/base.py | 9 +++++++-- jaxley/modules/network.py | 2 +- tests/test_groups.py | 31 ------------------------------- 3 files changed, 8 insertions(+), 34 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index d034923e..4ad8e7a4 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -43,12 +43,17 @@ def only_allow_module(func): - """Decorator to only allow the function to be called on Module instances.""" + """Decorator to only allow the function to be called on Module instances. + + Decorates methods of Module that cannot be called on Views of Modules instances. + and have to be called on the Module itself.""" def wrapper(self, *args, **kwargs): + module_name = self.base.__class__.__name__ + method_name = func.__name__ assert not isinstance( self, View - ), "This function can only be called on Module instances" + ), f"{method_name} is currently not supported for Views. Call on the {module_name} base Module." return func(self, *args, **kwargs) return wrapper diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index 7efa2c29..c225545e 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -105,7 +105,7 @@ def __init__( # Channels. self._gather_channels_from_constituents(cells) - self.initialize() + self._initialize() del self._cells_list def __repr__(self): diff --git a/tests/test_groups.py b/tests/test_groups.py index 2987fb78..00e22ee5 100644 --- a/tests/test_groups.py +++ b/tests/test_groups.py @@ -101,37 +101,6 @@ def test_subclassing_groups_net_make_trainable_equivalence(): assert jnp.array_equal(inds1, inds2) -def test_subclassing_groups_net_lazy_indexing_make_trainable_equivalence(): - """Test whether groups can be indexing in a lazy way.""" - comp = jx.Compartment() - branch = jx.Branch(comp, 4) - cell = jx.Cell(branch, [-1, 0]) - net1 = jx.Network([cell for _ in range(10)]) - net2 = jx.Network([cell for _ in range(10)]) - - net1.cell([0, 3, 5]).add_to_group("excitatory") - net2.cell([0, 3, 5]).add_to_group("excitatory") - - # The following lines are made possible by PR #324. - net1.excitatory.cell([0, 3]).branch(0).make_trainable("radius") - net1.excitatory.cell([0, 5]).branch(1).comp("all").make_trainable("length") - net1.excitatory.cell("all").branch(1).comp(2).make_trainable("axial_resistivity") - params1 = jnp.concatenate(jax.tree_util.tree_flatten(net1.get_parameters())[0]) - - # The following lines are made possible by PR #324. - net2.excitatory[[0, 3], 0].make_trainable("radius") - net2.excitatory[[0, 5], 1, :].make_trainable("length") - net2.excitatory[:, 1, 2].make_trainable("axial_resistivity") - params2 = jnp.concatenate(jax.tree_util.tree_flatten(net2.get_parameters())[0]) - - assert jnp.array_equal(params1, params2) - - for inds1, inds2 in zip( - net1.indices_set_by_trainables, net2.indices_set_by_trainables - ): - assert jnp.array_equal(inds1, inds2) - - def test_fully_connect_groups_equivalence(): """Test whether groups can be used with `fully_connect`.""" comp = jx.Compartment()