Skip to content

Commit

Permalink
Enable param sharing with differently sized groups (#514)
Browse files Browse the repository at this point in the history
* fix: enable param sharing with differently sized groups. fixes #501

* doc: add more doc

* wip: fix padding, save wip

* enh: better tests

* fix: rebase and fixes

* chore: add changes to changelog

* add comments to clarify parameter sharing

---------

Co-authored-by: michaeldeistler <[email protected]>
  • Loading branch information
jnsbck and michaeldeistler authored Dec 5, 2024
1 parent 79f311c commit 4ea4c16
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 9 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ net.arrange_in_layers([3,3])
net.vis()
```

- Allow parameter sharing for groups of different sizes, i.e. due to inhomogenous numbers of compartments or for synapses with the same (pre-)synaptic parameters but different numbers of post-synaptic partners. (#514, @jnsbck)

# 0.5.0

### API changes
Expand Down
32 changes: 23 additions & 9 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1058,9 +1058,6 @@ def make_trainable(
ncomps_per_branch = (
self.base.nodes["global_branch_index"].value_counts().to_numpy()
)
assert np.all(
ncomps_per_branch == ncomps_per_branch[0]
), "Parameter sharing is not allowed for modules containing branches with different numbers of compartments."

data = self.nodes if key in self.nodes.columns else None
data = self.edges if key in self.edges.columns else data
Expand All @@ -1075,14 +1072,31 @@ def make_trainable(
grouped_view = data.groupby("controlled_by_param")
# Because of this `x.index.values` we cannot support `make_trainable()` on
# the module level for synapse parameters (but only for `SynapseView`).
inds_of_comps = list(
comp_inds = list(
grouped_view.apply(lambda x: x.index.values, include_groups=False)
)
indices_per_param = jnp.stack(inds_of_comps)

# check if all shapes in comp_inds are the same. If not the case this means
# the groups in controlled_by_param have different sizes, i.e. due to different
# number of comps for two different branches. In this case we pad the smaller
# groups with -1 to make them the same size.
lens = np.array([inds.shape[0] for inds in comp_inds])
max_len = np.max(lens)
pad = lambda x: np.pad(x, (0, max_len - x.shape[0]), constant_values=-1)
if not np.all(lens == max_len):
comp_inds = [
pad(inds) if inds.shape[0] < max_len else inds for inds in comp_inds
]

# Sorted inds are only used to infer the correct starting values.
param_vals = jnp.asarray(
[data.loc[inds, key].to_numpy() for inds in inds_of_comps]
)
indices_per_param = jnp.stack(comp_inds)

# Assign dummy param (ignored by nanmean later). This adds a new row to the
# `data` (which is, e.g., self.nodes). That new row has index `-1`, which does
# not clash with any other node index (they are in
# `[0, ..., num_total_comps-1]`).
data.loc[-1, key] = np.nan
param_vals = jnp.asarray([data.loc[inds, key].to_numpy() for inds in comp_inds])

# Set the value which the trainable parameter should take.
num_created_parameters = len(indices_per_param)
Expand All @@ -1099,7 +1113,7 @@ def make_trainable(
f"init_val must a float, list, or None, but it is a {type(init_val).__name__}."
)
else:
new_params = jnp.mean(param_vals, axis=1)
new_params = jnp.nanmean(param_vals, axis=1)
self.base.trainable_params.append({key: new_params})
self.base.indices_set_by_trainables.append(indices_per_param)
self.base.num_trainable_params += num_created_parameters
Expand Down
25 changes: 25 additions & 0 deletions tests/test_make_trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,3 +531,28 @@ def test_write_trainables(SimpleNet):
# Test whether synapse view raises an error.
with pytest.raises(AssertionError):
net.select(edges=[0, 2, 3]).write_trainables(params)


def test_param_sharing_w_different_group_sizes():
# test if make_trainable corresponds to set
branch1 = jx.Branch(nseg=6)
branch1.nodes["controlled_by_param"] = np.array([0, 0, 0, 1, 1, 2])
branch1.make_trainable("radius")
assert branch1.num_trainable_params == 3

# make trainable
params = branch1.get_parameters()
params[0]["radius"] = params[0]["radius"].at[:].set([2, 3, 4])
branch1.to_jax()
pstate = params_to_pstate(params, branch1.indices_set_by_trainables)
params1 = branch1.get_all_parameters(pstate, voltage_solver="jaxley.thomas")

# set
branch2 = jx.Branch(nseg=6)
branch2.set("radius", np.array([2, 2, 2, 3, 3, 4]))
params = branch2.get_parameters()
branch2.to_jax()
pstate = params_to_pstate(params, branch2.indices_set_by_trainables)
params2 = branch2.get_all_parameters(pstate, voltage_solver="jaxley.thomas")

assert np.array_equal(params1["radius"], params2["radius"], equal_nan=True)

0 comments on commit 4ea4c16

Please sign in to comment.