From 5e3637c26e03b433b9a003c26186ecea6d4596eb Mon Sep 17 00:00:00 2001 From: jnsbck <65561470+jnsbck@users.noreply.github.com> Date: Wed, 20 Nov 2024 18:55:50 +0100 Subject: [PATCH] rename global_pre_comps (#518) * fix: rename global_pre_comps * fix: replace remaining * fix: wrong replace fixed * fix: make copy_nodes_to_edges work --- jaxley/modules/base.py | 12 ++++++------ jaxley/modules/network.py | 20 ++++++++++---------- tests/test_connection.py | 8 ++++---- tests/test_make_trainable.py | 20 ++++++++++---------- tests/test_viewing.py | 2 +- 5 files changed, 31 insertions(+), 31 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index d729eb3b..1409a6eb 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -127,8 +127,8 @@ def __init__(self): self.edges = pd.DataFrame( columns=[ "global_edge_index", - "global_pre_comp_index", - "global_post_comp_index", + "pre_global_comp_index", + "post_global_comp_index", "pre_locs", "post_locs", "type", @@ -2318,7 +2318,7 @@ def copy_node_property_to_edges( self.nodes[[property_to_import, "global_comp_index"]].set_index( "global_comp_index" ), - on=f"global_{pre_or_post_val}_comp_index", + on=f"{pre_or_post_val}_global_comp_index", ) self.edges = self.edges.rename( columns={ @@ -2466,8 +2466,8 @@ def _set_inds_in_view( incl_comps = pointer.nodes.loc[ self._nodes_in_view, "global_comp_index" ].unique() - pre = base_edges["global_pre_comp_index"].isin(incl_comps).to_numpy() - post = base_edges["global_post_comp_index"].isin(incl_comps).to_numpy() + pre = base_edges["pre_global_comp_index"].isin(incl_comps).to_numpy() + post = base_edges["post_global_comp_index"].isin(incl_comps).to_numpy() possible_edges_in_view = base_edges.index.to_numpy()[(pre & post).flatten()] self._edges_in_view = np.intersect1d( possible_edges_in_view, self._edges_in_view @@ -2476,7 +2476,7 @@ def _set_inds_in_view( base_nodes = self.base.nodes self._edges_in_view = edges incl_comps = pointer.edges.loc[ - self._edges_in_view, ["global_pre_comp_index", "global_post_comp_index"] + self._edges_in_view, ["pre_global_comp_index", "post_global_comp_index"] ] incl_comps = np.unique(incl_comps.to_numpy().flatten()) where_comps = base_nodes["global_comp_index"].isin(incl_comps) diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index 4fb2fe19..2966cfd7 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -262,8 +262,8 @@ def _step_synapse_state( voltages = states["v"] grouped_syns = edges.groupby("type", sort=False, group_keys=False) - pre_syn_inds = grouped_syns["global_pre_comp_index"].apply(list) - post_syn_inds = grouped_syns["global_post_comp_index"].apply(list) + pre_syn_inds = grouped_syns["pre_global_comp_index"].apply(list) + post_syn_inds = grouped_syns["post_global_comp_index"].apply(list) synapse_names = list(grouped_syns.indices.keys()) for i, synapse_type in enumerate(syn_channels): @@ -309,8 +309,8 @@ def _synapse_currents( voltages = states["v"] grouped_syns = edges.groupby("type", sort=False, group_keys=False) - pre_syn_inds = grouped_syns["global_pre_comp_index"].apply(list) - post_syn_inds = grouped_syns["global_post_comp_index"].apply(list) + pre_syn_inds = grouped_syns["pre_global_comp_index"].apply(list) + post_syn_inds = grouped_syns["post_global_comp_index"].apply(list) synapse_names = list(grouped_syns.indices.keys()) syn_voltage_terms = jnp.zeros_like(voltages) @@ -471,10 +471,10 @@ def vis( pre_locs = self.edges["pre_locs"].to_numpy() post_locs = self.edges["post_locs"].to_numpy() - pre_comp = self.edges["global_pre_comp_index"].to_numpy() + pre_comp = self.edges["pre_global_comp_index"].to_numpy() nodes = self.nodes.set_index("global_comp_index") pre_branch = nodes.loc[pre_comp, "global_branch_index"].to_numpy() - post_comp = self.edges["global_post_comp_index"].to_numpy() + post_comp = self.edges["post_global_comp_index"].to_numpy() post_branch = nodes.loc[post_comp, "global_branch_index"].to_numpy() dims_np = np.asarray(dims) @@ -536,10 +536,10 @@ def build_extents(*subset_sizes): else: graph.add_nodes_from(range(len(self._cells_in_view))) - pre_comp = self.edges["global_pre_comp_index"].to_numpy() + pre_comp = self.edges["pre_global_comp_index"].to_numpy() nodes = self.nodes.set_index("global_comp_index") pre_cell = nodes.loc[pre_comp, "global_cell_index"].to_numpy() - post_comp = self.edges["global_post_comp_index"].to_numpy() + post_comp = self.edges["post_global_comp_index"].to_numpy() post_cell = nodes.loc[post_comp, "global_cell_index"].to_numpy() inds = np.stack([pre_cell, post_cell]).T @@ -583,9 +583,9 @@ def _append_multiple_synapses(self, pre_nodes, post_nodes, synapse_type): # Define new synapses. Each row is one synapse. pre_nodes = pre_nodes[["global_comp_index"]] - pre_nodes.columns = ["global_pre_comp_index"] + pre_nodes.columns = ["pre_global_comp_index"] post_nodes = post_nodes[["global_comp_index"]] - post_nodes.columns = ["global_post_comp_index"] + post_nodes.columns = ["post_global_comp_index"] new_rows = pd.concat( [ global_edge_index, diff --git a/tests/test_connection.py b/tests/test_connection.py index c40f1e22..bb8d1b04 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -57,7 +57,7 @@ def test_connect(SimpleBranch, SimpleCell, SimpleNet): # check if all connections are made correctly first_set_edges = net2.edges.iloc[:8] nodes = net2.nodes.set_index("global_comp_index") - cols = ["global_pre_comp_index", "global_post_comp_index"] + cols = ["pre_global_comp_index", "post_global_comp_index"] comp_inds = nodes.loc[first_set_edges[cols].to_numpy().flatten()] branch_inds = comp_inds["global_branch_index"].to_numpy().reshape(-1, 2) cell_inds = comp_inds["global_cell_index"].to_numpy().reshape(-1, 2) @@ -84,7 +84,7 @@ def test_fully_connect(): fully_connect(net[8:12], net[12:16], TestSynapse()) assert all( - net.edges.global_post_comp_index + net.edges.post_global_comp_index == [ 108, 135, @@ -171,7 +171,7 @@ def test_connectivity_matrix_connect(SimpleNet): ) assert len(net.edges.index) == 4 nodes = net.nodes.set_index("global_comp_index") - cols = ["global_pre_comp_index", "global_post_comp_index"] + cols = ["pre_global_comp_index", "post_global_comp_index"] comp_inds = nodes.loc[net.edges[cols].to_numpy().flatten()] cell_inds = comp_inds["global_cell_index"].to_numpy().reshape(-1, 2) assert np.all(cell_inds == incides_of_connected_cells) @@ -192,7 +192,7 @@ def test_connectivity_matrix_connect(SimpleNet): ) assert len(net.edges.index) == 5 nodes = net.nodes.set_index("global_comp_index") - cols = ["global_pre_comp_index", "global_post_comp_index"] + cols = ["pre_global_comp_index", "post_global_comp_index"] comp_inds = nodes.loc[net.edges[cols].to_numpy().flatten()] cell_inds = comp_inds["global_cell_index"].to_numpy().reshape(-1, 2) assert np.all(cell_inds == incides_of_connected_cells) diff --git a/tests/test_make_trainable.py b/tests/test_make_trainable.py index db909a29..79bc51a5 100644 --- a/tests/test_make_trainable.py +++ b/tests/test_make_trainable.py @@ -217,14 +217,14 @@ def test_copy_node_property_to_edges(SimpleNet): assert "post_HH_gNa" not in net.edges.columns # Query the second cell. Each cell has four compartments. - edges_gna_values = net.edges.query("global_pre_comp_index > 3") - edges_gna_values = edges_gna_values.query("global_pre_comp_index <= 7") + edges_gna_values = net.edges.query("pre_global_comp_index > 3") + edges_gna_values = edges_gna_values.query("pre_global_comp_index <= 7") assert np.all(edges_gna_values["pre_HH_gNa"] == 1.0) # Query the other cells. The first cell has four compartments. - edges_gna_values = net.edges.query("global_pre_comp_index <= 3") + edges_gna_values = net.edges.query("pre_global_comp_index <= 3") assert np.all(edges_gna_values["pre_HH_gNa"] == 0.12) - edges_gna_values = net.edges.query("global_pre_comp_index > 7") + edges_gna_values = net.edges.query("pre_global_comp_index > 7") assert np.all(edges_gna_values["pre_HH_gNa"] == 0.12) # Test whether multiple properties can be copied over. @@ -234,10 +234,10 @@ def test_copy_node_property_to_edges(SimpleNet): assert "pre_length" in net.edges.columns assert "post_length" in net.edges.columns - edges_gna_values = net.edges.query("global_pre_comp_index <= 3") + edges_gna_values = net.edges.query("pre_global_comp_index <= 3") assert np.all(edges_gna_values["pre_radius"] == 0.2) - edges_gna_values = net.edges.query("global_pre_comp_index > 3") + edges_gna_values = net.edges.query("pre_global_comp_index > 3") assert np.all(edges_gna_values["pre_radius"] == 1.0) # Test whether modifying an individual compartment also takes effect. @@ -245,16 +245,16 @@ def test_copy_node_property_to_edges(SimpleNet): assert "pre_capacitance" in net.edges.columns assert "post_capacitance" in net.edges.columns - edges_gna_values = net.edges.query("global_pre_comp_index == 4") + edges_gna_values = net.edges.query("pre_global_comp_index == 4") assert np.all(edges_gna_values["pre_capacitance"] == 0.3) - edges_gna_values = net.edges.query("global_post_comp_index == 4") + edges_gna_values = net.edges.query("post_global_comp_index == 4") assert np.all(edges_gna_values["post_capacitance"] == 0.3) - edges_gna_values = net.edges.query("global_pre_comp_index != 4") + edges_gna_values = net.edges.query("pre_global_comp_index != 4") assert np.all(edges_gna_values["pre_capacitance"] == 1.0) - edges_gna_values = net.edges.query("global_post_comp_index != 4") + edges_gna_values = net.edges.query("post_global_comp_index != 4") assert np.all(edges_gna_values["post_capacitance"] == 1.0) diff --git a/tests/test_viewing.py b/tests/test_viewing.py index 5ebf65d1..1e38eb8e 100644 --- a/tests/test_viewing.py +++ b/tests/test_viewing.py @@ -359,7 +359,7 @@ def test_select(SimpleNet): # check if pre and post comps of edges are in nodes edge_node_inds = np.unique( - view.edges[["global_pre_comp_index", "global_post_comp_index"]] + view.edges[["pre_global_comp_index", "post_global_comp_index"]] .to_numpy() .flatten() )