-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Connectivity_matrix_connect update #489
base: main
Are you sure you want to change the base?
Changes from 9 commits
2d6a6f0
8a7b370
dbb8ef5
cf498f6
588198a
322e60c
b80f462
4971c32
5471c85
2a85d5a
23b31fe
c504689
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -44,33 +44,52 @@ def fully_connect( | |
pre_cell_view: "View", | ||
post_cell_view: "View", | ||
synapse_type: "Synapse", | ||
random_post_comp: bool = False, | ||
kyralianaka marked this conversation as resolved.
Show resolved
Hide resolved
|
||
): | ||
"""Appends multiple connections which build a fully connected layer. | ||
|
||
Connections are from branch 0 location 0 to a randomly chosen branch and loc. | ||
Connections are from branch 0 location 0 of the pre-synaptic cell to branch 0 | ||
location 0 of the post-synaptic cell unless random_post_comp=True. | ||
|
||
Args: | ||
pre_cell_view: View of the presynaptic cell. | ||
post_cell_view: View of the postsynaptic cell. | ||
synapse_type: The synapse to append. | ||
random_post_comp: If True, randomly samples the postsynaptic compartments. | ||
""" | ||
# Get pre- and postsynaptic cell indices. | ||
num_pre = len(pre_cell_view._cells_in_view) | ||
num_post = len(post_cell_view._cells_in_view) | ||
|
||
# Infer indices of (random) postsynaptic compartments. | ||
global_post_indices = ( | ||
post_cell_view.nodes.groupby("global_cell_index") | ||
.sample(num_pre, replace=True) | ||
.index.to_numpy() | ||
) | ||
global_post_indices = global_post_indices.reshape((-1, num_pre), order="F").ravel() | ||
post_rows = post_cell_view.nodes.loc[global_post_indices] | ||
|
||
# Pre-synapse is at the zero-eth branch and zero-eth compartment. | ||
pre_rows = pre_cell_view.scope("local").branch(0).comp(0).nodes.copy() | ||
# Repeat rows `num_post` times. See SO 50788508. | ||
pre_rows = pre_rows.loc[pre_rows.index.repeat(num_post)].reset_index(drop=True) | ||
# Get the indices of the connections, like it's a fully connected connectivity matrix | ||
from_idx = np.repeat(range(0, num_pre), num_post) | ||
to_idx = np.tile( | ||
kyralianaka marked this conversation as resolved.
Show resolved
Hide resolved
|
||
range(0, num_post), num_pre | ||
) # used only if random_post_comp is False | ||
|
||
# Pre-synapse at the zero-eth branch and zero-eth compartment | ||
global_pre_comp_indices = ( | ||
pre_cell_view.scope("local").branch(0).comp(0).nodes.index.to_numpy() | ||
kyralianaka marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) # setting scope ensure that this works indep of current scope | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. make sure to reset this after, in case the user is in global scope |
||
pre_rows = pre_cell_view.select(nodes=global_pre_comp_indices[from_idx]).nodes | ||
|
||
if random_post_comp: | ||
# Randomly sample the post-synaptic compartments | ||
global_post_comp_indices = ( | ||
post_cell_view.nodes.groupby("global_cell_index") | ||
.sample(num_pre, replace=True) | ||
.index.to_numpy() | ||
kyralianaka marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
global_post_comp_indices = global_post_comp_indices.reshape( | ||
(-1, num_pre), order="F" | ||
).ravel() | ||
else: | ||
# Post-synapse also at the zero-eth branch and zero-eth compartment | ||
global_post_comp_indices = ( | ||
post_cell_view.scope("local").branch(0).comp(0).nodes.index.to_numpy() | ||
) | ||
global_post_comp_indices = global_post_comp_indices[to_idx] | ||
post_rows = post_cell_view.nodes.loc[global_post_comp_indices] | ||
|
||
pre_cell_view.base._append_multiple_synapses(pre_rows, post_rows, synapse_type) | ||
|
||
|
@@ -80,45 +99,50 @@ def sparse_connect( | |
post_cell_view: "View", | ||
synapse_type: "Synapse", | ||
p: float, | ||
random_post_comp: bool = False, | ||
kyralianaka marked this conversation as resolved.
Show resolved
Hide resolved
|
||
): | ||
"""Appends multiple connections which build a sparse, randomly connected layer. | ||
|
||
Connections are from branch 0 location 0 to a randomly chosen branch and loc. | ||
Connections are from branch 0 location 0 of the pre-synaptic cell to branch 0 | ||
location 0 of the post-synaptic cell unless random_post_comp=True. | ||
|
||
Args: | ||
pre_cell_view: View of the presynaptic cell. | ||
post_cell_view: View of the postsynaptic cell. | ||
synapse_type: The synapse to append. | ||
p: Probability of connection. | ||
random_post_comp: If True, randomly samples the postsynaptic compartments. | ||
""" | ||
# Get pre- and postsynaptic cell indices. | ||
pre_cell_inds = pre_cell_view._cells_in_view | ||
post_cell_inds = post_cell_view._cells_in_view | ||
num_pre = len(pre_cell_inds) | ||
num_post = len(post_cell_inds) | ||
|
||
num_connections = np.random.binomial(num_pre * num_post, p) | ||
pre_syn_neurons = np.random.choice(pre_cell_inds, size=num_connections) | ||
post_syn_neurons = np.random.choice(post_cell_inds, size=num_connections) | ||
|
||
# Sort the synapses only for convenience of inspecting `.edges`. | ||
sorting = np.argsort(pre_syn_neurons) | ||
pre_syn_neurons = pre_syn_neurons[sorting] | ||
post_syn_neurons = post_syn_neurons[sorting] | ||
|
||
# Post-synapse is a randomly chosen branch and compartment. | ||
global_post_indices = [ | ||
sample_comp(post_cell_view.scope("global").cell(cell_idx)) | ||
for cell_idx in post_syn_neurons | ||
] | ||
global_post_indices = ( | ||
np.hstack(global_post_indices) if len(global_post_indices) > 1 else [] | ||
) | ||
post_rows = post_cell_view.base.nodes.loc[global_post_indices] | ||
|
||
# Pre-synapse is at the zero-eth branch and zero-eth compartment. | ||
global_pre_indices = pre_cell_view.base._cumsum_nseg_per_cell[pre_syn_neurons] | ||
pre_rows = pre_cell_view.base.nodes.loc[global_pre_indices] | ||
num_pre = len(pre_cell_view._cells_in_view) | ||
num_post = len(post_cell_view._cells_in_view) | ||
|
||
# Generate random cxns without duplicates --> respects p but memory intesive if extremely large n cells | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When I implemented the connect API, @michaeldeistler and me had discussed not wanting to do this, since this can be really bad for large, sparse networks. Did we change our mind about doing it this way? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree its cleaner and easier to parse, but not efficient. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch Jonas! Let's avoid building a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was looking into this for quite a while, and all of the ways that are memory efficient are extremely time intensive, and vice versa, so I came up with a way to split the cost a bit... generating blocks of random matrix at a time. I am still thinking of how it could be prettier, but the resource management seems to be good There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the reason we cannot do it as it was before? If really would like to avoid memory scaling with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The function before created many duplicate connections between the same cells, and simply filtering these out would remove a lot of connections (inefficient) and make the p value entered somewhat meaningless (more problematic). Also p=1.0 did not create a fully connected graph, far from it I might add. Resource efficient methods for constructing sparse random graphs are highly sought after and typically trade off computation speed and memory, so I tried to write a function that balances the two. The function now isn't O(N^2) per se, the largest matrix it creates is 100x100. I can keep looking for better algorithms in the future, but for now the function at least fixes the bug of duplicate connections and restores the traditional meaning of p. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm also thinking now that just removing sparse_connect() might not be so bad... I think we might have used connectivity matrix connect or something else in the RNN experiments anyway |
||
connectivity_matrix = np.random.binomial(1, p, (num_pre, num_post)) | ||
from_idx, to_idx = np.where(connectivity_matrix) | ||
|
||
# Pre-synapse at the zero-eth branch and zero-eth compartment | ||
global_pre_comp_indices = ( | ||
pre_cell_view.scope("local").branch(0).comp(0).nodes.index.to_numpy() | ||
) # setting scope ensure that this works indep of current scope | ||
pre_rows = pre_cell_view.select(nodes=global_pre_comp_indices[from_idx]).nodes | ||
|
||
if random_post_comp: | ||
# Randomly sample the post-synaptic compartments | ||
global_post_comp_indices = ( | ||
post_cell_view.nodes.groupby("global_cell_index") | ||
.sample(num_pre, replace=True) | ||
.index.to_numpy() | ||
) | ||
global_post_comp_indices = global_post_comp_indices.reshape( | ||
(-1, num_pre), order="F" | ||
).ravel() | ||
else: | ||
# Post-synapse also at the zero-eth branch and zero-eth compartment | ||
global_post_comp_indices = ( | ||
post_cell_view.scope("local").branch(0).comp(0).nodes.index.to_numpy() | ||
) | ||
post_rows = post_cell_view.select(nodes=global_post_comp_indices[to_idx]).nodes | ||
|
||
if len(pre_rows) > 0: | ||
pre_cell_view.base._append_multiple_synapses(pre_rows, post_rows, synapse_type) | ||
|
@@ -129,49 +153,54 @@ def connectivity_matrix_connect( | |
post_cell_view: "View", | ||
synapse_type: "Synapse", | ||
connectivity_matrix: np.ndarray[bool], | ||
random_post_comp: bool = False, | ||
kyralianaka marked this conversation as resolved.
Show resolved
Hide resolved
|
||
): | ||
"""Appends multiple connections which build a custom connected network. | ||
"""Appends multiple connections according to a custom connectivity matrix. | ||
|
||
Connects pre- and postsynaptic cells according to a custom connectivity matrix. | ||
Entries > 0 in the matrix indicate a connection between the corresponding cells. | ||
Connections are from branch 0 location 0 to a randomly chosen branch and loc. | ||
Connections are from branch 0 location 0 of the pre-synaptic cell to branch 0 | ||
location 0 of the post-synaptic cell unless random_post_comp=True. | ||
|
||
Args: | ||
pre_cell_view: View of the presynaptic cell. | ||
post_cell_view: View of the postsynaptic cell. | ||
synapse_type: The synapse to append. | ||
connectivity_matrix: A boolean matrix indicating the connections between cells. | ||
random_post_comp: If True, randomly samples the postsynaptic compartments. | ||
""" | ||
# Get pre- and postsynaptic cell indices. | ||
pre_cell_inds = pre_cell_view._cells_in_view | ||
post_cell_inds = post_cell_view._cells_in_view | ||
# setting scope ensure that this works indep of current scope | ||
pre_nodes = pre_cell_view.scope("local").branch(0).comp(0).nodes | ||
pre_nodes["index"] = pre_nodes.index | ||
pre_cell_nodes = pre_nodes.set_index("global_cell_index") | ||
# Get pre- and postsynaptic cell indices | ||
num_pre = len(pre_cell_view._cells_in_view) | ||
num_post = len(post_cell_view._cells_in_view) | ||
|
||
assert connectivity_matrix.shape == ( | ||
len(pre_cell_inds), | ||
len(post_cell_inds), | ||
num_pre, | ||
num_post, | ||
), "Connectivity matrix must have shape (num_pre, num_post)." | ||
assert connectivity_matrix.dtype == bool, "Connectivity matrix must be boolean." | ||
|
||
# get connection pairs from connectivity matrix | ||
# Get pre to post connection pairs from connectivity matrix | ||
from_idx, to_idx = np.where(connectivity_matrix) | ||
pre_cell_inds = pre_cell_inds[from_idx] | ||
post_cell_inds = post_cell_inds[to_idx] | ||
|
||
# Sample random postsynaptic compartments (global comp indices). | ||
global_post_indices = np.hstack( | ||
[ | ||
sample_comp(post_cell_view.scope("global").cell(cell_idx)) | ||
for cell_idx in post_cell_inds | ||
] | ||
) | ||
post_rows = post_cell_view.nodes.loc[global_post_indices] | ||
|
||
# Pre-synapse is at the zero-eth branch and zero-eth compartment. | ||
global_pre_indices = pre_cell_nodes.loc[pre_cell_inds, "index"].to_numpy() | ||
pre_rows = pre_cell_view.select(nodes=global_pre_indices).nodes | ||
|
||
# Pre-synapse at the zero-eth branch and zero-eth compartment | ||
global_pre_comp_indices = ( | ||
pre_cell_view.scope("local").branch(0).comp(0).nodes.index.to_numpy() | ||
) # setting scope ensure that this works indep of current scope | ||
pre_rows = pre_cell_view.select(nodes=global_pre_comp_indices[from_idx]).nodes | ||
|
||
if random_post_comp: | ||
global_post_comp_indices = ( | ||
post_cell_view.nodes.groupby("global_cell_index") | ||
.sample(len(from_idx), replace=True) | ||
.index.to_numpy() | ||
) | ||
global_post_comp_indices = global_post_comp_indices.reshape( | ||
(-1, len(from_idx)), order="F" | ||
).ravel() | ||
else: | ||
# Post-synapse also at the zero-eth branch and zero-eth compartment | ||
global_post_comp_indices = ( | ||
post_cell_view.scope("local").branch(0).comp(0).nodes.index.to_numpy() | ||
) | ||
post_rows = post_cell_view.select(nodes=global_post_comp_indices[to_idx]).nodes | ||
|
||
pre_cell_view.base._append_multiple_synapses(pre_rows, post_rows, synapse_type) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,13 +33,13 @@ def test_network_grad(): | |
_ = np.random.seed(0) | ||
pre = net.cell([0, 1, 2]) | ||
post = net.cell([3, 4, 5]) | ||
fully_connect(pre, post, IonotropicSynapse()) | ||
fully_connect(pre, post, TestSynapse()) | ||
fully_connect(pre, post, IonotropicSynapse(), random_post_comp=True) | ||
fully_connect(pre, post, TestSynapse(), random_post_comp=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is there a test for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. All of the tests in test_connection.py are with random_post_comp=False, but the tests that used connection.py everywhere else (test_grad.py and test_basic_modules.py) use random_post_comp=True with fully connect (so that the simulation results are the same as before). I could add tests for random_post_comp=True to test_connection.py -- would this then be enough coverage? |
||
|
||
pre = net.cell([3, 4, 5]) | ||
post = net.cell(6) | ||
fully_connect(pre, post, IonotropicSynapse()) | ||
fully_connect(pre, post, TestSynapse()) | ||
fully_connect(pre, post, IonotropicSynapse(), random_post_comp=True) | ||
fully_connect(pre, post, TestSynapse(), random_post_comp=True) | ||
|
||
area = 2 * pi * 10.0 * 1.0 | ||
point_process_to_dist_factor = 100_000.0 / area | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we rename
fully_connect
tofully_connect_cells
? Might be less confusing, since otherwise connections are made comp 2 comp