From bf6e2fc5f0d5ea2ea052f1d05390f0c7a98aa8dc Mon Sep 17 00:00:00 2001 From: Michael Deistler Date: Wed, 8 Nov 2023 09:57:00 +0100 Subject: [PATCH 1/2] bugfixes for network --- neurax/modules/cell.py | 5 +++++ neurax/modules/network.py | 16 ++++++++-------- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/neurax/modules/cell.py b/neurax/modules/cell.py index 4c34ad65..942abc58 100644 --- a/neurax/modules/cell.py +++ b/neurax/modules/cell.py @@ -138,6 +138,7 @@ def init_conds(self, params): conds[1], parents, ) + print("summed", summed_coupling_conds) branch_conds_fwd = jnp.zeros((nbranches)) branch_conds_bwd = jnp.zeros((nbranches)) @@ -191,7 +192,11 @@ def update_summed_coupling_conds( parents: shape [num_branches] """ + # print("conds_bwd", conds_bwd) + # print("child_inds", child_inds) + # print("pre", summed_conds) summed_conds = summed_conds.at[child_inds, -1].add(conds_bwd[child_inds - 1]) + # print("mid", summed_conds) dnums = ScatterDimensionNumbers( update_window_dims=(), diff --git a/neurax/modules/network.py b/neurax/modules/network.py index a9afbc8b..629fec6d 100644 --- a/neurax/modules/network.py +++ b/neurax/modules/network.py @@ -166,13 +166,18 @@ def init_conds(self, params): child_inds = self.branch_edges["child_branch_index"].to_numpy() conds = vmap(Cell.init_cell_conds, in_axes=(0, 0, 0, 0, 0, 0))( - axial_resistivity[par_inds, 0], axial_resistivity[child_inds, -1], - radiuses[par_inds, 0], + axial_resistivity[par_inds, 0], radiuses[child_inds, -1], - lengths[par_inds, 0], + radiuses[par_inds, 0], lengths[child_inds, -1], + lengths[par_inds, 0], ) + branch_conds_fwd = jnp.zeros((nbranches)) + branch_conds_bwd = jnp.zeros((nbranches)) + branch_conds_fwd = branch_conds_fwd.at[child_inds].set(conds[0]) + branch_conds_bwd = branch_conds_bwd.at[child_inds].set(conds[1]) + summed_coupling_conds = Cell.update_summed_coupling_conds( summed_coupling_conds, child_inds, @@ -181,11 +186,6 @@ def init_conds(self, params): parents, ) - branch_conds_fwd = jnp.zeros((nbranches)) - branch_conds_bwd = jnp.zeros((nbranches)) - branch_conds_fwd = branch_conds_fwd.at[child_inds].set(conds[0]) - branch_conds_bwd = branch_conds_bwd.at[child_inds].set(conds[1]) - cond_params = { "coupling_conds_fwd": coupling_conds_fwd, "coupling_conds_bwd": coupling_conds_bwd, From 9d72d4de70febf008ff9a35d8acfefa6ea8df3a9 Mon Sep 17 00:00:00 2001 From: Michael Deistler Date: Wed, 8 Nov 2023 10:13:11 +0100 Subject: [PATCH 2/2] Fix bug in summed coupling conductance of network --- neurax/modules/cell.py | 23 +++++++++-------------- neurax/modules/network.py | 6 +++--- 2 files changed, 12 insertions(+), 17 deletions(-) diff --git a/neurax/modules/cell.py b/neurax/modules/cell.py index 942abc58..c8809cd8 100644 --- a/neurax/modules/cell.py +++ b/neurax/modules/cell.py @@ -131,19 +131,18 @@ def init_conds(self, params): lengths[child_inds, -1], lengths[par_inds, 0], ) + branch_conds_fwd = jnp.zeros((nbranches)) + branch_conds_bwd = jnp.zeros((nbranches)) + branch_conds_fwd = branch_conds_fwd.at[child_inds].set(conds[0]) + branch_conds_bwd = branch_conds_bwd.at[child_inds].set(conds[1]) + summed_coupling_conds = self.update_summed_coupling_conds( summed_coupling_conds, child_inds, - conds[0], - conds[1], + branch_conds_fwd, + branch_conds_bwd, parents, ) - print("summed", summed_coupling_conds) - - branch_conds_fwd = jnp.zeros((nbranches)) - branch_conds_bwd = jnp.zeros((nbranches)) - branch_conds_fwd = branch_conds_fwd.at[child_inds].set(conds[0]) - branch_conds_bwd = branch_conds_bwd.at[child_inds].set(conds[1]) cond_params = { "coupling_conds_fwd": coupling_conds_fwd, @@ -192,11 +191,7 @@ def update_summed_coupling_conds( parents: shape [num_branches] """ - # print("conds_bwd", conds_bwd) - # print("child_inds", child_inds) - # print("pre", summed_conds) - summed_conds = summed_conds.at[child_inds, -1].add(conds_bwd[child_inds - 1]) - # print("mid", summed_conds) + summed_conds = summed_conds.at[child_inds, -1].add(conds_bwd[child_inds]) dnums = ScatterDimensionNumbers( update_window_dims=(), @@ -206,7 +201,7 @@ def update_summed_coupling_conds( summed_conds = scatter_add( summed_conds, jnp.stack([parents[child_inds], jnp.zeros_like(parents[child_inds])]).T, - conds_fwd[child_inds - 1], + conds_fwd[child_inds], dnums, ) return summed_conds diff --git a/neurax/modules/network.py b/neurax/modules/network.py index 629fec6d..96d5d0e4 100644 --- a/neurax/modules/network.py +++ b/neurax/modules/network.py @@ -177,12 +177,12 @@ def init_conds(self, params): branch_conds_bwd = jnp.zeros((nbranches)) branch_conds_fwd = branch_conds_fwd.at[child_inds].set(conds[0]) branch_conds_bwd = branch_conds_bwd.at[child_inds].set(conds[1]) - + summed_coupling_conds = Cell.update_summed_coupling_conds( summed_coupling_conds, child_inds, - conds[0], - conds[1], + branch_conds_fwd, + branch_conds_bwd, parents, )