diff --git a/CHANGELOG.md b/CHANGELOG.md index dbdc3d4c..065c0577 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,8 @@ net.arrange_in_layers([3,3]) net.vis() ``` +- Allow parameter sharing for groups of different sizes, i.e. due to inhomogenous numbers of compartments or for synapses with the same (pre-)synaptic parameters but different numbers of post-synaptic partners. (#514, @jnsbck) + # 0.5.0 ### API changes diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 59ccbd5c..2893f983 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -1058,9 +1058,6 @@ def make_trainable( ncomps_per_branch = ( self.base.nodes["global_branch_index"].value_counts().to_numpy() ) - assert np.all( - ncomps_per_branch == ncomps_per_branch[0] - ), "Parameter sharing is not allowed for modules containing branches with different numbers of compartments." data = self.nodes if key in self.nodes.columns else None data = self.edges if key in self.edges.columns else data @@ -1075,14 +1072,31 @@ def make_trainable( grouped_view = data.groupby("controlled_by_param") # Because of this `x.index.values` we cannot support `make_trainable()` on # the module level for synapse parameters (but only for `SynapseView`). - inds_of_comps = list( + comp_inds = list( grouped_view.apply(lambda x: x.index.values, include_groups=False) ) - indices_per_param = jnp.stack(inds_of_comps) + + # check if all shapes in comp_inds are the same. If not the case this means + # the groups in controlled_by_param have different sizes, i.e. due to different + # number of comps for two different branches. In this case we pad the smaller + # groups with -1 to make them the same size. + lens = np.array([inds.shape[0] for inds in comp_inds]) + max_len = np.max(lens) + pad = lambda x: np.pad(x, (0, max_len - x.shape[0]), constant_values=-1) + if not np.all(lens == max_len): + comp_inds = [ + pad(inds) if inds.shape[0] < max_len else inds for inds in comp_inds + ] + # Sorted inds are only used to infer the correct starting values. - param_vals = jnp.asarray( - [data.loc[inds, key].to_numpy() for inds in inds_of_comps] - ) + indices_per_param = jnp.stack(comp_inds) + + # Assign dummy param (ignored by nanmean later). This adds a new row to the + # `data` (which is, e.g., self.nodes). That new row has index `-1`, which does + # not clash with any other node index (they are in + # `[0, ..., num_total_comps-1]`). + data.loc[-1, key] = np.nan + param_vals = jnp.asarray([data.loc[inds, key].to_numpy() for inds in comp_inds]) # Set the value which the trainable parameter should take. num_created_parameters = len(indices_per_param) @@ -1099,7 +1113,7 @@ def make_trainable( f"init_val must a float, list, or None, but it is a {type(init_val).__name__}." ) else: - new_params = jnp.mean(param_vals, axis=1) + new_params = jnp.nanmean(param_vals, axis=1) self.base.trainable_params.append({key: new_params}) self.base.indices_set_by_trainables.append(indices_per_param) self.base.num_trainable_params += num_created_parameters diff --git a/tests/test_make_trainable.py b/tests/test_make_trainable.py index 50ece696..783461b3 100644 --- a/tests/test_make_trainable.py +++ b/tests/test_make_trainable.py @@ -531,3 +531,28 @@ def test_write_trainables(SimpleNet): # Test whether synapse view raises an error. with pytest.raises(AssertionError): net.select(edges=[0, 2, 3]).write_trainables(params) + + +def test_param_sharing_w_different_group_sizes(): + # test if make_trainable corresponds to set + branch1 = jx.Branch(nseg=6) + branch1.nodes["controlled_by_param"] = np.array([0, 0, 0, 1, 1, 2]) + branch1.make_trainable("radius") + assert branch1.num_trainable_params == 3 + + # make trainable + params = branch1.get_parameters() + params[0]["radius"] = params[0]["radius"].at[:].set([2, 3, 4]) + branch1.to_jax() + pstate = params_to_pstate(params, branch1.indices_set_by_trainables) + params1 = branch1.get_all_parameters(pstate, voltage_solver="jaxley.thomas") + + # set + branch2 = jx.Branch(nseg=6) + branch2.set("radius", np.array([2, 2, 2, 3, 3, 4])) + params = branch2.get_parameters() + branch2.to_jax() + pstate = params_to_pstate(params, branch2.indices_set_by_trainables) + params2 = branch2.get_all_parameters(pstate, voltage_solver="jaxley.thomas") + + assert np.array_equal(params1["radius"], params2["radius"], equal_nan=True)