diff --git a/jaxley/connection.py b/jaxley/connection.py index 73265138..fe5e90bc 100644 --- a/jaxley/connection.py +++ b/jaxley/connection.py @@ -50,22 +50,22 @@ def fc(self, pre_cell_inds, post_cell_inds): rand_loc = np.random.rand() conns.append(Connection(pre_ind, 0, 0, post_ind, rand_branch, rand_loc)) return conns - + def sparse_random(self, pre_cell_inds, post_cell_inds, p): - """Returns a list of `Connection`s which build a sparse, randomly - connected layer. The presence of each connection is determined by a - Bernoulli trial with probability p. + """Returns a list of `Connection`s which build a sparse, randomly + connected layer. - Connections are from branch 0 location 0 to a randomly chosen branch + Connections are from branch 0 location 0 to a randomly chosen branch and loc. - - NOTE: autapses are allowed. """ + num_pre = len(pre_cell_inds) + num_post = len(post_cell_inds) + num_connections = np.random.binomial(num_pre * num_post, p) + pre_syn_neurons = np.random.choice(pre_cell_inds, size=num_connections) + post_syn_neurons = np.random.choice(post_cell_inds, size=num_connections) + conns = [] - adj_mat = np.random.binomial(1, p, size=(len(pre_cell_inds), - len(post_cell_inds))) - cxn_inds = np.where(adj_mat) - for pre_ind, post_ind in zip(cxn_inds[0], cxn_inds[1]): + for pre_ind, post_ind in zip(pre_syn_neurons, post_syn_neurons): num_branches_post = self.nbranches_per_submodule[post_ind] rand_branch = np.random.randint(0, num_branches_post) rand_loc = np.random.rand()