Skip to content

Commit

Permalink
Fix bug in summed coupling conductance of network
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Nov 8, 2023
1 parent bf6e2fc commit 9d72d4d
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 17 deletions.
23 changes: 9 additions & 14 deletions neurax/modules/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,19 +131,18 @@ def init_conds(self, params):
lengths[child_inds, -1],
lengths[par_inds, 0],
)
branch_conds_fwd = jnp.zeros((nbranches))
branch_conds_bwd = jnp.zeros((nbranches))
branch_conds_fwd = branch_conds_fwd.at[child_inds].set(conds[0])
branch_conds_bwd = branch_conds_bwd.at[child_inds].set(conds[1])

summed_coupling_conds = self.update_summed_coupling_conds(
summed_coupling_conds,
child_inds,
conds[0],
conds[1],
branch_conds_fwd,
branch_conds_bwd,
parents,
)
print("summed", summed_coupling_conds)

branch_conds_fwd = jnp.zeros((nbranches))
branch_conds_bwd = jnp.zeros((nbranches))
branch_conds_fwd = branch_conds_fwd.at[child_inds].set(conds[0])
branch_conds_bwd = branch_conds_bwd.at[child_inds].set(conds[1])

cond_params = {
"coupling_conds_fwd": coupling_conds_fwd,
Expand Down Expand Up @@ -192,11 +191,7 @@ def update_summed_coupling_conds(
parents: shape [num_branches]
"""

# print("conds_bwd", conds_bwd)
# print("child_inds", child_inds)
# print("pre", summed_conds)
summed_conds = summed_conds.at[child_inds, -1].add(conds_bwd[child_inds - 1])
# print("mid", summed_conds)
summed_conds = summed_conds.at[child_inds, -1].add(conds_bwd[child_inds])

dnums = ScatterDimensionNumbers(
update_window_dims=(),
Expand All @@ -206,7 +201,7 @@ def update_summed_coupling_conds(
summed_conds = scatter_add(
summed_conds,
jnp.stack([parents[child_inds], jnp.zeros_like(parents[child_inds])]).T,
conds_fwd[child_inds - 1],
conds_fwd[child_inds],
dnums,
)
return summed_conds
Expand Down
6 changes: 3 additions & 3 deletions neurax/modules/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,12 @@ def init_conds(self, params):
branch_conds_bwd = jnp.zeros((nbranches))
branch_conds_fwd = branch_conds_fwd.at[child_inds].set(conds[0])
branch_conds_bwd = branch_conds_bwd.at[child_inds].set(conds[1])

summed_coupling_conds = Cell.update_summed_coupling_conds(
summed_coupling_conds,
child_inds,
conds[0],
conds[1],
branch_conds_fwd,
branch_conds_bwd,
parents,
)

Expand Down

0 comments on commit 9d72d4d

Please sign in to comment.