Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable param sharing with differently sized groups #514

Merged
merged 7 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ net.arrange_in_layers([3,3])
net.vis()
```

- Allow parameter sharing for groups of different sizes, i.e. due to inhomogenous numbers of compartments or for synapses with the same (pre-)synaptic parameters but different numbers of post-synaptic partners. (#514, @jnsbck)

# 0.5.0

### API changes
Expand Down
32 changes: 23 additions & 9 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1058,9 +1058,6 @@ def make_trainable(
ncomps_per_branch = (
self.base.nodes["global_branch_index"].value_counts().to_numpy()
)
assert np.all(
ncomps_per_branch == ncomps_per_branch[0]
), "Parameter sharing is not allowed for modules containing branches with different numbers of compartments."

data = self.nodes if key in self.nodes.columns else None
data = self.edges if key in self.edges.columns else data
Expand All @@ -1075,14 +1072,31 @@ def make_trainable(
grouped_view = data.groupby("controlled_by_param")
# Because of this `x.index.values` we cannot support `make_trainable()` on
# the module level for synapse parameters (but only for `SynapseView`).
inds_of_comps = list(
comp_inds = list(
grouped_view.apply(lambda x: x.index.values, include_groups=False)
)
indices_per_param = jnp.stack(inds_of_comps)

# check if all shapes in comp_inds are the same. If not the case this means
# the groups in controlled_by_param have different sizes, i.e. due to different
# number of comps for two different branches. In this case we pad the smaller
# groups with -1 to make them the same size.
lens = np.array([inds.shape[0] for inds in comp_inds])
max_len = np.max(lens)
pad = lambda x: np.pad(x, (0, max_len - x.shape[0]), constant_values=-1)
if not np.all(lens == max_len):
comp_inds = [
pad(inds) if inds.shape[0] < max_len else inds for inds in comp_inds
]

# Sorted inds are only used to infer the correct starting values.
param_vals = jnp.asarray(
[data.loc[inds, key].to_numpy() for inds in inds_of_comps]
)
indices_per_param = jnp.stack(comp_inds)

# Assign dummy param (ignored by nanmean later). This adds a new row to the
# `data` (which is, e.g., self.nodes). That new row has index `-1`, which does
# not clash with any other node index (they are in
# `[0, ..., num_total_comps-1]`).
data.loc[-1, key] = np.nan
param_vals = jnp.asarray([data.loc[inds, key].to_numpy() for inds in comp_inds])

# Set the value which the trainable parameter should take.
num_created_parameters = len(indices_per_param)
Expand All @@ -1099,7 +1113,7 @@ def make_trainable(
f"init_val must a float, list, or None, but it is a {type(init_val).__name__}."
)
else:
new_params = jnp.mean(param_vals, axis=1)
new_params = jnp.nanmean(param_vals, axis=1)
self.base.trainable_params.append({key: new_params})
self.base.indices_set_by_trainables.append(indices_per_param)
self.base.num_trainable_params += num_created_parameters
Expand Down
25 changes: 25 additions & 0 deletions tests/test_make_trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,3 +531,28 @@ def test_write_trainables(SimpleNet):
# Test whether synapse view raises an error.
with pytest.raises(AssertionError):
net.select(edges=[0, 2, 3]).write_trainables(params)


def test_param_sharing_w_different_group_sizes():
# test if make_trainable corresponds to set
branch1 = jx.Branch(nseg=6)
branch1.nodes["controlled_by_param"] = np.array([0, 0, 0, 1, 1, 2])
branch1.make_trainable("radius")
assert branch1.num_trainable_params == 3

# make trainable
params = branch1.get_parameters()
params[0]["radius"] = params[0]["radius"].at[:].set([2, 3, 4])
branch1.to_jax()
pstate = params_to_pstate(params, branch1.indices_set_by_trainables)
params1 = branch1.get_all_parameters(pstate, voltage_solver="jaxley.thomas")

# set
branch2 = jx.Branch(nseg=6)
branch2.set("radius", np.array([2, 2, 2, 3, 3, 4]))
params = branch2.get_parameters()
branch2.to_jax()
pstate = params_to_pstate(params, branch2.indices_set_by_trainables)
params2 = branch2.get_all_parameters(pstate, voltage_solver="jaxley.thomas")

assert np.array_equal(params1["radius"], params2["radius"], equal_nan=True)