From 54e2366fb501f78aebaf901e90891e6b834676c5 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Tue, 19 Nov 2024 19:00:04 +0100 Subject: [PATCH 1/7] fix: enable param sharing with differently sized groups. fixes #501 --- jaxley/modules/base.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 59ccbd5c..7d15484a 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -1075,14 +1075,24 @@ def make_trainable( grouped_view = data.groupby("controlled_by_param") # Because of this `x.index.values` we cannot support `make_trainable()` on # the module level for synapse parameters (but only for `SynapseView`). - inds_of_comps = list( + comp_inds = list( grouped_view.apply(lambda x: x.index.values, include_groups=False) ) - indices_per_param = jnp.stack(inds_of_comps) + + # check if all shapes in comp_inds are the same. If not the case this means + # the groups in controlled_by_param have different sizes, i.e. due to different + # number of comps for two different branches. + lens = np.array([inds.shape[0] for inds in comp_inds]) + max_len = np.max(lens) + pad = lambda x: np.pad(x, (0, max_len - x.shape[0]), constant_values=-1) + if not np.all(lens == max_len): + comp_inds = [pad(inds) for inds in comp_inds if inds.shape[0] < max_len] + + indices_per_param = jnp.stack(comp_inds) + # Sorted inds are only used to infer the correct starting values. - param_vals = jnp.asarray( - [data.loc[inds, key].to_numpy() for inds in inds_of_comps] - ) + data.loc[-1, key] = np.nan # assign dummy index to NaN + param_vals = jnp.asarray([data.loc[inds, key].to_numpy() for inds in comp_inds]) # Set the value which the trainable parameter should take. num_created_parameters = len(indices_per_param) @@ -1099,7 +1109,7 @@ def make_trainable( f"init_val must a float, list, or None, but it is a {type(init_val).__name__}." ) else: - new_params = jnp.mean(param_vals, axis=1) + new_params = jnp.nanmean(param_vals, axis=1) self.base.trainable_params.append({key: new_params}) self.base.indices_set_by_trainables.append(indices_per_param) self.base.num_trainable_params += num_created_parameters From 60908a4fc16a8ed68243cdfb11782bb5b0b8dada Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Tue, 19 Nov 2024 19:06:16 +0100 Subject: [PATCH 2/7] doc: add more doc --- jaxley/modules/base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 7d15484a..c12d5977 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -1081,7 +1081,8 @@ def make_trainable( # check if all shapes in comp_inds are the same. If not the case this means # the groups in controlled_by_param have different sizes, i.e. due to different - # number of comps for two different branches. + # number of comps for two different branches. In this case we pad the smaller + # groups with -1 to make them the same size. lens = np.array([inds.shape[0] for inds in comp_inds]) max_len = np.max(lens) pad = lambda x: np.pad(x, (0, max_len - x.shape[0]), constant_values=-1) @@ -1091,7 +1092,7 @@ def make_trainable( indices_per_param = jnp.stack(comp_inds) # Sorted inds are only used to infer the correct starting values. - data.loc[-1, key] = np.nan # assign dummy index to NaN + data.loc[-1, key] = np.nan # assign dummy param (ignored by nanmean later) param_vals = jnp.asarray([data.loc[inds, key].to_numpy() for inds in comp_inds]) # Set the value which the trainable parameter should take. From 4aa97040afbe4fa332a3482de0667329af96eab9 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Wed, 20 Nov 2024 16:09:00 +0100 Subject: [PATCH 3/7] wip: fix padding, save wip --- jaxley/modules/base.py | 4 +++- tests/test_make_trainable.py | 12 ++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index c12d5977..0c8662be 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -1087,7 +1087,9 @@ def make_trainable( max_len = np.max(lens) pad = lambda x: np.pad(x, (0, max_len - x.shape[0]), constant_values=-1) if not np.all(lens == max_len): - comp_inds = [pad(inds) for inds in comp_inds if inds.shape[0] < max_len] + comp_inds = [ + pad(inds) if inds.shape[0] < max_len else inds for inds in comp_inds + ] indices_per_param = jnp.stack(comp_inds) diff --git a/tests/test_make_trainable.py b/tests/test_make_trainable.py index 50ece696..a6b0e0ea 100644 --- a/tests/test_make_trainable.py +++ b/tests/test_make_trainable.py @@ -531,3 +531,15 @@ def test_write_trainables(SimpleNet): # Test whether synapse view raises an error. with pytest.raises(AssertionError): net.select(edges=[0, 2, 3]).write_trainables(params) + + +def test_param_sharing_w_different_group_sizes(): + branch = jx.Branch(nseg=6) + branch.nodes["controlled_by_param"] = np.array([0, 0, 0, 1, 1, 2]) + branch.make_trainable("radius") + assert branch.num_trainable_params == 3 + + params = branch.get_parameters() + branch.to_jax() + pstate = params_to_pstate(params, branch.indices_set_by_trainables) + branch.get_all_parameters(pstate, voltage_solver="jaxley.thomas") From 205090b179205bea523dc9e895c64c4a99dc7a09 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Wed, 4 Dec 2024 19:23:54 +0100 Subject: [PATCH 4/7] enh: better tests --- tests/test_make_trainable.py | 31 ++++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/tests/test_make_trainable.py b/tests/test_make_trainable.py index a6b0e0ea..783461b3 100644 --- a/tests/test_make_trainable.py +++ b/tests/test_make_trainable.py @@ -534,12 +534,25 @@ def test_write_trainables(SimpleNet): def test_param_sharing_w_different_group_sizes(): - branch = jx.Branch(nseg=6) - branch.nodes["controlled_by_param"] = np.array([0, 0, 0, 1, 1, 2]) - branch.make_trainable("radius") - assert branch.num_trainable_params == 3 - - params = branch.get_parameters() - branch.to_jax() - pstate = params_to_pstate(params, branch.indices_set_by_trainables) - branch.get_all_parameters(pstate, voltage_solver="jaxley.thomas") + # test if make_trainable corresponds to set + branch1 = jx.Branch(nseg=6) + branch1.nodes["controlled_by_param"] = np.array([0, 0, 0, 1, 1, 2]) + branch1.make_trainable("radius") + assert branch1.num_trainable_params == 3 + + # make trainable + params = branch1.get_parameters() + params[0]["radius"] = params[0]["radius"].at[:].set([2, 3, 4]) + branch1.to_jax() + pstate = params_to_pstate(params, branch1.indices_set_by_trainables) + params1 = branch1.get_all_parameters(pstate, voltage_solver="jaxley.thomas") + + # set + branch2 = jx.Branch(nseg=6) + branch2.set("radius", np.array([2, 2, 2, 3, 3, 4])) + params = branch2.get_parameters() + branch2.to_jax() + pstate = params_to_pstate(params, branch2.indices_set_by_trainables) + params2 = branch2.get_all_parameters(pstate, voltage_solver="jaxley.thomas") + + assert np.array_equal(params1["radius"], params2["radius"], equal_nan=True) From ab47228b0258d43fbfae872dc6af2090c05f40ff Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Wed, 4 Dec 2024 19:29:56 +0100 Subject: [PATCH 5/7] fix: rebase and fixes --- jaxley/modules/base.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 0c8662be..2186923d 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -1058,9 +1058,6 @@ def make_trainable( ncomps_per_branch = ( self.base.nodes["global_branch_index"].value_counts().to_numpy() ) - assert np.all( - ncomps_per_branch == ncomps_per_branch[0] - ), "Parameter sharing is not allowed for modules containing branches with different numbers of compartments." data = self.nodes if key in self.nodes.columns else None data = self.edges if key in self.edges.columns else data From 838cf3843bb8f85d5663fd57efb87aa574c207b8 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Wed, 4 Dec 2024 19:36:54 +0100 Subject: [PATCH 6/7] chore: add changes to changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index dbdc3d4c..065c0577 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,8 @@ net.arrange_in_layers([3,3]) net.vis() ``` +- Allow parameter sharing for groups of different sizes, i.e. due to inhomogenous numbers of compartments or for synapses with the same (pre-)synaptic parameters but different numbers of post-synaptic partners. (#514, @jnsbck) + # 0.5.0 ### API changes From c095a90c840ae6ac1510c0bcfc574305db3a2dd8 Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Thu, 5 Dec 2024 15:05:56 +0100 Subject: [PATCH 7/7] add comments to clarify parameter sharing --- jaxley/modules/base.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 2186923d..2893f983 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -1088,10 +1088,14 @@ def make_trainable( pad(inds) if inds.shape[0] < max_len else inds for inds in comp_inds ] + # Sorted inds are only used to infer the correct starting values. indices_per_param = jnp.stack(comp_inds) - # Sorted inds are only used to infer the correct starting values. - data.loc[-1, key] = np.nan # assign dummy param (ignored by nanmean later) + # Assign dummy param (ignored by nanmean later). This adds a new row to the + # `data` (which is, e.g., self.nodes). That new row has index `-1`, which does + # not clash with any other node index (they are in + # `[0, ..., num_total_comps-1]`). + data.loc[-1, key] = np.nan param_vals = jnp.asarray([data.loc[inds, key].to_numpy() for inds in comp_inds]) # Set the value which the trainable parameter should take.