Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

return and track number of trainable parameters #170

Merged
merged 1 commit into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(self):
self.indices_set_by_trainables: List[jnp.ndarray] = []
self.trainable_params: List[Dict[str, jnp.ndarray]] = []
self.allow_make_trainable: bool = True
self.num_trainable_params: int = 0

# For recordings.
self.recordings: pd.DataFrame = pd.DataFrame().from_dict({})
Expand Down Expand Up @@ -298,7 +299,12 @@ def _get_states(self, key: str, view):
else:
raise KeyError("Key not recognized.")

def make_trainable(self, key: str, init_val: Optional[Union[float, list]] = None):
def make_trainable(
self,
key: str,
init_val: Optional[Union[float, list]] = None,
verbose: bool = True,
):
"""Make a parameter trainable.

Args:
Expand All @@ -308,16 +314,22 @@ def make_trainable(self, key: str, init_val: Optional[Union[float, list]] = None
to match the number of created parameters. If `None`, the current
parameter value is used and if parameter sharing is performed that the
current parameter value is averaged over all shared parameters.
verbose: Whether to print the number of parameters that are added and the
total number of parameters.
"""
view = deepcopy(self.nodes.assign(controlled_by_param=0))
self._make_trainable(view, key, init_val)
self._make_trainable(view, key, init_val, verbose=verbose)

def _make_trainable(
self, view, key: str, init_val: Optional[Union[float, list]] = None
self,
view,
key: str,
init_val: Optional[Union[float, list]] = None,
verbose: bool = True,
):
assert (
self.allow_make_trainable
), "network.cell('all') is not supported. Use a for-loop over cells."
), "network.cell('all').make_trainable() is not supported. Use a for-loop over cells."

grouped_view = view.groupby("controlled_by_param")
inds_of_comps = list(grouped_view.apply(lambda x: x.index.values))
Expand Down Expand Up @@ -356,6 +368,11 @@ def _make_trainable(
new_params = jnp.mean(param_vals, axis=1, keepdims=True)

self.trainable_params.append({key: new_params})
self.num_trainable_params += num_created_parameters
if verbose:
print(
f"Number of newly added trainable parameters: {num_created_parameters}. Total number of trainable parameters: {self.num_trainable_params}"
)

def add_to_group(self, group_name):
raise ValueError("`add_to_group()` makes no sense for an entire module.")
Expand Down
3 changes: 3 additions & 0 deletions tests/test_make_trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ def test_make_trainable():

cell.branch(0).comp(0.0).set_params("length", 12.0)
cell.branch(1).comp(1.0).set_params("gNa", 0.2)
assert cell.num_trainable_params == 2

cell.branch([0, 1]).make_trainable("radius", 1.0)
assert cell.num_trainable_params == 4
cell.branch([0, 1]).make_trainable("length")
cell.branch([0, 1]).make_trainable("axial_resistivity", [600.0, 700.0])
cell.branch([0, 1]).make_trainable("gNa")
Expand Down Expand Up @@ -69,3 +71,4 @@ def test_make_trainable_network():

cell.get_parameters()
net.GlutamateSynapse.set_params("gS", 0.1)
assert cell.num_trainable_params == 8 # `set_params()` is ignored.