From 6a74de4ffb5ca593dc2008a0f6b41e3cf3fdb922 Mon Sep 17 00:00:00 2001 From: Michael Deistler Date: Mon, 11 Dec 2023 18:04:57 +0100 Subject: [PATCH] Fully connected and sparely connected --- jaxley/connection.py | 6 ++---- jaxley/modules/cell.py | 41 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/jaxley/connection.py b/jaxley/connection.py index fe5e90bc..b5312d4c 100644 --- a/jaxley/connection.py +++ b/jaxley/connection.py @@ -52,11 +52,9 @@ def fc(self, pre_cell_inds, post_cell_inds): return conns def sparse_random(self, pre_cell_inds, post_cell_inds, p): - """Returns a list of `Connection`s which build a sparse, randomly - connected layer. + """Returns a list of `Connection`s forming a sparse, randomly connected layer. - Connections are from branch 0 location 0 to a randomly chosen branch - and loc. + Connections are from branch 0 location 0 to a randomly chosen branch and loc. """ num_pre = len(pre_cell_inds) num_post = len(post_cell_inds) diff --git a/jaxley/modules/cell.py b/jaxley/modules/cell.py index 35908a30..a2ef74c4 100644 --- a/jaxley/modules/cell.py +++ b/jaxley/modules/cell.py @@ -258,6 +258,47 @@ def __call__(self, index: float): def __getattr__(self, key): assert key == "branch" return BranchView(self.pointer, self.view) + + def fully_connect(self, post_cell_view, synapse_type): + """Returns a list of `Connection`s which build a fully connected layer. + + Connections are from branch 0 location 0 to a randomly chosen branch and loc. + """ + pre_cell_inds = np.unique(self.view["cell_index"].to_numpy()) + post_cell_inds = np.unique(post_cell_view.view["cell_index"].to_numpy()) + + for pre_ind in pre_cell_inds: + for post_ind in post_cell_inds: + num_branches_post = self.pointer.nbranches_per_cell[post_ind] + rand_branch = np.random.randint(0, num_branches_post) + rand_loc = np.random.rand() + + pre = self.pointer.cell(pre_ind).branch(rand_branch).comp(rand_loc) + post = self.pointer.cell(post_ind).branch(rand_branch).comp(rand_loc) + pre.connect(post, synapse_type) + + def sparse_connect(self, post_cell_view, p, synapse_type): + """Returns a list of `Connection`s forming a sparse, randomly connected layer. + + Connections are from branch 0 location 0 to a randomly chosen branch and loc. + """ + pre_cell_inds = np.unique(self.view["cell_index"].to_numpy()) + post_cell_inds = np.unique(post_cell_view.view["cell_index"].to_numpy()) + + num_pre = len(pre_cell_inds) + num_post = len(post_cell_inds) + num_connections = np.random.binomial(num_pre * num_post, p) + pre_syn_neurons = np.random.choice(pre_cell_inds, size=num_connections) + post_syn_neurons = np.random.choice(post_cell_inds, size=num_connections) + + for pre_ind, post_ind in zip(pre_syn_neurons, post_syn_neurons): + num_branches_post = self.pointer.nbranches_per_cell[post_ind] + rand_branch = np.random.randint(0, num_branches_post) + rand_loc = np.random.rand() + + pre = self.pointer.cell(pre_ind).branch(rand_branch).comp(rand_loc) + post = self.pointer.cell(post_ind).branch(rand_branch).comp(rand_loc) + pre.connect(post, synapse_type) def read_swc(