diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index c2b73525..473a875b 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -53,6 +53,7 @@ def __init__(self): self.indices_set_by_trainables: List[jnp.ndarray] = [] self.trainable_params: List[Dict[str, jnp.ndarray]] = [] self.allow_make_trainable: bool = True + self.num_trainable_params: int = 0 # For recordings. self.recordings: pd.DataFrame = pd.DataFrame().from_dict({}) @@ -298,7 +299,12 @@ def _get_states(self, key: str, view): else: raise KeyError("Key not recognized.") - def make_trainable(self, key: str, init_val: Optional[Union[float, list]] = None): + def make_trainable( + self, + key: str, + init_val: Optional[Union[float, list]] = None, + verbose: bool = True, + ): """Make a parameter trainable. Args: @@ -308,16 +314,22 @@ def make_trainable(self, key: str, init_val: Optional[Union[float, list]] = None to match the number of created parameters. If `None`, the current parameter value is used and if parameter sharing is performed that the current parameter value is averaged over all shared parameters. + verbose: Whether to print the number of parameters that are added and the + total number of parameters. """ view = deepcopy(self.nodes.assign(controlled_by_param=0)) - self._make_trainable(view, key, init_val) + self._make_trainable(view, key, init_val, verbose=verbose) def _make_trainable( - self, view, key: str, init_val: Optional[Union[float, list]] = None + self, + view, + key: str, + init_val: Optional[Union[float, list]] = None, + verbose: bool = True, ): assert ( self.allow_make_trainable - ), "network.cell('all') is not supported. Use a for-loop over cells." + ), "network.cell('all').make_trainable() is not supported. Use a for-loop over cells." grouped_view = view.groupby("controlled_by_param") inds_of_comps = list(grouped_view.apply(lambda x: x.index.values)) @@ -356,6 +368,11 @@ def _make_trainable( new_params = jnp.mean(param_vals, axis=1, keepdims=True) self.trainable_params.append({key: new_params}) + self.num_trainable_params += num_created_parameters + if verbose: + print( + f"Number of newly added trainable parameters: {num_created_parameters}. Total number of trainable parameters: {self.num_trainable_params}" + ) def add_to_group(self, group_name): raise ValueError("`add_to_group()` makes no sense for an entire module.") diff --git a/tests/test_make_trainable.py b/tests/test_make_trainable.py index 390ff9c5..ec1774a4 100644 --- a/tests/test_make_trainable.py +++ b/tests/test_make_trainable.py @@ -26,8 +26,10 @@ def test_make_trainable(): cell.branch(0).comp(0.0).set_params("length", 12.0) cell.branch(1).comp(1.0).set_params("gNa", 0.2) + assert cell.num_trainable_params == 2 cell.branch([0, 1]).make_trainable("radius", 1.0) + assert cell.num_trainable_params == 4 cell.branch([0, 1]).make_trainable("length") cell.branch([0, 1]).make_trainable("axial_resistivity", [600.0, 700.0]) cell.branch([0, 1]).make_trainable("gNa") @@ -69,3 +71,4 @@ def test_make_trainable_network(): cell.get_parameters() net.GlutamateSynapse.set_params("gS", 0.1) + assert cell.num_trainable_params == 8 # `set_params()` is ignored.