Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

.connect() and fully_connect() methods #197

Merged
merged 12 commits into from
Dec 12, 2023
6 changes: 2 additions & 4 deletions jaxley/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,9 @@ def fc(self, pre_cell_inds, post_cell_inds):
return conns

def sparse_random(self, pre_cell_inds, post_cell_inds, p):
"""Returns a list of `Connection`s which build a sparse, randomly
connected layer.
"""Returns a list of `Connection`s forming a sparse, randomly connected layer.

Connections are from branch 0 location 0 to a randomly chosen branch
and loc.
Connections are from branch 0 location 0 to a randomly chosen branch and loc.
"""
num_pre = len(pre_cell_inds)
num_post = len(post_cell_inds)
Expand Down
3 changes: 2 additions & 1 deletion jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(self):

self.syn_params: Dict[str, jnp.ndarray] = {}
self.syn_states: Dict[str, jnp.ndarray] = {}
self.syn_classes: List = []

# Channel indices, parameters, and states.
self.channel_nodes: Dict[str, pd.DataFrame] = {}
Expand Down Expand Up @@ -535,7 +536,7 @@ def step(
# Step of the synapse.
new_syn_states, syn_voltage_terms, syn_constant_terms = self._step_synapse(
u,
self.conns,
self.syn_classes,
params,
delta_t,
self.syn_edges,
Expand Down
41 changes: 41 additions & 0 deletions jaxley/modules/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,47 @@ def __getattr__(self, key):
assert key == "branch"
return BranchView(self.pointer, self.view)

def fully_connect(self, post_cell_view, synapse_type):
michaeldeistler marked this conversation as resolved.
Show resolved Hide resolved
"""Returns a list of `Connection`s which build a fully connected layer.

Connections are from branch 0 location 0 to a randomly chosen branch and loc.
"""
pre_cell_inds = np.unique(self.view["cell_index"].to_numpy())
post_cell_inds = np.unique(post_cell_view.view["cell_index"].to_numpy())

for pre_ind in pre_cell_inds:
for post_ind in post_cell_inds:
num_branches_post = self.pointer.nbranches_per_cell[post_ind]
rand_branch = np.random.randint(0, num_branches_post)
rand_loc = np.random.rand()

pre = self.pointer.cell(pre_ind).branch(rand_branch).comp(rand_loc)
post = self.pointer.cell(post_ind).branch(rand_branch).comp(rand_loc)
pre.connect(post, synapse_type)

def sparse_connect(self, post_cell_view, p, synapse_type):
"""Returns a list of `Connection`s forming a sparse, randomly connected layer.

Connections are from branch 0 location 0 to a randomly chosen branch and loc.
"""
pre_cell_inds = np.unique(self.view["cell_index"].to_numpy())
post_cell_inds = np.unique(post_cell_view.view["cell_index"].to_numpy())

num_pre = len(pre_cell_inds)
num_post = len(post_cell_inds)
num_connections = np.random.binomial(num_pre * num_post, p)
pre_syn_neurons = np.random.choice(pre_cell_inds, size=num_connections)
post_syn_neurons = np.random.choice(post_cell_inds, size=num_connections)

for pre_ind, post_ind in zip(pre_syn_neurons, post_syn_neurons):
num_branches_post = self.pointer.nbranches_per_cell[post_ind]
rand_branch = np.random.randint(0, num_branches_post)
rand_loc = np.random.rand()

pre = self.pointer.cell(pre_ind).branch(rand_branch).comp(rand_loc)
post = self.pointer.cell(post_ind).branch(rand_branch).comp(rand_loc)
pre.connect(post, synapse_type)


def read_swc(
fname: str,
Expand Down
113 changes: 107 additions & 6 deletions jaxley/modules/compartment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
import jax.numpy as jnp
import pandas as pd

from jaxley.channels import Channel
from jaxley.modules.base import Module, View
from jaxley.utils.cell_utils import index_of_loc
from jaxley.utils.cell_utils import index_of_loc, loc_of_index


class Compartment(Module):
Expand All @@ -29,10 +28,6 @@ def __init__(self):
self.nodes = pd.DataFrame(
dict(comp_index=[0], branch_index=[0], cell_index=[0])
)
# Synapse indexing.
self.syn_edges = pd.DataFrame(
dict(global_pre_comp_index=[], global_post_comp_index=[], type="")
)
self.branch_edges = pd.DataFrame(
dict(parent_branch_index=[], child_branch_index=[])
)
Expand Down Expand Up @@ -65,3 +60,109 @@ def __call__(self, loc: float):

index = index_of_loc(0, loc, self.pointer.nseg) if loc != "all" else "all"
return super().adjust_view("comp_index", index)

def connect(self, post: "CompartmentView", synapse_type):
michaeldeistler marked this conversation as resolved.
Show resolved Hide resolved
"""Connect two compartments with a chemical synapse.

High-level strategy:

We need to first check if the network already has a type of this synapse, else
we need to register it as a new synapse in a bunch of dictionaries which track
synapse parameters, state and meta information.

Next, we register the new connection in the synapse dataframe (`.syn_edges`).
Then, we update synapse parameter and state arrays with the new connection.
Finally, we update synapse meta information.
"""
synapse_name = type(synapse_type).__name__
is_new_type = True if synapse_name not in self.pointer.synapse_names else False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would start with a short paragraph as a comment describing the implementation strategy to scaffold the information for any dev, reviewer, user.

# High-level strategy
# We need to first check if the compartment [branch, cell, network ?] already has a type of this synapse, else we need to register it as a new synapse in a bunch of dictionaries which track synapse parameters, state and meta information [Is this correctly inferred from the code?]. 
# Next, we register the new connection in the synapse dataframe.
# Then, we update synapse parameter and state arrays with the new connection.
# Finally, we update synapse meta information.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, copied it almost verbatim into the docstring. And yes, your description is correctly inferred from the code.


if is_new_type:
# New type: index for the synapse type is one more than the currently
# highest index.
max_ind = self.pointer.syn_edges["type_ind"].max() + 1
type_ind = 0 if jnp.isnan(max_ind) else max_ind
else:
# Not a new type: search for the index that this type has previously had.
type_ind = self.pointer.syn_edges.query(f"type == '{synapse_name}'")[
"type_ind"
].to_numpy()[0]

# The `syn_edges` dataframe expects the compartment as continuous `loc`, not
# as discrete compartment index (because the continuous `loc` is used for
# plotting). Below, we cast the compartment index to its (rough) location.
pre_comp = loc_of_index(
self.view["global_comp_index"].to_numpy(), self.pointer.nseg
)
post_comp = loc_of_index(
post.view["global_comp_index"].to_numpy(), self.pointer.nseg
)

# Update edges.
self.pointer.syn_edges = pd.concat(
[
self.pointer.syn_edges,
pd.DataFrame(
dict(
pre_locs=pre_comp,
post_locs=post_comp,
pre_branch_index=self.view["branch_index"].to_numpy(),
post_branch_index=post.view["branch_index"].to_numpy(),
pre_cell_index=self.view["cell_index"].to_numpy(),
post_cell_index=post.view["cell_index"].to_numpy(),
type=synapse_name,
type_ind=type_ind,
global_pre_comp_index=self.view["global_comp_index"].to_numpy(),
global_post_comp_index=post.view[
"global_comp_index"
].to_numpy(),
global_pre_branch_index=self.view[
"global_branch_index"
].to_numpy(),
global_post_branch_index=post.view[
"global_branch_index"
].to_numpy(),
)
),
],
ignore_index=True,
)

# We add a column called index which is used by `adjust_view` of the
# `SynapseView` (see `network.py`).
self.pointer.syn_edges["index"] = list(self.pointer.syn_edges.index)

# Update synaptic parameter array.
for key in synapse_type.synapse_params:
param_vals = jnp.asarray([synapse_type.synapse_params[key]])
if is_new_type:
# Register parameter array for new synapse type.
self.pointer.syn_params[key] = param_vals
else:
# Append to synaptic parameter array.
self.pointer.syn_params[key] = jnp.concatenate(
[self.pointer.syn_params[key], param_vals]
)

# Update synaptic state array.
for key in synapse_type.synapse_states:
state_vals = jnp.asarray([synapse_type.synapse_states[key]])
if is_new_type:
# Register parameter array for new synapse type.
self.pointer.syn_states[key] = state_vals
else:
# Append to synaptic parameter array.
self.pointer.syn_states[key] = jnp.concatenate(
[self.pointer.syn_states[key], state_vals]
)

# (Potentially) update variables that track meta information about synapses.
if is_new_type:
self.pointer.synapse_names.append(type(synapse_type).__name__)
self.pointer.synapse_param_names.append(
list(synapse_type.synapse_params.keys())
)
self.pointer.synapse_state_names.append(
list(synapse_type.synapse_states.keys())
)
self.pointer.syn_classes.append(synapse_type)
12 changes: 5 additions & 7 deletions jaxley/modules/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def __init__(

self.cells = cells
self.connectivities = connectivities
self.conns = [connectivity.synapse_type for connectivity in connectivities]
self.syn_classes = [
connectivity.synapse_type for connectivity in connectivities
]
self.nseg = cells[0].nseg
self.synapse_names = [type(c.synapse_type).__name__ for c in connectivities]
self.synapse_param_names = [
Expand Down Expand Up @@ -456,6 +458,8 @@ class SynapseView(View):

def __init__(self, pointer, view, key):
view = view[view["type"] == key]
view = view.reset_index(drop=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did it become necessary to reset the index here?

Copy link
Contributor Author

@michaeldeistler michaeldeistler Dec 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, good catch. This is unrelated to the .connect() method. It fixes a bug that only surfaced when I added tests that add synapses of different types:

pre.connect(post, GlutamateSynapse())
pre.connect(post, TestSynapse())

view["index"] = list(view.index)
view = view.assign(controlled_by_param=view.index)
super().__init__(pointer, view)

Expand Down Expand Up @@ -487,12 +491,6 @@ def show(

return nodes

def adjust_view(self, key: str, index: float):
"""Update view."""
if index != "all":
self.view = self.view[self.view[key] == index]
return self

def set_params(self, key: str, val: float):
"""Set parameters of the pointer."""
assert (
Expand Down
1 change: 1 addition & 0 deletions jaxley/synapses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from jaxley.synapses.glutamate import GlutamateSynapse
from jaxley.synapses.synapse import Synapse
from jaxley.synapses.test import TestSynapse
2 changes: 1 addition & 1 deletion jaxley/synapses/glutamate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

class GlutamateSynapse(Synapse):
"""
Compute syanptic current and update syanpse state for Glutamate receptor.
Compute syanptic current and update synapse state for Glutamate receptor.
"""

synapse_params = {"gS": 0.5}
Expand Down
41 changes: 41 additions & 0 deletions jaxley/synapses/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from typing import Dict, Tuple

import jax.numpy as jnp

from jaxley.synapses.synapse import Synapse


class TestSynapse(Synapse):
"""
Compute syanptic current and update syanpse state for a test synapse.
michaeldeistler marked this conversation as resolved.
Show resolved Hide resolved
"""

synapse_params = {"gC": 0.5}
synapse_states = {"c": 0.2}

@staticmethod
def step(
u: Dict[str, jnp.ndarray],
delta_t: float,
voltages: jnp.ndarray,
params: Dict[str, jnp.ndarray],
pre_inds: jnp.ndarray,
) -> Tuple[Dict[str, jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray]]:
"""Return updated synapse state and current."""
e_syn = 0.0
v_th = -35.0
delta = 10.0
k_minus = 1.0 / 40.0

s_bar = 1.0 / (1.0 + jnp.exp((v_th - voltages[pre_inds]) / delta))
tau_s = (1.0 - s_bar) / k_minus

s_inf = s_bar
slope = -1.0 / tau_s
exp_term = jnp.exp(slope * delta_t)
new_s = u["c"] * exp_term + s_inf * (1.0 - exp_term)

non_zero_voltage_term = params["gC"] * u["c"]
non_zero_constant_term = params["gC"] * u["c"] * e_syn

return {"c": new_s}, (non_zero_voltage_term, non_zero_constant_term)
7 changes: 7 additions & 0 deletions jaxley/utils/cell_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,13 @@ def index_of_loc(branch_ind: int, loc: float, nseg_per_branch: int) -> int:
return branch_ind * nseg_per_branch + ind_along_branch


def loc_of_index(global_comp_index, nseg):
"""Return location corresponding to index."""
index = global_comp_index % nseg
possible_locs = np.linspace(1, 0, nseg)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why reversed from 1 to 0 instead of 0 to 1? Cause I would expect that index 0 corresponds to location 0.0

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, indeed, this is unintitive. See #30

return possible_locs[index]


def compute_coupling_cond(rad1, rad2, r_a1, r_a2, l1, l2):
midpoint_r_a = 0.5 * (r_a1 + r_a2)
return rad1 * rad2**2 / midpoint_r_a / (rad2**2 * l1 + rad1**2 * l2) / l1
Expand Down
Loading