diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 0465fdee..406f6657 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -297,6 +297,12 @@ def _make_trainable( f"Number of newly added trainable parameters: {num_created_parameters}. Total number of trainable parameters: {self.num_trainable_params}" ) + def delete_trainables(self): + """Removes all trainable parameters from the module.""" + self.indices_set_by_trainables: List[jnp.ndarray] = [] + self.trainable_params: List[Dict[str, jnp.ndarray]] = [] + self.num_trainable_params: int = 0 + def add_to_group(self, group_name): raise ValueError("`add_to_group()` makes no sense for an entire module.") diff --git a/tests/test_make_trainable.py b/tests/test_make_trainable.py index d997648a..bb22730e 100644 --- a/tests/test_make_trainable.py +++ b/tests/test_make_trainable.py @@ -37,6 +37,28 @@ def test_make_trainable(): cell.get_parameters() +def test_delete_trainables(): + """Test make_trainable.""" + nseg_per_branch = 8 + + depth = 5 + parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)] + parents = jnp.asarray(parents) + + comp = jx.Compartment().initialize() + branch = jx.Branch(comp, nseg_per_branch).initialize() + cell = jx.Cell(branch, parents=parents).initialize() + + cell.branch(0).comp(0.0).make_trainable("length", 12.0) + assert cell.num_trainable_params == 1 + + cell.delete_trainables() + cell.branch(0).comp(0.0).make_trainable("length", 12.0) + assert cell.num_trainable_params == 1 + + cell.get_parameters() + + def test_make_trainable_network(): """Test make_trainable.""" nseg_per_branch = 8