Skip to content

Commit

Permalink
Fully connected and sparely connected
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Dec 11, 2023
1 parent 4a64cf3 commit 6a74de4
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 4 deletions.
6 changes: 2 additions & 4 deletions jaxley/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,9 @@ def fc(self, pre_cell_inds, post_cell_inds):
return conns

def sparse_random(self, pre_cell_inds, post_cell_inds, p):
"""Returns a list of `Connection`s which build a sparse, randomly
connected layer.
"""Returns a list of `Connection`s forming a sparse, randomly connected layer.
Connections are from branch 0 location 0 to a randomly chosen branch
and loc.
Connections are from branch 0 location 0 to a randomly chosen branch and loc.
"""
num_pre = len(pre_cell_inds)
num_post = len(post_cell_inds)
Expand Down
41 changes: 41 additions & 0 deletions jaxley/modules/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,47 @@ def __call__(self, index: float):
def __getattr__(self, key):
assert key == "branch"
return BranchView(self.pointer, self.view)

def fully_connect(self, post_cell_view, synapse_type):
"""Returns a list of `Connection`s which build a fully connected layer.
Connections are from branch 0 location 0 to a randomly chosen branch and loc.
"""
pre_cell_inds = np.unique(self.view["cell_index"].to_numpy())
post_cell_inds = np.unique(post_cell_view.view["cell_index"].to_numpy())

for pre_ind in pre_cell_inds:
for post_ind in post_cell_inds:
num_branches_post = self.pointer.nbranches_per_cell[post_ind]
rand_branch = np.random.randint(0, num_branches_post)
rand_loc = np.random.rand()

pre = self.pointer.cell(pre_ind).branch(rand_branch).comp(rand_loc)
post = self.pointer.cell(post_ind).branch(rand_branch).comp(rand_loc)
pre.connect(post, synapse_type)

def sparse_connect(self, post_cell_view, p, synapse_type):
"""Returns a list of `Connection`s forming a sparse, randomly connected layer.
Connections are from branch 0 location 0 to a randomly chosen branch and loc.
"""
pre_cell_inds = np.unique(self.view["cell_index"].to_numpy())
post_cell_inds = np.unique(post_cell_view.view["cell_index"].to_numpy())

num_pre = len(pre_cell_inds)
num_post = len(post_cell_inds)
num_connections = np.random.binomial(num_pre * num_post, p)
pre_syn_neurons = np.random.choice(pre_cell_inds, size=num_connections)
post_syn_neurons = np.random.choice(post_cell_inds, size=num_connections)

for pre_ind, post_ind in zip(pre_syn_neurons, post_syn_neurons):
num_branches_post = self.pointer.nbranches_per_cell[post_ind]
rand_branch = np.random.randint(0, num_branches_post)
rand_loc = np.random.rand()

pre = self.pointer.cell(pre_ind).branch(rand_branch).comp(rand_loc)
post = self.pointer.cell(post_ind).branch(rand_branch).comp(rand_loc)
pre.connect(post, synapse_type)


def read_swc(
Expand Down

0 comments on commit 6a74de4

Please sign in to comment.