Skip to content

Commit

Permalink
.delete_trainables() method
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Jan 22, 2024
1 parent c5538b8 commit 9152d97
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
6 changes: 6 additions & 0 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,12 @@ def _make_trainable(
f"Number of newly added trainable parameters: {num_created_parameters}. Total number of trainable parameters: {self.num_trainable_params}"
)

def delete_trainables(self):
"""Removes all trainable parameters from the module."""
self.indices_set_by_trainables: List[jnp.ndarray] = []
self.trainable_params: List[Dict[str, jnp.ndarray]] = []
self.num_trainable_params: int = 0

def add_to_group(self, group_name):
raise ValueError("`add_to_group()` makes no sense for an entire module.")

Expand Down
22 changes: 22 additions & 0 deletions tests/test_make_trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,28 @@ def test_make_trainable():
cell.get_parameters()


def test_delete_trainables():
"""Test make_trainable."""
nseg_per_branch = 8

depth = 5
parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)]
parents = jnp.asarray(parents)

comp = jx.Compartment().initialize()
branch = jx.Branch(comp, nseg_per_branch).initialize()
cell = jx.Cell(branch, parents=parents).initialize()

cell.branch(0).comp(0.0).make_trainable("length", 12.0)
assert cell.num_trainable_params == 1

cell.delete_trainables()
cell.branch(0).comp(0.0).make_trainable("length", 12.0)
assert cell.num_trainable_params == 1

cell.get_parameters()


def test_make_trainable_network():
"""Test make_trainable."""
nseg_per_branch = 8
Expand Down

0 comments on commit 9152d97

Please sign in to comment.