Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Connectivity_matrix_connect update #489

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
169 changes: 99 additions & 70 deletions jaxley/connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,33 +44,52 @@ def fully_connect(
pre_cell_view: "View",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we rename fully_connect to fully_connect_cells? Might be less confusing, since otherwise connections are made comp 2 comp

post_cell_view: "View",
synapse_type: "Synapse",
random_post_comp: bool = False,
kyralianaka marked this conversation as resolved.
Show resolved Hide resolved
):
"""Appends multiple connections which build a fully connected layer.

Connections are from branch 0 location 0 to a randomly chosen branch and loc.
Connections are from branch 0 location 0 of the pre-synaptic cell to branch 0
location 0 of the post-synaptic cell unless random_post_comp=True.

Args:
pre_cell_view: View of the presynaptic cell.
post_cell_view: View of the postsynaptic cell.
synapse_type: The synapse to append.
random_post_comp: If True, randomly samples the postsynaptic compartments.
"""
# Get pre- and postsynaptic cell indices.
num_pre = len(pre_cell_view._cells_in_view)
num_post = len(post_cell_view._cells_in_view)

# Infer indices of (random) postsynaptic compartments.
global_post_indices = (
post_cell_view.nodes.groupby("global_cell_index")
.sample(num_pre, replace=True)
.index.to_numpy()
)
global_post_indices = global_post_indices.reshape((-1, num_pre), order="F").ravel()
post_rows = post_cell_view.nodes.loc[global_post_indices]

# Pre-synapse is at the zero-eth branch and zero-eth compartment.
pre_rows = pre_cell_view.scope("local").branch(0).comp(0).nodes.copy()
# Repeat rows `num_post` times. See SO 50788508.
pre_rows = pre_rows.loc[pre_rows.index.repeat(num_post)].reset_index(drop=True)
# Get the indices of the connections, like it's a fully connected connectivity matrix
from_idx = np.repeat(range(0, num_pre), num_post)
to_idx = np.tile(
kyralianaka marked this conversation as resolved.
Show resolved Hide resolved
range(0, num_post), num_pre
) # used only if random_post_comp is False

# Pre-synapse at the zero-eth branch and zero-eth compartment
global_pre_comp_indices = (
pre_cell_view.scope("local").branch(0).comp(0).nodes.index.to_numpy()
kyralianaka marked this conversation as resolved.
Show resolved Hide resolved
) # setting scope ensure that this works indep of current scope
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sure to reset this after, in case the user is in global scope

pre_rows = pre_cell_view.select(nodes=global_pre_comp_indices[from_idx]).nodes

if random_post_comp:
# Randomly sample the post-synaptic compartments
global_post_comp_indices = (
post_cell_view.nodes.groupby("global_cell_index")
.sample(num_pre, replace=True)
.index.to_numpy()
kyralianaka marked this conversation as resolved.
Show resolved Hide resolved
)
global_post_comp_indices = global_post_comp_indices.reshape(
(-1, num_pre), order="F"
).ravel()
else:
# Post-synapse also at the zero-eth branch and zero-eth compartment
global_post_comp_indices = (
post_cell_view.scope("local").branch(0).comp(0).nodes.index.to_numpy()
)
global_post_comp_indices = global_post_comp_indices[to_idx]
post_rows = post_cell_view.nodes.loc[global_post_comp_indices]

pre_cell_view.base._append_multiple_synapses(pre_rows, post_rows, synapse_type)

Expand All @@ -80,45 +99,50 @@ def sparse_connect(
post_cell_view: "View",
synapse_type: "Synapse",
p: float,
random_post_comp: bool = False,
kyralianaka marked this conversation as resolved.
Show resolved Hide resolved
):
"""Appends multiple connections which build a sparse, randomly connected layer.

Connections are from branch 0 location 0 to a randomly chosen branch and loc.
Connections are from branch 0 location 0 of the pre-synaptic cell to branch 0
location 0 of the post-synaptic cell unless random_post_comp=True.

Args:
pre_cell_view: View of the presynaptic cell.
post_cell_view: View of the postsynaptic cell.
synapse_type: The synapse to append.
p: Probability of connection.
random_post_comp: If True, randomly samples the postsynaptic compartments.
"""
# Get pre- and postsynaptic cell indices.
pre_cell_inds = pre_cell_view._cells_in_view
post_cell_inds = post_cell_view._cells_in_view
num_pre = len(pre_cell_inds)
num_post = len(post_cell_inds)

num_connections = np.random.binomial(num_pre * num_post, p)
pre_syn_neurons = np.random.choice(pre_cell_inds, size=num_connections)
post_syn_neurons = np.random.choice(post_cell_inds, size=num_connections)

# Sort the synapses only for convenience of inspecting `.edges`.
sorting = np.argsort(pre_syn_neurons)
pre_syn_neurons = pre_syn_neurons[sorting]
post_syn_neurons = post_syn_neurons[sorting]

# Post-synapse is a randomly chosen branch and compartment.
global_post_indices = [
sample_comp(post_cell_view.scope("global").cell(cell_idx))
for cell_idx in post_syn_neurons
]
global_post_indices = (
np.hstack(global_post_indices) if len(global_post_indices) > 1 else []
)
post_rows = post_cell_view.base.nodes.loc[global_post_indices]

# Pre-synapse is at the zero-eth branch and zero-eth compartment.
global_pre_indices = pre_cell_view.base._cumsum_nseg_per_cell[pre_syn_neurons]
pre_rows = pre_cell_view.base.nodes.loc[global_pre_indices]
num_pre = len(pre_cell_view._cells_in_view)
num_post = len(post_cell_view._cells_in_view)

# Generate random cxns without duplicates --> respects p but memory intesive if extremely large n cells
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I implemented the connect API, @michaeldeistler and me had discussed not wanting to do this, since this can be really bad for large, sparse networks. Did we change our mind about doing it this way?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree its cleaner and easier to parse, but not efficient.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch Jonas! Let's avoid building a n x n matrix.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was looking into this for quite a while, and all of the ways that are memory efficient are extremely time intensive, and vice versa, so I came up with a way to split the cost a bit... generating blocks of random matrix at a time. I am still thinking of how it could be prettier, but the resource management seems to be good

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the reason we cannot do it as it was before? If really would like to avoid memory scaling with O(N^2).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function before created many duplicate connections between the same cells, and simply filtering these out would remove a lot of connections (inefficient) and make the p value entered somewhat meaningless (more problematic). Also p=1.0 did not create a fully connected graph, far from it I might add. Resource efficient methods for constructing sparse random graphs are highly sought after and typically trade off computation speed and memory, so I tried to write a function that balances the two. The function now isn't O(N^2) per se, the largest matrix it creates is 100x100. I can keep looking for better algorithms in the future, but for now the function at least fixes the bug of duplicate connections and restores the traditional meaning of p.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also thinking now that just removing sparse_connect() might not be so bad... I think we might have used connectivity matrix connect or something else in the RNN experiments anyway

connectivity_matrix = np.random.binomial(1, p, (num_pre, num_post))
from_idx, to_idx = np.where(connectivity_matrix)

# Pre-synapse at the zero-eth branch and zero-eth compartment
global_pre_comp_indices = (
pre_cell_view.scope("local").branch(0).comp(0).nodes.index.to_numpy()
) # setting scope ensure that this works indep of current scope
pre_rows = pre_cell_view.select(nodes=global_pre_comp_indices[from_idx]).nodes

if random_post_comp:
# Randomly sample the post-synaptic compartments
global_post_comp_indices = (
post_cell_view.nodes.groupby("global_cell_index")
.sample(num_pre, replace=True)
.index.to_numpy()
)
global_post_comp_indices = global_post_comp_indices.reshape(
(-1, num_pre), order="F"
).ravel()
else:
# Post-synapse also at the zero-eth branch and zero-eth compartment
global_post_comp_indices = (
post_cell_view.scope("local").branch(0).comp(0).nodes.index.to_numpy()
)
post_rows = post_cell_view.select(nodes=global_post_comp_indices[to_idx]).nodes

if len(pre_rows) > 0:
pre_cell_view.base._append_multiple_synapses(pre_rows, post_rows, synapse_type)
Expand All @@ -129,49 +153,54 @@ def connectivity_matrix_connect(
post_cell_view: "View",
synapse_type: "Synapse",
connectivity_matrix: np.ndarray[bool],
random_post_comp: bool = False,
kyralianaka marked this conversation as resolved.
Show resolved Hide resolved
):
"""Appends multiple connections which build a custom connected network.
"""Appends multiple connections according to a custom connectivity matrix.

Connects pre- and postsynaptic cells according to a custom connectivity matrix.
Entries > 0 in the matrix indicate a connection between the corresponding cells.
Connections are from branch 0 location 0 to a randomly chosen branch and loc.
Connections are from branch 0 location 0 of the pre-synaptic cell to branch 0
location 0 of the post-synaptic cell unless random_post_comp=True.

Args:
pre_cell_view: View of the presynaptic cell.
post_cell_view: View of the postsynaptic cell.
synapse_type: The synapse to append.
connectivity_matrix: A boolean matrix indicating the connections between cells.
random_post_comp: If True, randomly samples the postsynaptic compartments.
"""
# Get pre- and postsynaptic cell indices.
pre_cell_inds = pre_cell_view._cells_in_view
post_cell_inds = post_cell_view._cells_in_view
# setting scope ensure that this works indep of current scope
pre_nodes = pre_cell_view.scope("local").branch(0).comp(0).nodes
pre_nodes["index"] = pre_nodes.index
pre_cell_nodes = pre_nodes.set_index("global_cell_index")
# Get pre- and postsynaptic cell indices
num_pre = len(pre_cell_view._cells_in_view)
num_post = len(post_cell_view._cells_in_view)

assert connectivity_matrix.shape == (
len(pre_cell_inds),
len(post_cell_inds),
num_pre,
num_post,
), "Connectivity matrix must have shape (num_pre, num_post)."
assert connectivity_matrix.dtype == bool, "Connectivity matrix must be boolean."

# get connection pairs from connectivity matrix
# Get pre to post connection pairs from connectivity matrix
from_idx, to_idx = np.where(connectivity_matrix)
pre_cell_inds = pre_cell_inds[from_idx]
post_cell_inds = post_cell_inds[to_idx]

# Sample random postsynaptic compartments (global comp indices).
global_post_indices = np.hstack(
[
sample_comp(post_cell_view.scope("global").cell(cell_idx))
for cell_idx in post_cell_inds
]
)
post_rows = post_cell_view.nodes.loc[global_post_indices]

# Pre-synapse is at the zero-eth branch and zero-eth compartment.
global_pre_indices = pre_cell_nodes.loc[pre_cell_inds, "index"].to_numpy()
pre_rows = pre_cell_view.select(nodes=global_pre_indices).nodes

# Pre-synapse at the zero-eth branch and zero-eth compartment
global_pre_comp_indices = (
pre_cell_view.scope("local").branch(0).comp(0).nodes.index.to_numpy()
) # setting scope ensure that this works indep of current scope
pre_rows = pre_cell_view.select(nodes=global_pre_comp_indices[from_idx]).nodes

if random_post_comp:
global_post_comp_indices = (
post_cell_view.nodes.groupby("global_cell_index")
.sample(len(from_idx), replace=True)
.index.to_numpy()
)
global_post_comp_indices = global_post_comp_indices.reshape(
(-1, len(from_idx)), order="F"
).ravel()
else:
# Post-synapse also at the zero-eth branch and zero-eth compartment
global_post_comp_indices = (
post_cell_view.scope("local").branch(0).comp(0).nodes.index.to_numpy()
)
post_rows = post_cell_view.select(nodes=global_post_comp_indices[to_idx]).nodes

pre_cell_view.base._append_multiple_synapses(pre_rows, post_rows, synapse_type)
8 changes: 4 additions & 4 deletions tests/jaxley_identical/test_basic_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,13 +319,13 @@ def test_complex_net(voltage_solver: str):
_ = np.random.seed(0)
pre = net.cell([0, 1, 2])
post = net.cell([3, 4, 5])
fully_connect(pre, post, IonotropicSynapse())
fully_connect(pre, post, TestSynapse())
fully_connect(pre, post, IonotropicSynapse(), random_post_comp=True)
fully_connect(pre, post, TestSynapse(), random_post_comp=True)

pre = net.cell([3, 4, 5])
post = net.cell(6)
fully_connect(pre, post, IonotropicSynapse())
fully_connect(pre, post, TestSynapse())
fully_connect(pre, post, IonotropicSynapse(), random_post_comp=True)
fully_connect(pre, post, TestSynapse(), random_post_comp=True)

area = 2 * pi * 10.0 * 1.0
point_process_to_dist_factor = 100_000.0 / area
Expand Down
8 changes: 4 additions & 4 deletions tests/jaxley_identical/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ def test_network_grad():
_ = np.random.seed(0)
pre = net.cell([0, 1, 2])
post = net.cell([3, 4, 5])
fully_connect(pre, post, IonotropicSynapse())
fully_connect(pre, post, TestSynapse())
fully_connect(pre, post, IonotropicSynapse(), random_post_comp=True)
fully_connect(pre, post, TestSynapse(), random_post_comp=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a test for random_post_comp=False? At least for fully_connect, but also for the others? At least to check that they do not break?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of the tests in test_connection.py are with random_post_comp=False, but the tests that used connection.py everywhere else (test_grad.py and test_basic_modules.py) use random_post_comp=True with fully connect (so that the simulation results are the same as before). I could add tests for random_post_comp=True to test_connection.py -- would this then be enough coverage?


pre = net.cell([3, 4, 5])
post = net.cell(6)
fully_connect(pre, post, IonotropicSynapse())
fully_connect(pre, post, TestSynapse())
fully_connect(pre, post, IonotropicSynapse(), random_post_comp=True)
fully_connect(pre, post, TestSynapse(), random_post_comp=True)

area = 2 * pi * 10.0 * 1.0
point_process_to_dist_factor = 100_000.0 / area
Expand Down
Loading
Loading