diff --git a/neurax/modules/cell.py b/neurax/modules/cell.py index 942abc58b..33287a024 100644 --- a/neurax/modules/cell.py +++ b/neurax/modules/cell.py @@ -131,19 +131,18 @@ def init_conds(self, params): lengths[child_inds, -1], lengths[par_inds, 0], ) + branch_conds_fwd = jnp.zeros((nbranches)) + branch_conds_bwd = jnp.zeros((nbranches)) + branch_conds_fwd = branch_conds_fwd.at[child_inds].set(conds[0]) + branch_conds_bwd = branch_conds_bwd.at[child_inds].set(conds[1]) + summed_coupling_conds = self.update_summed_coupling_conds( summed_coupling_conds, child_inds, - conds[0], - conds[1], + branch_conds_fwd, + branch_conds_bwd, parents, ) - print("summed", summed_coupling_conds) - - branch_conds_fwd = jnp.zeros((nbranches)) - branch_conds_bwd = jnp.zeros((nbranches)) - branch_conds_fwd = branch_conds_fwd.at[child_inds].set(conds[0]) - branch_conds_bwd = branch_conds_bwd.at[child_inds].set(conds[1]) cond_params = { "coupling_conds_fwd": coupling_conds_fwd, @@ -191,12 +190,8 @@ def update_summed_coupling_conds( conds_bwd: shape [num_branches - 1] parents: shape [num_branches] """ - - # print("conds_bwd", conds_bwd) - # print("child_inds", child_inds) - # print("pre", summed_conds) - summed_conds = summed_conds.at[child_inds, -1].add(conds_bwd[child_inds - 1]) - # print("mid", summed_conds) + + summed_conds = summed_conds.at[child_inds, -1].add(conds_bwd[child_inds]) dnums = ScatterDimensionNumbers( update_window_dims=(), @@ -206,7 +201,7 @@ def update_summed_coupling_conds( summed_conds = scatter_add( summed_conds, jnp.stack([parents[child_inds], jnp.zeros_like(parents[child_inds])]).T, - conds_fwd[child_inds - 1], + conds_fwd[child_inds], dnums, ) return summed_conds diff --git a/neurax/modules/network.py b/neurax/modules/network.py index 629fec6d1..96d5d0e4d 100644 --- a/neurax/modules/network.py +++ b/neurax/modules/network.py @@ -177,12 +177,12 @@ def init_conds(self, params): branch_conds_bwd = jnp.zeros((nbranches)) branch_conds_fwd = branch_conds_fwd.at[child_inds].set(conds[0]) branch_conds_bwd = branch_conds_bwd.at[child_inds].set(conds[1]) - + summed_coupling_conds = Cell.update_summed_coupling_conds( summed_coupling_conds, child_inds, - conds[0], - conds[1], + branch_conds_fwd, + branch_conds_bwd, parents, )