Skip to content

Commit

Permalink
fix/rm: rm test for laxy indexing into groups, and rebase onto main.
Browse files Browse the repository at this point in the history
  • Loading branch information
jnsbck committed Oct 29, 2024
1 parent 192a5d1 commit 2f2ccd1
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 34 deletions.
9 changes: 7 additions & 2 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,17 @@


def only_allow_module(func):
"""Decorator to only allow the function to be called on Module instances."""
"""Decorator to only allow the function to be called on Module instances.
Decorates methods of Module that cannot be called on Views of Modules instances.
and have to be called on the Module itself."""

def wrapper(self, *args, **kwargs):
module_name = self.base.__class__.__name__
method_name = func.__name__
assert not isinstance(
self, View
), "This function can only be called on Module instances"
), f"{method_name} is currently not supported for Views. Call on the {module_name} base Module."
return func(self, *args, **kwargs)

return wrapper
Expand Down
2 changes: 1 addition & 1 deletion jaxley/modules/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def __init__(
# Channels.
self._gather_channels_from_constituents(cells)

self.initialize()
self._initialize()
del self._cells_list

def __repr__(self):
Expand Down
31 changes: 0 additions & 31 deletions tests/test_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,37 +101,6 @@ def test_subclassing_groups_net_make_trainable_equivalence():
assert jnp.array_equal(inds1, inds2)


def test_subclassing_groups_net_lazy_indexing_make_trainable_equivalence():
"""Test whether groups can be indexing in a lazy way."""
comp = jx.Compartment()
branch = jx.Branch(comp, 4)
cell = jx.Cell(branch, [-1, 0])
net1 = jx.Network([cell for _ in range(10)])
net2 = jx.Network([cell for _ in range(10)])

net1.cell([0, 3, 5]).add_to_group("excitatory")
net2.cell([0, 3, 5]).add_to_group("excitatory")

# The following lines are made possible by PR #324.
net1.excitatory.cell([0, 3]).branch(0).make_trainable("radius")
net1.excitatory.cell([0, 5]).branch(1).comp("all").make_trainable("length")
net1.excitatory.cell("all").branch(1).comp(2).make_trainable("axial_resistivity")
params1 = jnp.concatenate(jax.tree_util.tree_flatten(net1.get_parameters())[0])

# The following lines are made possible by PR #324.
net2.excitatory[[0, 3], 0].make_trainable("radius")
net2.excitatory[[0, 5], 1, :].make_trainable("length")
net2.excitatory[:, 1, 2].make_trainable("axial_resistivity")
params2 = jnp.concatenate(jax.tree_util.tree_flatten(net2.get_parameters())[0])

assert jnp.array_equal(params1, params2)

for inds1, inds2 in zip(
net1.indices_set_by_trainables, net2.indices_set_by_trainables
):
assert jnp.array_equal(inds1, inds2)


def test_fully_connect_groups_equivalence():
"""Test whether groups can be used with `fully_connect`."""
comp = jx.Compartment()
Expand Down

0 comments on commit 2f2ccd1

Please sign in to comment.