Skip to content

Commit

Permalink
rename global_pre_comps (#518)
Browse files Browse the repository at this point in the history
* fix: rename global_pre_comps

* fix: replace remaining

* fix: wrong replace fixed

* fix: make copy_nodes_to_edges work
  • Loading branch information
jnsbck authored Nov 20, 2024
1 parent caf238c commit 5e3637c
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 31 deletions.
12 changes: 6 additions & 6 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ def __init__(self):
self.edges = pd.DataFrame(
columns=[
"global_edge_index",
"global_pre_comp_index",
"global_post_comp_index",
"pre_global_comp_index",
"post_global_comp_index",
"pre_locs",
"post_locs",
"type",
Expand Down Expand Up @@ -2318,7 +2318,7 @@ def copy_node_property_to_edges(
self.nodes[[property_to_import, "global_comp_index"]].set_index(
"global_comp_index"
),
on=f"global_{pre_or_post_val}_comp_index",
on=f"{pre_or_post_val}_global_comp_index",
)
self.edges = self.edges.rename(
columns={
Expand Down Expand Up @@ -2466,8 +2466,8 @@ def _set_inds_in_view(
incl_comps = pointer.nodes.loc[
self._nodes_in_view, "global_comp_index"
].unique()
pre = base_edges["global_pre_comp_index"].isin(incl_comps).to_numpy()
post = base_edges["global_post_comp_index"].isin(incl_comps).to_numpy()
pre = base_edges["pre_global_comp_index"].isin(incl_comps).to_numpy()
post = base_edges["post_global_comp_index"].isin(incl_comps).to_numpy()
possible_edges_in_view = base_edges.index.to_numpy()[(pre & post).flatten()]
self._edges_in_view = np.intersect1d(
possible_edges_in_view, self._edges_in_view
Expand All @@ -2476,7 +2476,7 @@ def _set_inds_in_view(
base_nodes = self.base.nodes
self._edges_in_view = edges
incl_comps = pointer.edges.loc[
self._edges_in_view, ["global_pre_comp_index", "global_post_comp_index"]
self._edges_in_view, ["pre_global_comp_index", "post_global_comp_index"]
]
incl_comps = np.unique(incl_comps.to_numpy().flatten())
where_comps = base_nodes["global_comp_index"].isin(incl_comps)
Expand Down
20 changes: 10 additions & 10 deletions jaxley/modules/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,8 @@ def _step_synapse_state(
voltages = states["v"]

grouped_syns = edges.groupby("type", sort=False, group_keys=False)
pre_syn_inds = grouped_syns["global_pre_comp_index"].apply(list)
post_syn_inds = grouped_syns["global_post_comp_index"].apply(list)
pre_syn_inds = grouped_syns["pre_global_comp_index"].apply(list)
post_syn_inds = grouped_syns["post_global_comp_index"].apply(list)
synapse_names = list(grouped_syns.indices.keys())

for i, synapse_type in enumerate(syn_channels):
Expand Down Expand Up @@ -309,8 +309,8 @@ def _synapse_currents(
voltages = states["v"]

grouped_syns = edges.groupby("type", sort=False, group_keys=False)
pre_syn_inds = grouped_syns["global_pre_comp_index"].apply(list)
post_syn_inds = grouped_syns["global_post_comp_index"].apply(list)
pre_syn_inds = grouped_syns["pre_global_comp_index"].apply(list)
post_syn_inds = grouped_syns["post_global_comp_index"].apply(list)
synapse_names = list(grouped_syns.indices.keys())

syn_voltage_terms = jnp.zeros_like(voltages)
Expand Down Expand Up @@ -471,10 +471,10 @@ def vis(

pre_locs = self.edges["pre_locs"].to_numpy()
post_locs = self.edges["post_locs"].to_numpy()
pre_comp = self.edges["global_pre_comp_index"].to_numpy()
pre_comp = self.edges["pre_global_comp_index"].to_numpy()
nodes = self.nodes.set_index("global_comp_index")
pre_branch = nodes.loc[pre_comp, "global_branch_index"].to_numpy()
post_comp = self.edges["global_post_comp_index"].to_numpy()
post_comp = self.edges["post_global_comp_index"].to_numpy()
post_branch = nodes.loc[post_comp, "global_branch_index"].to_numpy()

dims_np = np.asarray(dims)
Expand Down Expand Up @@ -536,10 +536,10 @@ def build_extents(*subset_sizes):
else:
graph.add_nodes_from(range(len(self._cells_in_view)))

pre_comp = self.edges["global_pre_comp_index"].to_numpy()
pre_comp = self.edges["pre_global_comp_index"].to_numpy()
nodes = self.nodes.set_index("global_comp_index")
pre_cell = nodes.loc[pre_comp, "global_cell_index"].to_numpy()
post_comp = self.edges["global_post_comp_index"].to_numpy()
post_comp = self.edges["post_global_comp_index"].to_numpy()
post_cell = nodes.loc[post_comp, "global_cell_index"].to_numpy()

inds = np.stack([pre_cell, post_cell]).T
Expand Down Expand Up @@ -583,9 +583,9 @@ def _append_multiple_synapses(self, pre_nodes, post_nodes, synapse_type):

# Define new synapses. Each row is one synapse.
pre_nodes = pre_nodes[["global_comp_index"]]
pre_nodes.columns = ["global_pre_comp_index"]
pre_nodes.columns = ["pre_global_comp_index"]
post_nodes = post_nodes[["global_comp_index"]]
post_nodes.columns = ["global_post_comp_index"]
post_nodes.columns = ["post_global_comp_index"]
new_rows = pd.concat(
[
global_edge_index,
Expand Down
8 changes: 4 additions & 4 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_connect(SimpleBranch, SimpleCell, SimpleNet):
# check if all connections are made correctly
first_set_edges = net2.edges.iloc[:8]
nodes = net2.nodes.set_index("global_comp_index")
cols = ["global_pre_comp_index", "global_post_comp_index"]
cols = ["pre_global_comp_index", "post_global_comp_index"]
comp_inds = nodes.loc[first_set_edges[cols].to_numpy().flatten()]
branch_inds = comp_inds["global_branch_index"].to_numpy().reshape(-1, 2)
cell_inds = comp_inds["global_cell_index"].to_numpy().reshape(-1, 2)
Expand All @@ -84,7 +84,7 @@ def test_fully_connect():
fully_connect(net[8:12], net[12:16], TestSynapse())

assert all(
net.edges.global_post_comp_index
net.edges.post_global_comp_index
== [
108,
135,
Expand Down Expand Up @@ -171,7 +171,7 @@ def test_connectivity_matrix_connect(SimpleNet):
)
assert len(net.edges.index) == 4
nodes = net.nodes.set_index("global_comp_index")
cols = ["global_pre_comp_index", "global_post_comp_index"]
cols = ["pre_global_comp_index", "post_global_comp_index"]
comp_inds = nodes.loc[net.edges[cols].to_numpy().flatten()]
cell_inds = comp_inds["global_cell_index"].to_numpy().reshape(-1, 2)
assert np.all(cell_inds == incides_of_connected_cells)
Expand All @@ -192,7 +192,7 @@ def test_connectivity_matrix_connect(SimpleNet):
)
assert len(net.edges.index) == 5
nodes = net.nodes.set_index("global_comp_index")
cols = ["global_pre_comp_index", "global_post_comp_index"]
cols = ["pre_global_comp_index", "post_global_comp_index"]
comp_inds = nodes.loc[net.edges[cols].to_numpy().flatten()]
cell_inds = comp_inds["global_cell_index"].to_numpy().reshape(-1, 2)
assert np.all(cell_inds == incides_of_connected_cells)
20 changes: 10 additions & 10 deletions tests/test_make_trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,14 +217,14 @@ def test_copy_node_property_to_edges(SimpleNet):
assert "post_HH_gNa" not in net.edges.columns

# Query the second cell. Each cell has four compartments.
edges_gna_values = net.edges.query("global_pre_comp_index > 3")
edges_gna_values = edges_gna_values.query("global_pre_comp_index <= 7")
edges_gna_values = net.edges.query("pre_global_comp_index > 3")
edges_gna_values = edges_gna_values.query("pre_global_comp_index <= 7")
assert np.all(edges_gna_values["pre_HH_gNa"] == 1.0)

# Query the other cells. The first cell has four compartments.
edges_gna_values = net.edges.query("global_pre_comp_index <= 3")
edges_gna_values = net.edges.query("pre_global_comp_index <= 3")
assert np.all(edges_gna_values["pre_HH_gNa"] == 0.12)
edges_gna_values = net.edges.query("global_pre_comp_index > 7")
edges_gna_values = net.edges.query("pre_global_comp_index > 7")
assert np.all(edges_gna_values["pre_HH_gNa"] == 0.12)

# Test whether multiple properties can be copied over.
Expand All @@ -234,27 +234,27 @@ def test_copy_node_property_to_edges(SimpleNet):
assert "pre_length" in net.edges.columns
assert "post_length" in net.edges.columns

edges_gna_values = net.edges.query("global_pre_comp_index <= 3")
edges_gna_values = net.edges.query("pre_global_comp_index <= 3")
assert np.all(edges_gna_values["pre_radius"] == 0.2)

edges_gna_values = net.edges.query("global_pre_comp_index > 3")
edges_gna_values = net.edges.query("pre_global_comp_index > 3")
assert np.all(edges_gna_values["pre_radius"] == 1.0)

# Test whether modifying an individual compartment also takes effect.
net.copy_node_property_to_edges(["capacitance"])
assert "pre_capacitance" in net.edges.columns
assert "post_capacitance" in net.edges.columns

edges_gna_values = net.edges.query("global_pre_comp_index == 4")
edges_gna_values = net.edges.query("pre_global_comp_index == 4")
assert np.all(edges_gna_values["pre_capacitance"] == 0.3)

edges_gna_values = net.edges.query("global_post_comp_index == 4")
edges_gna_values = net.edges.query("post_global_comp_index == 4")
assert np.all(edges_gna_values["post_capacitance"] == 0.3)

edges_gna_values = net.edges.query("global_pre_comp_index != 4")
edges_gna_values = net.edges.query("pre_global_comp_index != 4")
assert np.all(edges_gna_values["pre_capacitance"] == 1.0)

edges_gna_values = net.edges.query("global_post_comp_index != 4")
edges_gna_values = net.edges.query("post_global_comp_index != 4")
assert np.all(edges_gna_values["post_capacitance"] == 1.0)


Expand Down
2 changes: 1 addition & 1 deletion tests/test_viewing.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def test_select(SimpleNet):

# check if pre and post comps of edges are in nodes
edge_node_inds = np.unique(
view.edges[["global_pre_comp_index", "global_post_comp_index"]]
view.edges[["pre_global_comp_index", "post_global_comp_index"]]
.to_numpy()
.flatten()
)
Expand Down

0 comments on commit 5e3637c

Please sign in to comment.