Skip to content

Commit

Permalink
return and track number of trainable parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Nov 21, 2023
1 parent 5ae9134 commit d5bdcc5
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
25 changes: 21 additions & 4 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(self):
self.indices_set_by_trainables: List[jnp.ndarray] = []
self.trainable_params: List[Dict[str, jnp.ndarray]] = []
self.allow_make_trainable: bool = True
self.num_trainable_params: int = 0

# For recordings.
self.recordings: pd.DataFrame = pd.DataFrame().from_dict({})
Expand Down Expand Up @@ -298,7 +299,12 @@ def _get_states(self, key: str, view):
else:
raise KeyError("Key not recognized.")

def make_trainable(self, key: str, init_val: Optional[Union[float, list]] = None):
def make_trainable(
self,
key: str,
init_val: Optional[Union[float, list]] = None,
verbose: bool = True,
):
"""Make a parameter trainable.
Args:
Expand All @@ -308,16 +314,22 @@ def make_trainable(self, key: str, init_val: Optional[Union[float, list]] = None
to match the number of created parameters. If `None`, the current
parameter value is used and if parameter sharing is performed that the
current parameter value is averaged over all shared parameters.
verbose: Whether to print the number of parameters that are added and the
total number of parameters.
"""
view = deepcopy(self.nodes.assign(controlled_by_param=0))
self._make_trainable(view, key, init_val)
self._make_trainable(view, key, init_val, verbose=verbose)

def _make_trainable(
self, view, key: str, init_val: Optional[Union[float, list]] = None
self,
view,
key: str,
init_val: Optional[Union[float, list]] = None,
verbose: bool = True,
):
assert (
self.allow_make_trainable
), "network.cell('all') is not supported. Use a for-loop over cells."
), "network.cell('all').make_trainable() is not supported. Use a for-loop over cells."

grouped_view = view.groupby("controlled_by_param")
inds_of_comps = list(grouped_view.apply(lambda x: x.index.values))
Expand Down Expand Up @@ -356,6 +368,11 @@ def _make_trainable(
new_params = jnp.mean(param_vals, axis=1, keepdims=True)

self.trainable_params.append({key: new_params})
self.num_trainable_params += num_created_parameters
if verbose:
print(
f"Number of newly added trainable parameters: {num_created_parameters}. Total number of trainable parameters: {self.num_trainable_params}"
)

def add_to_group(self, group_name):
raise ValueError("`add_to_group()` makes no sense for an entire module.")
Expand Down
3 changes: 3 additions & 0 deletions tests/test_make_trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ def test_make_trainable():

cell.branch(0).comp(0.0).set_params("length", 12.0)
cell.branch(1).comp(1.0).set_params("gNa", 0.2)
assert cell.num_trainable_params == 2

cell.branch([0, 1]).make_trainable("radius", 1.0)
assert cell.num_trainable_params == 4
cell.branch([0, 1]).make_trainable("length")
cell.branch([0, 1]).make_trainable("axial_resistivity", [600.0, 700.0])
cell.branch([0, 1]).make_trainable("gNa")
Expand Down Expand Up @@ -69,3 +71,4 @@ def test_make_trainable_network():

cell.get_parameters()
net.GlutamateSynapse.set_params("gS", 0.1)
assert cell.num_trainable_params == 8 # `set_params()` is ignored.

0 comments on commit d5bdcc5

Please sign in to comment.