From dbb1b34a1f75a793c0a752153bbdee3b393a83b2 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Tue, 29 Oct 2024 11:51:23 +0100 Subject: [PATCH 01/26] fix: v1 of new get_all_parameters and to_jax --- jaxley/modules/base.py | 40 ++++++++++++++-------------------------- 1 file changed, 14 insertions(+), 26 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 2893f983..7f40262f 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -734,22 +734,18 @@ def to_jax(self): they can be processed on GPU/TPU and such that the simulation can be differentiated. `.to_jax()` copies the `.nodes` to `.jaxnodes`. """ - self.base.jaxnodes = {} - for key, value in self.base.nodes.to_dict(orient="list").items(): - inds = jnp.arange(len(value)) - self.base.jaxnodes[key] = jnp.asarray(value)[inds] - - # `jaxedges` contains only parameters (no indices). - # `jaxedges` contains only non-Nan elements. This is unlike the channels where - # we allow parameter sharing. - self.base.jaxedges = {} + jaxnodes = self.base.jaxnodes = {} + nodes = self.base.nodes.to_dict(orient="list") + + jaxedges = self.base.jaxedges = {} edges = self.base.edges.to_dict(orient="list") - for i, synapse in enumerate(self.base.synapses): - condition = np.asarray(edges["type_ind"]) == i - for key in synapse.synapse_params: - self.base.jaxedges[key] = jnp.asarray(np.asarray(edges[key])[condition]) - for key in synapse.synapse_states: - self.base.jaxedges[key] = jnp.asarray(np.asarray(edges[key])[condition]) + edges.pop("type") # drop since column type is string + + for jax_array, params in zip([jaxnodes, jaxedges], [nodes, edges]): + for key, value in params.items(): + inds = jnp.arange(len(value)) + jax_array[key] = jnp.asarray(value)[inds] + def show( self, @@ -1157,19 +1153,11 @@ def write_trainables(self, trainable_params: List[Dict[str, jnp.ndarray]]): # Loop only over the keys in `pstate` to avoid unnecessary computation. for parameter in pstate: key = parameter["key"] + vals_to_set = all_params if key in all_params.keys() else all_states if key in self.base.nodes.columns: - vals_to_set = all_params if key in all_params.keys() else all_states self.base.nodes[key] = vals_to_set[key] - - # `jaxedges` contains only non-Nan elements. This is unlike the channels where - # we allow parameter sharing. - edges = self.base.edges.to_dict(orient="list") - for i, synapse in enumerate(self.base.synapses): - condition = np.asarray(edges["type_ind"]) == i - for key in list(synapse.synapse_params.keys()): - self.base.edges.loc[condition, key] = all_params[key] - for key in list(synapse.synapse_states.keys()): - self.base.edges.loc[condition, key] = all_states[key] + if key in self.base.edges.columns: + self.base.edges[key] = vals_to_set[key] def distance(self, endpoint: "View") -> float: """Return the direct distance between two compartments. From d5a7d036351a26eda94fb0cde15f40476fa72948 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Thu, 7 Nov 2024 13:34:21 +0100 Subject: [PATCH 02/26] enh: simplified and refactored steping currents in synapses and channels. --- jaxley/modules/base.py | 103 ++++++++++++++------------------ jaxley/modules/network.py | 111 +++++++++++++++++------------------ jaxley/utils/cell_utils.py | 2 +- tests/test_make_trainable.py | 8 +-- 4 files changed, 102 insertions(+), 122 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 7f40262f..5e670b84 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -33,7 +33,7 @@ interpolate_xyzr, loc_of_index, params_to_pstate, - query_channel_states_and_params, + query_states_and_params, v_interp, ) from jaxley.utils.debug_solver import compute_morphology_indices @@ -736,17 +736,16 @@ def to_jax(self): """ jaxnodes = self.base.jaxnodes = {} nodes = self.base.nodes.to_dict(orient="list") - + jaxedges = self.base.jaxedges = {} edges = self.base.edges.to_dict(orient="list") - edges.pop("type") # drop since column type is string - + edges.pop("type") # drop since column type is string + for jax_array, params in zip([jaxnodes, jaxedges], [nodes, edges]): for key, value in params.items(): inds = jnp.arange(len(value)) jax_array[key] = jnp.asarray(value)[inds] - def show( self, param_names: Optional[Union[str, List[str]]] = None, @@ -1283,17 +1282,6 @@ def get_all_parameters( inds = parameter["indices"] set_param = parameter["val"] - # This is needed since SynapseViews worked differently before. - # This mimics the old behaviour and tranformes the new indices - # to the old indices. - # TODO FROM #447: Longterm this should be gotten rid of. - # Instead edges should work similar to nodes (would also allow for - # param sharing). - synapse_inds = self.base.edges.groupby("type").rank()["global_edge_index"] - synapse_inds = (synapse_inds.astype(int) - 1).to_numpy() - if key in self.base.synapse_param_names: - inds = synapse_inds[inds] - if key in params: # Only parameters, not initial states. # `inds` is of shape `(num_params, num_comps_per_param)`. # `set_param` is of shape `(num_params,)` @@ -1400,10 +1388,10 @@ def init_states(self, delta_t: float = 0.025): channel_param_names = list(channel.channel_params.keys()) channel_state_names = list(channel.channel_states.keys()) - channel_states = query_channel_states_and_params( + channel_states = query_states_and_params( states, channel_state_names, channel_indices ) - channel_params = query_channel_states_and_params( + channel_params = query_states_and_params( params, channel_param_names, channel_indices ) @@ -1932,35 +1920,33 @@ def _step_channels_state( ) -> Dict[str, jnp.ndarray]: """One integration step of the channels.""" voltages = states["v"] + morph_params = ["radius", "length", "axial_resistivity", "capacitance"] # Update states of the channels. - indices = channel_nodes["global_comp_index"].to_numpy() for channel in channels: - channel_param_names = list(channel.channel_params) - channel_param_names += [ - "radius", - "length", - "axial_resistivity", - "capacitance", - ] - channel_state_names = list(channel.channel_states) - channel_state_names += self.membrane_current_names - channel_indices = indices[channel_nodes[channel._name].astype(bool)] + has_channel = channel_nodes[channel._name] + channel_inds = channel_nodes.loc[ + has_channel, "global_comp_index" + ].to_numpy() - channel_params = query_channel_states_and_params( - params, channel_param_names, channel_indices + channel_param_names = list(channel.channel_params) + morph_params + channel_params = query_states_and_params( + params, channel_param_names, channel_inds ) - channel_states = query_channel_states_and_params( - states, channel_state_names, channel_indices + channel_state_names = ( + list(channel.channel_states) + self.membrane_current_names + ) + channel_states = query_states_and_params( + states, channel_state_names, channel_inds ) states_updated = channel.update_states( - channel_states, delta_t, voltages[channel_indices], channel_params + channel_states, delta_t, voltages[channel_inds], channel_params ) # Rebuild state. This has to be done within the loop over channels to allow # multiple channels which modify the same state. for key, val in states_updated.items(): - states[key] = states[key].at[channel_indices].set(val) + states[key] = states[key].at[channel_inds].set(val) return states @@ -1977,53 +1963,53 @@ def _channel_currents( This is also updates `state` because the `state` also contains the current. """ voltages = states["v"] + morph_params = ["radius", "length", "axial_resistivity"] # Compute current through channels. - voltage_terms = jnp.zeros_like(voltages) - constant_terms = jnp.zeros_like(voltages) + zeros = jnp.zeros_like(voltages) + voltage_terms = zeros + constant_terms = zeros # Run with two different voltages that are `diff` apart to infer the slope and # offset. diff = 1e-3 - current_states = {} - for name in self.membrane_current_names: - current_states[name] = jnp.zeros_like(voltages) - + current_states = {name: zeros for name in self.membrane_current_names} for channel in channels: name = channel._name - channel_param_names = list(channel.channel_params.keys()) - channel_state_names = list(channel.channel_states.keys()) - indices = channel_nodes.loc[channel_nodes[name]][ + channel_inds = channel_nodes.loc[channel_nodes[name]][ "global_comp_index" ].to_numpy() - channel_params = {} - for p in channel_param_names: - channel_params[p] = params[p][indices] - channel_params["radius"] = params["radius"][indices] - channel_params["length"] = params["length"][indices] - channel_params["axial_resistivity"] = params["axial_resistivity"][indices] + channel_params = query_states_and_params( + params, list(channel.channel_params) + morph_params, channel_inds + ) + channel_states = query_states_and_params( + states, channel.channel_states, channel_inds + ) - channel_states = {} - for s in channel_state_names: - channel_states[s] = states[s][indices] + v_and_perturbed = jnp.stack( + [voltages[channel_inds], voltages[channel_inds] + diff] + ) - v_and_perturbed = jnp.stack([voltages[indices], voltages[indices] + diff]) membrane_currents = vmap(channel.compute_current, in_axes=(None, 0, None))( channel_states, v_and_perturbed, channel_params ) + + # Split into voltage and constant terms. voltage_term = (membrane_currents[1] - membrane_currents[0]) / diff - constant_term = membrane_currents[0] - voltage_term * voltages[indices] + constant_term = membrane_currents[0] - voltage_term * voltages[channel_inds] # * 1000 to convert from mA/cm^2 to uA/cm^2. - voltage_terms = voltage_terms.at[indices].add(voltage_term * 1000.0) - constant_terms = constant_terms.at[indices].add(-constant_term * 1000.0) + voltage_terms = voltage_terms.at[channel_inds].add(voltage_term * 1000.0) + constant_terms = constant_terms.at[channel_inds].add( + -constant_term * 1000.0 + ) # Save the current (for the unperturbed voltage) as a state that will # also be passed to the state update. current_states[channel.current_name] = ( current_states[channel.current_name] - .at[indices] + .at[channel_inds] .add(membrane_currents[0]) ) @@ -2031,7 +2017,6 @@ def _channel_currents( # recorded and used by `Channel.update_states()`. for name in self.membrane_current_names: states[name] = current_states[name] - return states, (voltage_terms, constant_terms) def _step_synapse( diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index 15183bd6..98a00a69 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -21,6 +21,7 @@ convert_point_process_to_distributed, loc_of_index, merge_cells, + query_states_and_params, ) from jaxley.utils.misc_utils import concat_and_ignore_empty, cumsum_leading_zero from jaxley.utils.solver_utils import ( @@ -264,30 +265,21 @@ def _step_synapse_state( ) -> Dict: voltages = states["v"] - grouped_syns = edges.groupby("type", sort=False, group_keys=False) - pre_syn_inds = grouped_syns["pre_global_comp_index"].apply(list) - post_syn_inds = grouped_syns["post_global_comp_index"].apply(list) - synapse_names = list(grouped_syns.indices.keys()) + for i, group in edges.groupby("type_ind"): + synapse = syn_channels[i] + pre_inds = group["global_pre_comp_index"].to_numpy() + post_inds = group["global_post_comp_index"].to_numpy() + edge_inds = group.index.to_numpy() - for i, synapse_type in enumerate(syn_channels): - assert ( - synapse_names[i] == synapse_type._name - ), "Mixup in the ordering of synapses. Please create an issue on Github." - synapse_param_names = list(synapse_type.synapse_params.keys()) - synapse_state_names = list(synapse_type.synapse_states.keys()) - - synapse_params = {} - for p in synapse_param_names: - synapse_params[p] = params[p] - synapse_states = {} - for s in synapse_state_names: - synapse_states[s] = states[s] - - pre_inds = np.asarray(pre_syn_inds[synapse_names[i]]) - post_inds = np.asarray(post_syn_inds[synapse_names[i]]) + synapse_params = query_states_and_params( + params, synapse.synapse_params, edge_inds + ) + synapse_states = query_states_and_params( + states, synapse.synapse_states, edge_inds + ) # State updates. - states_updated = synapse_type.update_states( + states_updated = synapse.update_states( synapse_states, delta_t, voltages[pre_inds], @@ -295,9 +287,10 @@ def _step_synapse_state( synapse_params, ) - # Rebuild state. + # Rebuild state. This has to be done within the loop over channels to allow + # multiple channels which modify the same state. for key, val in states_updated.items(): - states[key] = val + states[key] = states[key].at[group.index.to_numpy()].set(val) return states @@ -311,43 +304,37 @@ def _synapse_currents( ) -> Tuple[Dict, Tuple[jnp.ndarray, jnp.ndarray]]: voltages = states["v"] - grouped_syns = edges.groupby("type", sort=False, group_keys=False) - pre_syn_inds = grouped_syns["pre_global_comp_index"].apply(list) - post_syn_inds = grouped_syns["post_global_comp_index"].apply(list) - synapse_names = list(grouped_syns.indices.keys()) - - syn_voltage_terms = jnp.zeros_like(voltages) - syn_constant_terms = jnp.zeros_like(voltages) + # Compute current through synapses. + zeros = jnp.zeros_like(voltages) + syn_voltage_terms = zeros + syn_constant_terms = zeros # Run with two different voltages that are `diff` apart to infer the slope and # offset. diff = 1e-3 - for i, synapse_type in enumerate(syn_channels): - assert ( - synapse_names[i] == synapse_type._name - ), "Mixup in the ordering of synapses. Please create an issue on Github." - synapse_param_names = list(synapse_type.synapse_params.keys()) - synapse_state_names = list(synapse_type.synapse_states.keys()) - - synapse_params = {} - for p in synapse_param_names: - synapse_params[p] = params[p] - synapse_states = {} - for s in synapse_state_names: - synapse_states[s] = states[s] - - # Get pre and post indexes of the current synapse type. - pre_inds = np.asarray(pre_syn_inds[synapse_names[i]]) - post_inds = np.asarray(post_syn_inds[synapse_names[i]]) - - # Compute slope and offset of the current through every synapse. + + synapse_current_states = {f"{s._name}_current": zeros for s in syn_channels} + for i, group in edges.groupby("type_ind"): + synapse = syn_channels[i] + pre_inds = group["global_pre_comp_index"].to_numpy() + post_inds = group["global_post_comp_index"].to_numpy() + edge_inds = group.index.to_numpy() + + synapse_params = query_states_and_params( + params, synapse.synapse_params, edge_inds + ) + synapse_states = query_states_and_params( + states, synapse.synapse_states, edge_inds + ) + pre_v_and_perturbed = jnp.stack( [voltages[pre_inds], voltages[pre_inds] + diff] ) post_v_and_perturbed = jnp.stack( [voltages[post_inds], voltages[post_inds] + diff] ) + synapse_currents = vmap( - synapse_type.compute_current, in_axes=(None, 0, 0, None) + synapse.compute_current, in_axes=(None, 0, 0, None) )( synapse_states, pre_v_and_perturbed, @@ -373,14 +360,22 @@ def _synapse_currents( voltage_term, constant_term, ) - syn_voltage_terms += gathered_syn_currents[0] - syn_constant_terms -= gathered_syn_currents[1] - - # Add the synaptic currents through every compartment as state. - # `post_syn_currents` is a `jnp.ndarray` of as many elements as there are - # compartments in the network. - # `[0]` because we only use the non-perturbed voltage. - states[f"i_{synapse_type._name}"] = synapse_currents[0] + + syn_voltage_terms = syn_voltage_terms.at[:].add(gathered_syn_currents[0]) + syn_constant_terms = syn_constant_terms.at[:].add(-gathered_syn_currents[1]) + + # Save the current (for the unperturbed voltage) as a state that will + # also be passed to the state update. + synapse_current_states[f"{synapse._name}_current"] = ( + synapse_current_states[f"{synapse._name}_current"] + .at[edge_inds] + .add(synapse_currents_dist[0]) + ) + + # Copy the currents into the `state` dictionary such that they can be + # recorded and used by `Channel.update_states()`. + for name in [s._name for s in self.synapses]: + states[f"{name}_current"] = synapse_current_states[f"{name}_current"] return states, (syn_voltage_terms, syn_constant_terms) diff --git a/jaxley/utils/cell_utils.py b/jaxley/utils/cell_utils.py index d71014c2..edd8b2d6 100644 --- a/jaxley/utils/cell_utils.py +++ b/jaxley/utils/cell_utils.py @@ -686,7 +686,7 @@ def group_and_sum( return group_sums -def query_channel_states_and_params(d, keys, idcs): +def query_states_and_params(d, keys, idcs): """Get dict with subset of keys and values from d. This is used to restrict a dict where every item contains __all__ states to only diff --git a/tests/test_make_trainable.py b/tests/test_make_trainable.py index 783461b3..0ef53655 100644 --- a/tests/test_make_trainable.py +++ b/tests/test_make_trainable.py @@ -102,9 +102,9 @@ def test_diverse_synapse_types(SimpleNet): assert np.all(all_parameters["length"] == 10.0) assert np.all(all_parameters["axial_resistivity"] == 5000.0) assert np.all(all_parameters["IonotropicSynapse_gS"][0] == 2.2) - assert np.all(all_parameters["IonotropicSynapse_gS"][1] == 2.2) - assert np.all(all_parameters["TestSynapse_gC"][0] == 3.3) - assert np.all(all_parameters["TestSynapse_gC"][1] == 4.4) + assert np.all(all_parameters["IonotropicSynapse_gS"][2] == 2.2) + assert np.all(all_parameters["TestSynapse_gC"][1] == 3.3) + assert np.all(all_parameters["TestSynapse_gC"][3] == 4.4) # Add another trainable parameter and test again. net.IonotropicSynapse.edge(1).make_trainable("IonotropicSynapse_gS") @@ -118,7 +118,7 @@ def test_diverse_synapse_types(SimpleNet): pstate = params_to_pstate(params, net.indices_set_by_trainables) all_parameters = net.get_all_parameters(pstate, voltage_solver="jaxley.thomas") assert np.all(all_parameters["IonotropicSynapse_gS"][0] == 2.2) - assert np.all(all_parameters["IonotropicSynapse_gS"][1] == 5.5) + assert np.all(all_parameters["IonotropicSynapse_gS"][2] == 5.5) def test_make_all_trainable_corresponds_to_set(SimpleNet): From 4107a46bbd07c258bb2b56e29c71d407cfbdc186 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Thu, 7 Nov 2024 14:26:40 +0100 Subject: [PATCH 03/26] fix: cleanup --- jaxley/modules/base.py | 51 ++++++++++++++++---------------------- jaxley/modules/network.py | 48 ++++++++++++----------------------- jaxley/utils/cell_utils.py | 2 +- 3 files changed, 39 insertions(+), 62 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 5e670b84..92ed9197 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -29,8 +29,8 @@ build_radiuses_from_xyzr, compute_axial_conductances, compute_levels, - convert_point_process_to_distributed, - interpolate_xyzr, + compute_current_density, + interpolate_xyz, loc_of_index, params_to_pstate, query_states_and_params, @@ -1922,24 +1922,21 @@ def _step_channels_state( voltages = states["v"] morph_params = ["radius", "length", "axial_resistivity", "capacitance"] - # Update states of the channels. for channel in channels: - has_channel = channel_nodes[channel._name] - channel_inds = channel_nodes.loc[ - has_channel, "global_comp_index" - ].to_numpy() + is_channel = channel_nodes[channel._name] + channel_inds = channel_nodes.loc[is_channel, "global_comp_index"].to_numpy() - channel_param_names = list(channel.channel_params) + morph_params - channel_params = query_states_and_params( - params, channel_param_names, channel_inds + query_channel = lambda d, names: query_states_and_params( + d, names, channel_inds ) + channel_param_names = list(channel.channel_params) + morph_params + channel_params = query_channel(params, channel_param_names) channel_state_names = ( list(channel.channel_states) + self.membrane_current_names ) - channel_states = query_states_and_params( - states, channel_state_names, channel_inds - ) + channel_states = query_channel(states, channel_state_names) + # States updates. states_updated = channel.update_states( channel_states, delta_t, voltages[channel_inds], channel_params ) @@ -1967,29 +1964,25 @@ def _channel_currents( # Compute current through channels. zeros = jnp.zeros_like(voltages) - voltage_terms = zeros - constant_terms = zeros + voltage_terms, constant_terms = zeros, zeros # Run with two different voltages that are `diff` apart to infer the slope and # offset. diff = 1e-3 current_states = {name: zeros for name in self.membrane_current_names} for channel in channels: - name = channel._name - channel_inds = channel_nodes.loc[channel_nodes[name]][ - "global_comp_index" - ].to_numpy() + is_channel = channel_nodes[channel._name] + channel_inds = channel_nodes.loc[is_channel, "global_comp_index"].to_numpy() - channel_params = query_states_and_params( - params, list(channel.channel_params) + morph_params, channel_inds - ) - channel_states = query_states_and_params( - states, channel.channel_states, channel_inds + query_channel = lambda d, names: query_states_and_params( + d, names, channel_inds ) + channel_param_names = list(channel.channel_params) + morph_params + channel_params = query_channel(params, channel_param_names) + channel_states = query_channel(states, channel.channel_states) - v_and_perturbed = jnp.stack( - [voltages[channel_inds], voltages[channel_inds] + diff] - ) + v_channel = voltages[channel_inds] + v_and_perturbed = jnp.array([v_channel, v_channel + diff]) membrane_currents = vmap(channel.compute_current, in_axes=(None, 0, None))( channel_states, v_and_perturbed, channel_params @@ -1997,7 +1990,7 @@ def _channel_currents( # Split into voltage and constant terms. voltage_term = (membrane_currents[1] - membrane_currents[0]) / diff - constant_term = membrane_currents[0] - voltage_term * voltages[channel_inds] + constant_term = membrane_currents[0] - voltage_term * v_channel # * 1000 to convert from mA/cm^2 to uA/cm^2. voltage_terms = voltage_terms.at[channel_inds].add(voltage_term * 1000.0) @@ -2058,7 +2051,7 @@ def _get_external_input( length_single_compartment: um. """ zero_vec = jnp.zeros_like(voltages) - current = convert_point_process_to_distributed( + current = compute_current_density( i_stim, radius[i_inds], length_single_compartment[i_inds] ) diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index 98a00a69..7e321d8f 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -18,7 +18,7 @@ from jaxley.utils.cell_utils import ( build_branchpoint_group_inds, compute_children_and_parents, - convert_point_process_to_distributed, + compute_current_density, loc_of_index, merge_cells, query_states_and_params, @@ -271,12 +271,9 @@ def _step_synapse_state( post_inds = group["global_post_comp_index"].to_numpy() edge_inds = group.index.to_numpy() - synapse_params = query_states_and_params( - params, synapse.synapse_params, edge_inds - ) - synapse_states = query_states_and_params( - states, synapse.synapse_states, edge_inds - ) + query_syn = lambda d, names: query_states_and_params(d, names, edge_inds) + synapse_params = query_syn(params, synapse.synapse_params) + synapse_states = query_syn(states, synapse.synapse_states) # State updates. states_updated = synapse.update_states( @@ -306,8 +303,7 @@ def _synapse_currents( # Compute current through synapses. zeros = jnp.zeros_like(voltages) - syn_voltage_terms = zeros - syn_constant_terms = zeros + syn_voltage_terms, syn_constant_terms = zeros, zeros # Run with two different voltages that are `diff` apart to infer the slope and # offset. diff = 1e-3 @@ -319,19 +315,13 @@ def _synapse_currents( post_inds = group["global_post_comp_index"].to_numpy() edge_inds = group.index.to_numpy() - synapse_params = query_states_and_params( - params, synapse.synapse_params, edge_inds - ) - synapse_states = query_states_and_params( - states, synapse.synapse_states, edge_inds - ) + query_syn = lambda d, names: query_states_and_params(d, names, edge_inds) + synapse_params = query_syn(params, synapse.synapse_params) + synapse_states = query_syn(states, synapse.synapse_states) - pre_v_and_perturbed = jnp.stack( - [voltages[pre_inds], voltages[pre_inds] + diff] - ) - post_v_and_perturbed = jnp.stack( - [voltages[post_inds], voltages[post_inds] + diff] - ) + v_pre, v_post = voltages[pre_inds], voltages[post_inds] + pre_v_and_perturbed = jnp.array([v_pre, v_pre + diff]) + post_v_and_perturbed = jnp.array([v_post, v_post + diff]) synapse_currents = vmap( synapse.compute_current, in_axes=(None, 0, 0, None) @@ -341,7 +331,7 @@ def _synapse_currents( post_v_and_perturbed, synapse_params, ) - synapse_currents_dist = convert_point_process_to_distributed( + synapse_currents_dist = compute_current_density( synapse_currents, params["radius"][post_inds], params["length"][post_inds], @@ -349,17 +339,12 @@ def _synapse_currents( # Split into voltage and constant terms. voltage_term = (synapse_currents_dist[1] - synapse_currents_dist[0]) / diff - constant_term = ( - synapse_currents_dist[0] - voltage_term * voltages[post_inds] - ) + constant_term = synapse_currents_dist[0] - voltage_term * v_post + syn_voltages = voltage_term, constant_term # Gather slope and offset for every postsynaptic compartment. - gathered_syn_currents = gather_synapes( - len(voltages), - post_inds, - voltage_term, - constant_term, - ) + num_comp = len(voltages) + gathered_syn_currents = gather_synapes(num_comp, post_inds, *syn_voltages) syn_voltage_terms = syn_voltage_terms.at[:].add(gathered_syn_currents[0]) syn_constant_terms = syn_constant_terms.at[:].add(-gathered_syn_currents[1]) @@ -376,7 +361,6 @@ def _synapse_currents( # recorded and used by `Channel.update_states()`. for name in [s._name for s in self.synapses]: states[f"{name}_current"] = synapse_current_states[f"{name}_current"] - return states, (syn_voltage_terms, syn_constant_terms) def arrange_in_layers( diff --git a/jaxley/utils/cell_utils.py b/jaxley/utils/cell_utils.py index edd8b2d6..21f1fe55 100644 --- a/jaxley/utils/cell_utils.py +++ b/jaxley/utils/cell_utils.py @@ -610,7 +610,7 @@ def params_to_pstate( ] -def convert_point_process_to_distributed( +def compute_current_density( current: jnp.ndarray, radius: jnp.ndarray, length: jnp.ndarray ) -> jnp.ndarray: """Convert current point process (nA) to distributed current (uA/cm2). From 050debc8a5c6c37ecb0eb002a85ba649740ffdd7 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Thu, 7 Nov 2024 14:30:55 +0100 Subject: [PATCH 04/26] fix: ran isort --- jaxley/modules/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 92ed9197..5fa96b3e 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -28,8 +28,8 @@ _compute_num_children, build_radiuses_from_xyzr, compute_axial_conductances, - compute_levels, compute_current_density, + compute_levels, interpolate_xyz, loc_of_index, params_to_pstate, From 3ffda84c04c1645ceacd218ad938afc836ce55b7 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Wed, 4 Dec 2024 22:20:34 +0100 Subject: [PATCH 05/26] wip: rename channel and synapse params and enh to_jax --- jaxley/channels/channel.py | 8 +-- jaxley/channels/hh.py | 4 +- jaxley/channels/pospischil.py | 24 ++++----- jaxley/modules/base.py | 89 +++++++++++++++++----------------- jaxley/modules/network.py | 16 +++--- jaxley/synapses/ionotropic.py | 4 +- jaxley/synapses/synapse.py | 8 +-- jaxley/synapses/tanh_rate.py | 4 +- jaxley/synapses/test.py | 4 +- tests/test_channels.py | 56 ++++++++++----------- tests/test_shared_state.py | 16 +++--- tests/test_syn.py | 2 +- tests/test_synapse_indexing.py | 6 +-- 13 files changed, 120 insertions(+), 121 deletions(-) diff --git a/jaxley/channels/channel.py b/jaxley/channels/channel.py index 678b1e1e..b8a1dc41 100644 --- a/jaxley/channels/channel.py +++ b/jaxley/channels/channel.py @@ -59,22 +59,22 @@ def change_name(self, new_name: str): new_prefix = new_name + "_" self._name = new_name - self.channel_params = { + self.params = { ( new_prefix + key[len(old_prefix) :] if key.startswith(old_prefix) else key ): value - for key, value in self.channel_params.items() + for key, value in self.params.items() } - self.channel_states = { + self.states = { ( new_prefix + key[len(old_prefix) :] if key.startswith(old_prefix) else key ): value - for key, value in self.channel_states.items() + for key, value in self.states.items() } return self diff --git a/jaxley/channels/hh.py b/jaxley/channels/hh.py index c19bf002..70fc72b5 100644 --- a/jaxley/channels/hh.py +++ b/jaxley/channels/hh.py @@ -17,7 +17,7 @@ def __init__(self, name: Optional[str] = None): super().__init__(name) prefix = self._name - self.channel_params = { + self.params = { f"{prefix}_gNa": 0.12, f"{prefix}_gK": 0.036, f"{prefix}_gLeak": 0.0003, @@ -25,7 +25,7 @@ def __init__(self, name: Optional[str] = None): f"{prefix}_eK": -77.0, f"{prefix}_eLeak": -54.3, } - self.channel_states = { + self.states = { f"{prefix}_m": 0.2, f"{prefix}_h": 0.2, f"{prefix}_n": 0.2, diff --git a/jaxley/channels/pospischil.py b/jaxley/channels/pospischil.py index 5884deac..8602a72c 100644 --- a/jaxley/channels/pospischil.py +++ b/jaxley/channels/pospischil.py @@ -40,11 +40,11 @@ def __init__(self, name: Optional[str] = None): super().__init__(name) prefix = self._name - self.channel_params = { + self.params = { f"{prefix}_gLeak": 1e-4, f"{prefix}_eLeak": -70.0, } - self.channel_states = {} + self.states = {} self.current_name = f"i_{prefix}" def update_states( @@ -77,12 +77,12 @@ def __init__(self, name: Optional[str] = None): super().__init__(name) prefix = self._name - self.channel_params = { + self.params = { f"{prefix}_gNa": 50e-3, "eNa": 50.0, "vt": -60.0, # Global parameter, not prefixed with `Na`. } - self.channel_states = {f"{prefix}_m": 0.2, f"{prefix}_h": 0.2} + self.states = {f"{prefix}_m": 0.2, f"{prefix}_h": 0.2} self.current_name = f"i_Na" def update_states( @@ -148,12 +148,12 @@ def __init__(self, name: Optional[str] = None): super().__init__(name) prefix = self._name - self.channel_params = { + self.params = { f"{prefix}_gK": 5e-3, "eK": -90.0, "vt": -60.0, # Global parameter, not prefixed with `Na`. } - self.channel_states = {f"{prefix}_n": 0.2} + self.states = {f"{prefix}_n": 0.2} self.current_name = f"i_K" def update_states( @@ -204,12 +204,12 @@ def __init__(self, name: Optional[str] = None): super().__init__(name) prefix = self._name - self.channel_params = { + self.params = { f"{prefix}_gKm": 0.004e-3, f"{prefix}_taumax": 4000.0, f"eK": -90.0, } - self.channel_states = {f"{prefix}_p": 0.2} + self.states = {f"{prefix}_p": 0.2} self.current_name = f"i_K" def update_states( @@ -261,11 +261,11 @@ def __init__(self, name: Optional[str] = None): super().__init__(name) prefix = self._name - self.channel_params = { + self.params = { f"{prefix}_gCaL": 0.1e-3, "eCa": 120.0, } - self.channel_states = {f"{prefix}_q": 0.2, f"{prefix}_r": 0.2} + self.states = {f"{prefix}_q": 0.2, f"{prefix}_r": 0.2} self.current_name = f"i_Ca" def update_states( @@ -329,12 +329,12 @@ def __init__(self, name: Optional[str] = None): super().__init__(name) prefix = self._name - self.channel_params = { + self.params = { f"{prefix}_gCaT": 0.4e-4, f"{prefix}_vx": 2.0, "eCa": 120.0, # Global parameter, not prefixed with `CaT`. } - self.channel_states = {f"{prefix}_u": 0.2} + self.states = {f"{prefix}_u": 0.2} self.current_name = f"i_Ca" def update_states( diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 5fa96b3e..7cb85e36 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -228,7 +228,6 @@ def __getattr__(self, key): view._set_controlled_by_param(key) # overwrites param set by edge # Ensure synapse param sharing works with `edge` # `edge` will be removed as part of #463 - view.edges["local_edge_index"] = np.arange(len(view.edges)) return view def _childviews(self) -> List[str]: @@ -734,17 +733,30 @@ def to_jax(self): they can be processed on GPU/TPU and such that the simulation can be differentiated. `.to_jax()` copies the `.nodes` to `.jaxnodes`. """ - jaxnodes = self.base.jaxnodes = {} - nodes = self.base.nodes.to_dict(orient="list") + jaxnodes, jaxedges = {}, {} - jaxedges = self.base.jaxedges = {} - edges = self.base.edges.to_dict(orient="list") - edges.pop("type") # drop since column type is string + for jax_arrays, data, mechs in zip( + [jaxnodes, jaxedges], + [self.nodes, self.edges], + [self.channels, self.synapses], + ): + jax_arrays.update({"index": data.index.to_numpy()}) + all_inds = jax_arrays["index"] + for mech in mechs: + inds = ( + all_inds[data["type"] == mech._name] + if "type" in data.columns + else all_inds[self.nodes[mech._name]] + ) + states_params = list(mech.params) + list(mech.states) + params = data[states_params].loc[inds] + jax_arrays.update({mech._name: inds}) + jax_arrays.update(params.to_dict(orient="list")) - for jax_array, params in zip([jaxnodes, jaxedges], [nodes, edges]): - for key, value in params.items(): - inds = jnp.arange(len(value)) - jax_array[key] = jnp.asarray(value)[inds] + morph_params = ["radius", "length", "axial_resistivity", "capacitance"] + jaxnodes.update(self.nodes[["v"]+morph_params].to_dict(orient="list")) + jaxnodes = {k: jnp.asarray(v) for k, v in jaxnodes.items()} + jaxedges = {k: jnp.asarray(v) for k, v in jaxedges.items()} def show( self, @@ -777,12 +789,8 @@ def show( inds = [f"{s}_{i}" for i in inds for s in scopes] if indices else [] cols += inds cols += [ch._name for ch in self.channels] if channel_names else [] - cols += ( - sum([list(ch.channel_params) for ch in self.channels], []) if params else [] - ) - cols += ( - sum([list(ch.channel_states) for ch in self.channels], []) if states else [] - ) + cols += sum([list(ch.params) for ch in self.channels], []) if params else [] + cols += sum([list(ch.states) for ch in self.channels], []) if states else [] if not param_names is None: cols = ( @@ -911,12 +919,8 @@ def set_ncomp( start_idx = self.nodes["global_comp_index"].to_numpy()[0] ncomp_per_branch = self.base.ncomp_per_branch channel_names = [c._name for c in self.base.channels] - channel_param_names = list( - chain(*[c.channel_params for c in self.base.channels]) - ) - channel_state_names = list( - chain(*[c.channel_states for c in self.base.channels]) - ) + channel_param_names = list(chain(*[c.params for c in self.base.channels])) + channel_state_names = list(chain(*[c.states for c in self.base.channels])) radius_generating_fns = self.base._radius_generating_fns within_branch_radiuses = view["radius"].to_numpy() @@ -1266,11 +1270,12 @@ def get_all_parameters( A dictionary of all module parameters. """ params = {} - for key in ["radius", "length", "axial_resistivity", "capacitance"]: + morph_params = ["radius", "length", "axial_resistivity", "capacitance"] + for key in ["v"] + morph_params: params[key] = self.base.jaxnodes[key] for channel in self.base.channels: - for channel_params in channel.channel_params: + for channel_params in channel.params: params[channel_params] = self.base.jaxnodes[channel_params] for synapse_params in self.base.synapse_param_names: @@ -1303,7 +1308,7 @@ def _get_states_from_nodes_and_edges(self) -> Dict[str, jnp.ndarray]: states = {"v": self.base.jaxnodes["v"]} # Join node and edge states into a single state dictionary. for channel in self.base.channels: - for channel_states in channel.channel_states: + for channel_states in channel.states: states[channel_states] = self.base.jaxnodes[channel_states] for synapse_states in self.base.synapse_state_names: states[synapse_states] = self.base.jaxedges[synapse_states] @@ -1386,8 +1391,8 @@ def init_states(self, delta_t: float = 0.025): ].to_numpy() voltages = channel_nodes.loc[channel_indices, "v"].to_numpy() - channel_param_names = list(channel.channel_params.keys()) - channel_state_names = list(channel.channel_states.keys()) + channel_param_names = list(channel.params.keys()) + channel_state_names = list(channel.states.keys()) channel_states = query_states_and_params( states, channel_state_names, channel_indices ) @@ -1724,12 +1729,12 @@ def insert(self, channel: Channel): self.base.nodes.loc[self._nodes_in_view, name] = True # Loop over all new parameters, e.g. gNa, eNa. - for key in channel.channel_params: - self.base.nodes.loc[self._nodes_in_view, key] = channel.channel_params[key] + for key in channel.params: + self.base.nodes.loc[self._nodes_in_view, key] = channel.params[key] # Loop over all new parameters, e.g. gNa, eNa. - for key in channel.channel_states: - self.base.nodes.loc[self._nodes_in_view, key] = channel.channel_states[key] + for key in channel.states: + self.base.nodes.loc[self._nodes_in_view, key] = channel.states[key] def delete_channel(self, channel: Channel): """Remove a channel from the module. @@ -1929,11 +1934,9 @@ def _step_channels_state( query_channel = lambda d, names: query_states_and_params( d, names, channel_inds ) - channel_param_names = list(channel.channel_params) + morph_params + channel_param_names = list(channel.params) + morph_params channel_params = query_channel(params, channel_param_names) - channel_state_names = ( - list(channel.channel_states) + self.membrane_current_names - ) + channel_state_names = list(channel.states) + self.membrane_current_names channel_states = query_channel(states, channel_state_names) # States updates. @@ -1977,9 +1980,9 @@ def _channel_currents( query_channel = lambda d, names: query_states_and_params( d, names, channel_inds ) - channel_param_names = list(channel.channel_params) + morph_params + channel_param_names = list(channel.params) + morph_params channel_params = query_channel(params, channel_param_names) - channel_states = query_channel(states, channel.channel_states) + channel_states = query_channel(states, channel.states) v_channel = voltages[channel_inds] v_and_perturbed = jnp.array([v_channel, v_channel + diff]) @@ -2565,13 +2568,9 @@ def _filter_trainables( ): pkey, pval = next(iter(params.items())) trainable_inds_in_view = None - if pkey in sum( - [list(c.channel_params.keys()) for c in self.base.channels], [] - ): + if pkey in sum([list(c.params.keys()) for c in self.base.channels], []): trainable_inds_in_view = np.intersect1d(inds, self._nodes_in_view) - elif pkey in sum( - [list(s.synapse_params.keys()) for s in self.base.synapses], [] - ): + elif pkey in sum([list(s.params.keys()) for s in self.base.synapses], []): trainable_inds_in_view = np.intersect1d(inds, self._edges_in_view) in_view = is_viewed == np.isin(inds, trainable_inds_in_view) @@ -2634,8 +2633,8 @@ def _set_synapses_in_view(self, pointer: Union[Module, View]): viewed_synapses += ( [syn] if in_view else [None] ) # padded with None to keep indices consistent - viewed_params += list(syn.synapse_params.keys()) if in_view else [] - viewed_states += list(syn.synapse_states.keys()) if in_view else [] + viewed_params += list(syn.params.keys()) if in_view else [] + viewed_states += list(syn.states.keys()) if in_view else [] self.synapses = viewed_synapses self.synapse_param_names = viewed_params self.synapse_state_names = viewed_states diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index 7e321d8f..9bcf8084 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -272,8 +272,8 @@ def _step_synapse_state( edge_inds = group.index.to_numpy() query_syn = lambda d, names: query_states_and_params(d, names, edge_inds) - synapse_params = query_syn(params, synapse.synapse_params) - synapse_states = query_syn(states, synapse.synapse_states) + synapse_params = query_syn(params, synapse.params) + synapse_states = query_syn(states, synapse.states) # State updates. states_updated = synapse.update_states( @@ -316,8 +316,8 @@ def _synapse_currents( edge_inds = group.index.to_numpy() query_syn = lambda d, names: query_states_and_params(d, names, edge_inds) - synapse_params = query_syn(params, synapse.synapse_params) - synapse_states = query_syn(states, synapse.synapse_states) + synapse_params = query_syn(params, synapse.params) + synapse_states = query_syn(states, synapse.states) v_pre, v_post = voltages[pre_inds], voltages[post_inds] pre_v_and_perturbed = jnp.array([v_pre, v_pre + diff]) @@ -493,8 +493,8 @@ def _infer_synapse_type_ind(self, synapse_name): def _update_synapse_state_names(self, synapse_type): # (Potentially) update variables that track meta information about synapses. self.base.synapse_names.append(synapse_type._name) - self.base.synapse_param_names += list(synapse_type.synapse_params.keys()) - self.base.synapse_state_names += list(synapse_type.synapse_states.keys()) + self.base.synapse_param_names += list(synapse_type.params.keys()) + self.base.synapse_state_names += list(synapse_type.states.keys()) self.base.synapses.append(synapse_type) def _append_multiple_synapses(self, pre_nodes, post_nodes, synapse_type): @@ -546,9 +546,9 @@ def _append_multiple_synapses(self, pre_nodes, post_nodes, synapse_type): def _add_params_to_edges(self, synapse_type, indices): # Add parameters and states to the `.edges` table. - for key, param_val in synapse_type.synapse_params.items(): + for key, param_val in synapse_type.params.items(): self.base.edges.loc[indices, key] = param_val # Update synaptic state array. - for key, state_val in synapse_type.synapse_states.items(): + for key, state_val in synapse_type.states.items(): self.base.edges.loc[indices, key] = state_val diff --git a/jaxley/synapses/ionotropic.py b/jaxley/synapses/ionotropic.py index da89113f..101dd95b 100644 --- a/jaxley/synapses/ionotropic.py +++ b/jaxley/synapses/ionotropic.py @@ -32,12 +32,12 @@ class IonotropicSynapse(Synapse): def __init__(self, name: Optional[str] = None): super().__init__(name) prefix = self._name - self.synapse_params = { + self.params = { f"{prefix}_gS": 1e-4, f"{prefix}_e_syn": 0.0, f"{prefix}_k_minus": 0.025, } - self.synapse_states = {f"{prefix}_s": 0.2} + self.states = {f"{prefix}_s": 0.2} def update_states( self, diff --git a/jaxley/synapses/synapse.py b/jaxley/synapses/synapse.py index a3b4752f..38cd7d3f 100644 --- a/jaxley/synapses/synapse.py +++ b/jaxley/synapses/synapse.py @@ -38,22 +38,22 @@ def change_name(self, new_name: str): new_prefix = new_name + "_" self._name = new_name - self.synapse_params = { + self.params = { ( new_prefix + key[len(old_prefix) :] if key.startswith(old_prefix) else key ): value - for key, value in self.synapse_params.items() + for key, value in self.params.items() } - self.synapse_states = { + self.states = { ( new_prefix + key[len(old_prefix) :] if key.startswith(old_prefix) else key ): value - for key, value in self.synapse_states.items() + for key, value in self.states.items() } return self diff --git a/jaxley/synapses/tanh_rate.py b/jaxley/synapses/tanh_rate.py index e006a278..6bbd49cc 100644 --- a/jaxley/synapses/tanh_rate.py +++ b/jaxley/synapses/tanh_rate.py @@ -16,12 +16,12 @@ class TanhRateSynapse(Synapse): def __init__(self, name: Optional[str] = None): super().__init__(name) prefix = self._name - self.synapse_params = { + self.params = { f"{prefix}_gS": 1e-4, f"{prefix}_x_offset": -70.0, f"{prefix}_slope": 1.0, } - self.synapse_states = {} + self.states = {} def update_states( self, diff --git a/jaxley/synapses/test.py b/jaxley/synapses/test.py index 49a7311e..84cb5d4d 100644 --- a/jaxley/synapses/test.py +++ b/jaxley/synapses/test.py @@ -19,8 +19,8 @@ class TestSynapse(Synapse): def __init__(self, name: Optional[str] = None): super().__init__(name) prefix = self._name - self.synapse_params = {f"{prefix}_gC": 1e-4} - self.synapse_states = {f"{prefix}_c": 0.2} + self.params = {f"{prefix}_gC": 1e-4} + self.states = {f"{prefix}_c": 0.2} def update_states( self, diff --git a/tests/test_channels.py b/tests/test_channels.py index 4063fd3e..fe1caead 100644 --- a/tests/test_channels.py +++ b/tests/test_channels.py @@ -25,13 +25,13 @@ def __init__( ): self.current_is_in_mA_per_cm2 = True super().__init__(name) - self.channel_params = { + self.params = { f"{self._name}_gamma": 0.05, # Fraction of free calcium (not buffered) f"{self._name}_decay": 80, # Rate of removal of calcium in ms f"{self._name}_depth": 0.1, # Depth of shell in um f"{self._name}_minCai": 1e-4, # Minimum intracellular calcium concentration in mM } - self.channel_states = { + self.states = { f"CaCon_i": 5e-05, # Initial internal calcium concentration in mM } self.current_name = f"i_Ca" @@ -84,8 +84,8 @@ def __init__( "T": 279.45, # Kelvin (temperature) "R": 8.314, # J/(mol K) (gas constant) } - self.channel_params = {} - self.channel_states = {"eCa": 0.0, "CaCon_i": 5e-05, "CaCon_e": 2.0} + self.params = {} + self.states = {"eCa": 0.0, "CaCon_i": 5e-05, "CaCon_e": 2.0} self.current_name = f"i_Ca" def update_states(self, u, dt, voltages, params): @@ -117,21 +117,21 @@ def test_channel_set_name(): # channel name can be set in the constructor na = Na(name="NaPospischil") assert na.name == "NaPospischil" - assert "NaPospischil_gNa" in na.channel_params.keys() - assert "eNa" in na.channel_params.keys() - assert "NaPospischil_h" in na.channel_states.keys() - assert "NaPospischil_m" in na.channel_states.keys() - assert "NaPospischil_vt" not in na.channel_params.keys() - assert "vt" in na.channel_params.keys() + assert "NaPospischil_gNa" in na.params.keys() + assert "eNa" in na.params.keys() + assert "NaPospischil_h" in na.states.keys() + assert "NaPospischil_m" in na.states.keys() + assert "NaPospischil_vt" not in na.params.keys() + assert "vt" in na.params.keys() # channel name can not be changed directly k = K() with pytest.raises(AttributeError): k.name = "KPospischil" - assert "KPospischil_gNa" not in k.channel_params.keys() - assert "eNa" not in k.channel_params.keys() - assert "KPospischil_h" not in k.channel_states.keys() - assert "KPospischil_m" not in k.channel_states.keys() + assert "KPospischil_gNa" not in k.params.keys() + assert "eNa" not in k.params.keys() + assert "KPospischil_h" not in k.states.keys() + assert "KPospischil_m" not in k.states.keys() def test_channel_change_name(): @@ -139,12 +139,12 @@ def test_channel_change_name(): # (and only this way after initialization) na = Na().change_name("NaPospischil") assert na.name == "NaPospischil" - assert "NaPospischil_gNa" in na.channel_params.keys() - assert "eNa" in na.channel_params.keys() - assert "NaPospischil_h" in na.channel_states.keys() - assert "NaPospischil_m" in na.channel_states.keys() - assert "NaPospischil_vt" not in na.channel_params.keys() - assert "vt" in na.channel_params.keys() + assert "NaPospischil_gNa" in na.params.keys() + assert "eNa" in na.params.keys() + assert "NaPospischil_h" in na.states.keys() + assert "NaPospischil_m" in na.states.keys() + assert "NaPospischil_vt" not in na.params.keys() + assert "vt" in na.params.keys() def test_integration_with_renamed_channels(): @@ -200,12 +200,12 @@ def __init__(self, name: Optional[str] = None): self.current_is_in_mA_per_cm2 = True super().__init__(name) prefix = self._name - self.channel_params = { + self.params = { f"{prefix}_q10_ch": 3, f"{prefix}_q10_ch0": 22, "celsius": 22, } - self.channel_states = {f"{prefix}_m": 0.02, "CaCon_i": 1e-4} + self.states = {f"{prefix}_m": 0.02, "CaCon_i": 1e-4} self.current_name = f"i_K" def update_states( @@ -291,8 +291,8 @@ class User(Channel): def __init__(self, name: Optional[str] = None): self.current_is_in_mA_per_cm2 = True super().__init__(name) - self.channel_params = {} - self.channel_states = {"cumulative": 0.0} + self.params = {} + self.states = {"cumulative": 0.0} self.current_name = f"i_User" def update_states(self, states, dt, v, params): @@ -307,8 +307,8 @@ class Dummy1(Channel): def __init__(self, name: Optional[str] = None): self.current_is_in_mA_per_cm2 = True super().__init__(name) - self.channel_params = {} - self.channel_states = {} + self.params = {} + self.states = {} self.current_name = f"i_Dummy" def update_states(self, states, dt, v, params): @@ -321,8 +321,8 @@ class Dummy2(Channel): def __init__(self, name: Optional[str] = None): self.current_is_in_mA_per_cm2 = True super().__init__(name) - self.channel_params = {} - self.channel_states = {} + self.params = {} + self.states = {} self.current_name = f"i_Dummy" def update_states(self, states, dt, v, params): diff --git a/tests/test_shared_state.py b/tests/test_shared_state.py index 0de88bb5..83541acd 100644 --- a/tests/test_shared_state.py +++ b/tests/test_shared_state.py @@ -22,8 +22,8 @@ class Dummy1(Channel): def __init__(self, name: Optional[str] = None): self.current_is_in_mA_per_cm2 = True super().__init__(name) - self.channel_params = {} - self.channel_states = {"Dummy_s": 0.0} + self.params = {} + self.states = {"Dummy_s": 0.0} self.current_name = f"i_Dummy1" @staticmethod @@ -45,8 +45,8 @@ class Dummy2(Channel): def __init__(self, name: Optional[str] = None): self.current_is_in_mA_per_cm2 = True super().__init__(name) - self.channel_params = {} - self.channel_states = {"Dummy_s": 0.0} + self.params = {} + self.states = {"Dummy_s": 0.0} self.current_name = f"i_Dummy2" @staticmethod @@ -68,10 +68,10 @@ class CaHVA(Channel): def __init__(self, name: Optional[str] = None): self.current_is_in_mA_per_cm2 = True super().__init__(name) - self.channel_params = { + self.params = { f"{self._name}_gCaHVA": 0.00001, # S/cm^2 } - self.channel_states = { + self.states = { f"{self._name}_m": 0.1, # Initial value for m gating variable f"{self._name}_h": 0.1, # Initial value for h gating variable "eCa": 0.0, # mV, assuming eca for demonstration @@ -140,13 +140,13 @@ def __init__( ): self.current_is_in_mA_per_cm2 = True super().__init__(name) - self.channel_params = { + self.params = { f"{self._name}_gamma": 0.05, # Fraction of free calcium (not buffered) f"{self._name}_decay": 80, # Rate of removal of calcium in ms f"{self._name}_depth": 0.1, # Depth of shell in um f"{self._name}_minCai": 1e-4, # Minimum intracellular calcium concentration in mM } - self.channel_states = { + self.states = { f"CaCon_i": 5e-05, # Initial internal calcium concentration in mM } self.current_name = f"i_Ca" diff --git a/tests/test_syn.py b/tests/test_syn.py index 3159e036..840fb341 100644 --- a/tests/test_syn.py +++ b/tests/test_syn.py @@ -27,7 +27,7 @@ def test_set_and_querying_params_one_type(SimpleNet): connect(pre, post, IonotropicSynapse()) # Get the synapse parameters to test setting - syn_params = list(IonotropicSynapse().synapse_params.keys()) + syn_params = list(IonotropicSynapse().params.keys()) for p in syn_params: net.set(p, 0.15) assert np.all(net.edges[p].to_numpy() == 0.15) diff --git a/tests/test_synapse_indexing.py b/tests/test_synapse_indexing.py index 150a5d83..d61934c4 100644 --- a/tests/test_synapse_indexing.py +++ b/tests/test_synapse_indexing.py @@ -68,7 +68,7 @@ def test_set_and_querying_params_one_type(synapse_type, SimpleNet): connect(pre, post, synapse_type) # Get the synapse parameters to test setting - syn_params = list(synapse_type.synapse_params.keys()) + syn_params = list(synapse_type.params.keys()) for p in syn_params: net.set(p, 0.15) assert np.all(net.edges[p].to_numpy() == 0.15) @@ -105,8 +105,8 @@ def test_set_and_querying_params_two_types(synapse_type, SimpleNet): post = net.cell(post_ind).branch(0).loc(0.0) connect(pre, post, synapse) - type1_params = list(IonotropicSynapse().synapse_params.keys()) - synapse_type_params = list(synapse_type.synapse_params.keys()) + type1_params = list(IonotropicSynapse().params.keys()) + synapse_type_params = list(synapse_type.params.keys()) default_synapse_type = net.edges[synapse_type_params[0]].to_numpy()[[1, 3]] From f63b6ee8857eed4096560f426b9279c7c1a3d1a2 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Thu, 5 Dec 2024 12:49:36 +0100 Subject: [PATCH 06/26] wip: make get_all_params work with new indexing --- jaxley/modules/base.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 7cb85e36..9dfc9fa9 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -737,8 +737,8 @@ def to_jax(self): for jax_arrays, data, mechs in zip( [jaxnodes, jaxedges], - [self.nodes, self.edges], - [self.channels, self.synapses], + [self.base.nodes, self.base.edges], + [self.base.channels, self.base.synapses], ): jax_arrays.update({"index": data.index.to_numpy()}) all_inds = jax_arrays["index"] @@ -754,9 +754,9 @@ def to_jax(self): jax_arrays.update(params.to_dict(orient="list")) morph_params = ["radius", "length", "axial_resistivity", "capacitance"] - jaxnodes.update(self.nodes[["v"]+morph_params].to_dict(orient="list")) - jaxnodes = {k: jnp.asarray(v) for k, v in jaxnodes.items()} - jaxedges = {k: jnp.asarray(v) for k, v in jaxedges.items()} + jaxnodes.update(self.nodes[["v"] + morph_params].to_dict(orient="list")) + self.base.jaxnodes = {k: jnp.asarray(v) for k, v in jaxnodes.items()} + self.base.jaxedges = {k: jnp.asarray(v) for k, v in jaxedges.items()} def show( self, @@ -1274,12 +1274,17 @@ def get_all_parameters( for key in ["v"] + morph_params: params[key] = self.base.jaxnodes[key] - for channel in self.base.channels: - for channel_params in channel.params: - params[channel_params] = self.base.jaxnodes[channel_params] - - for synapse_params in self.base.synapse_param_names: - params[synapse_params] = self.base.jaxedges[synapse_params] + for jax_arrays, data, mechs in zip( + [self.base.jaxnodes, self.base.jaxedges], + [self.base.nodes, self.base.edges], + [self.base.channels, self.base.synapses], + ): + for mech in mechs: + inds = jax_arrays[mech._name] + for mech_param in mech.params: + params[mech_param] = data[mech_param].to_numpy() + params[mech_param][inds] = jax_arrays[mech_param] + params[mech_param] = jnp.asarray(params[mech_param]) # Override with those parameters set by `.make_trainable()`. for parameter in pstate: From b9dd41139041f1bf2f05c8867ecf25b559d3a841 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Thu, 5 Dec 2024 14:20:03 +0100 Subject: [PATCH 07/26] wip: more tests passing, small refactor --- jaxley/modules/base.py | 112 +++++++++++++++++++++-------------------- 1 file changed, 58 insertions(+), 54 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 9dfc9fa9..804a88ac 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -30,8 +30,7 @@ compute_axial_conductances, compute_current_density, compute_levels, - interpolate_xyz, - loc_of_index, + interpolate_xyzr, params_to_pstate, query_states_and_params, v_interp, @@ -228,6 +227,7 @@ def __getattr__(self, key): view._set_controlled_by_param(key) # overwrites param set by edge # Ensure synapse param sharing works with `edge` # `edge` will be removed as part of #463 + view.edges["local_edge_index"] = np.arange(len(view.edges)) return view def _childviews(self) -> List[str]: @@ -1212,9 +1212,9 @@ def _get_state_names(self) -> Tuple[List, List]: """Collect all recordable / clampable states in the membrane and synapses. Returns states seperated by comps and edges.""" - channel_states = [name for c in self.channels for name in c.channel_states] + channel_states = [name for c in self.channels for name in c.states] synapse_states = [ - name for s in self.synapses if s is not None for name in s.synapse_states + name for s in self.synapses if s is not None for name in s.states ] membrane_states = ["v", "i"] + self.membrane_current_names return ( @@ -1233,6 +1233,26 @@ def get_parameters(self) -> List[Dict[str, jnp.ndarray]]: """ return self.trainable_params + @only_allow_module + def _iter_states_or_params(self, type="states") -> Dict[str, jnp.ndarray]: + # TODO FROM #447: MAKE THIS WORK FOR VIEW? + """Return states as they are set in the `.nodes` and `.edges` tables.""" + morph_params = ["radius", "length", "axial_resistivity", "capacitance"] + global_states = ["v"] + global_states_or_params = morph_params if type == "params" else global_states + for key in global_states_or_params: + yield key, self.base.jaxnodes["index"], self.base.jaxnodes[key] + + # Join node and edge states into a single state dictionary. + for jax_arrays, mechs in zip( + [self.base.jaxnodes, self.base.jaxedges], + [self.base.channels, self.base.synapses], + ): + for mech in mechs: + mech_inds = jax_arrays[mech._name] + for key in mech.__dict__[type]: + yield key, mech_inds, jax_arrays[key] + @only_allow_module def get_all_parameters( self, pstate: List[Dict], voltage_solver: str @@ -1269,34 +1289,24 @@ def get_all_parameters( Returns: A dictionary of all module parameters. """ - params = {} - morph_params = ["radius", "length", "axial_resistivity", "capacitance"] - for key in ["v"] + morph_params: - params[key] = self.base.jaxnodes[key] + pstate_inds = {d["key"]: i for i, d in enumerate(pstate)} - for jax_arrays, data, mechs in zip( - [self.base.jaxnodes, self.base.jaxedges], - [self.base.nodes, self.base.edges], - [self.base.channels, self.base.synapses], - ): - for mech in mechs: - inds = jax_arrays[mech._name] - for mech_param in mech.params: - params[mech_param] = data[mech_param].to_numpy() - params[mech_param][inds] = jax_arrays[mech_param] - params[mech_param] = jnp.asarray(params[mech_param]) + params = {} + for key, mech_inds, jax_array in self._iter_states_or_params("params"): + params[key] = jax_array - # Override with those parameters set by `.make_trainable()`. - for parameter in pstate: - key = parameter["key"] - inds = parameter["indices"] - set_param = parameter["val"] + # Override with those parameters set by `.make_trainable()`. + if key in pstate_inds: + idx = pstate_inds[key] + key = pstate[idx]["key"] + inds = pstate[idx]["indices"] + set_param = pstate[idx]["val"] - if key in params: # Only parameters, not initial states. # `inds` is of shape `(num_params, num_comps_per_param)`. # `set_param` is of shape `(num_params,)` - # We need to unsqueeze `set_param` to make it `(num_params, 1)` for the - # `.set()` to work. This is done with `[:, None]`. + # We need to unsqueeze `set_param` to make it `(num_params, 1)` + # for the `.set()` to work. This is done with `[:, None]`. + inds = np.searchsorted(mech_inds, inds) params[key] = params[key].at[inds].set(set_param[:, None]) # Compute conductance params and add them to the params dictionary. @@ -1305,20 +1315,6 @@ def get_all_parameters( ) return params - @only_allow_module - def _get_states_from_nodes_and_edges(self) -> Dict[str, jnp.ndarray]: - # TODO FROM #447: MAKE THIS WORK FOR VIEW? - """Return states as they are set in the `.nodes` and `.edges` tables.""" - self.base.to_jax() # Create `.jaxnodes` from `.nodes` and `.jaxedges` from `.edges`. - states = {"v": self.base.jaxnodes["v"]} - # Join node and edge states into a single state dictionary. - for channel in self.base.channels: - for channel_states in channel.states: - states[channel_states] = self.base.jaxnodes[channel_states] - for synapse_states in self.base.synapse_state_names: - states[synapse_states] = self.base.jaxedges[synapse_states] - return states - @only_allow_module def get_all_states( self, pstate: List[Dict], all_params, delta_t: float @@ -1334,18 +1330,23 @@ def get_all_states( Returns: A dictionary of all states of the module. """ - states = self.base._get_states_from_nodes_and_edges() - - # Override with the initial states set by `.make_trainable()`. - for parameter in pstate: - key = parameter["key"] - inds = parameter["indices"] - set_param = parameter["val"] - if key in states: # Only initial states, not parameters. - # `inds` is of shape `(num_params, num_comps_per_param)`. - # `set_param` is of shape `(num_params,)` - # We need to unsqueeze `set_param` to make it `(num_params, 1)` for the - # `.set()` to work. This is done with `[:, None]`. + pstate_inds = {d["key"]: i for i, d in enumerate(pstate)} + states = {} + for key, mech_inds, jax_array in self._iter_states_or_params("states"): + states[key] = jax_array + + # Override with those parameters set by `.make_trainable()`. + if key in pstate_inds: + idx = pstate_inds[key] + key = pstate[idx]["key"] + inds = pstate[idx]["indices"] + set_param = pstate[idx]["val"] + + # `inds` is of shape `(num_states, num_comps_per_param)`. + # `set_param` is of shape `(num_states,)` + # We need to unsqueeze `set_param` to make it `(num_states, 1)` + # for the `.set()` to work. This is done with `[:, None]`. + inds = np.searchsorted(mech_inds, inds) states[key] = states[key].at[inds].set(set_param[:, None]) # Add to the states the initial current through every channel. @@ -1380,8 +1381,11 @@ def init_states(self, delta_t: float = 0.025): delta_t: Passed on to `channel.init_state()`. """ # Update states of the channels. + self.base.to_jax() # Create `.jaxnodes` from `.nodes` and `.jaxedges` from `.edges`. channel_nodes = self.base.nodes - states = self.base._get_states_from_nodes_and_edges() + states = {} + for key, _, jax_array in self._iter_states_or_params("states"): + states[key] = jax_array # We do not use any `pstate` for initializing. In principle, we could change # that by allowing an input `params` and `pstate` to this function. From 2bf99a82d704e6f335795609424eee41a7957bf2 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Thu, 5 Dec 2024 19:00:34 +0100 Subject: [PATCH 08/26] wip: more tests passing some fixes --- jaxley/modules/base.py | 201 +++++++++++++++++++---------------- jaxley/modules/branch.py | 4 +- jaxley/modules/cell.py | 3 +- jaxley/modules/network.py | 11 +- jaxley/utils/cell_utils.py | 13 +++ tests/test_make_trainable.py | 8 +- 6 files changed, 137 insertions(+), 103 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 804a88ac..de62c0a7 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -746,7 +746,7 @@ def to_jax(self): inds = ( all_inds[data["type"] == mech._name] if "type" in data.columns - else all_inds[self.nodes[mech._name]] + else all_inds[data[mech._name]] ) states_params = list(mech.params) + list(mech.states) params = data[states_params].loc[inds] @@ -1121,6 +1121,7 @@ def make_trainable( f"Number of newly added trainable parameters: {num_created_parameters}. Total number of trainable parameters: {self.base.num_trainable_params}" ) + @only_allow_module def write_trainables(self, trainable_params: List[Dict[str, jnp.ndarray]]): """Write the trainables into `.nodes` and `.edges`. @@ -1145,22 +1146,26 @@ def write_trainables(self, trainable_params: List[Dict[str, jnp.ndarray]]): # any kind of issues with indexing or parameter sharing (as this is fully # taken care of by `get_all_parameters()`). self.base.to_jax() - pstate = params_to_pstate(trainable_params, self.base.indices_set_by_trainables) - all_params = self.base.get_all_parameters(pstate, voltage_solver="jaxley.stone") - # The value for `delta_t` does not matter here because it is only used to # compute the initial current. However, the initial current cannot be made # trainable and so its value never gets used below. - all_states = self.base.get_all_states(pstate, all_params, delta_t=0.025) + pstate = params_to_pstate(trainable_params, self.base.indices_set_by_trainables) + all_params_states = self.base._get_all_states_params( + pstate, + delta_t=0.025, + voltage_solver="jaxley.stone", + params=True, + states=True, + ) # Loop only over the keys in `pstate` to avoid unnecessary computation. for parameter in pstate: key = parameter["key"] - vals_to_set = all_params if key in all_params.keys() else all_states - if key in self.base.nodes.columns: - self.base.nodes[key] = vals_to_set[key] - if key in self.base.edges.columns: - self.base.edges[key] = vals_to_set[key] + mech, mech_inds = self.base._get_mech_inds_of_param_state(key) + data = ( + self.base.nodes if key in self.base.nodes.columns else self.base.edges + ) + data.loc[mech_inds, key] = all_params_states[key] def distance(self, endpoint: "View") -> float: """Return the direct distance between two compartments. @@ -1233,26 +1238,87 @@ def get_parameters(self) -> List[Dict[str, jnp.ndarray]]: """ return self.trainable_params - @only_allow_module - def _iter_states_or_params(self, type="states") -> Dict[str, jnp.ndarray]: + def _iter_states_params( + self, params=False, states=False + ) -> Tuple[str, jnp.ndarray, jnp.ndarray]: # TODO FROM #447: MAKE THIS WORK FOR VIEW? - """Return states as they are set in the `.nodes` and `.edges` tables.""" + + # assert that either params or states is True + assert params or states, "Either params or states must be True." morph_params = ["radius", "length", "axial_resistivity", "capacitance"] - global_states = ["v"] - global_states_or_params = morph_params if type == "params" else global_states - for key in global_states_or_params: - yield key, self.base.jaxnodes["index"], self.base.jaxnodes[key] + global_states_params = morph_params if params else [] + global_states_params += ["v"] if states else [] + for key in global_states_params: + yield key, self.jaxnodes["index"], self.jaxnodes[key] # Join node and edge states into a single state dictionary. for jax_arrays, mechs in zip( - [self.base.jaxnodes, self.base.jaxedges], - [self.base.channels, self.base.synapses], + [self.jaxnodes, self.jaxedges], + [self.channels, self.synapses], ): for mech in mechs: mech_inds = jax_arrays[mech._name] - for key in mech.__dict__[type]: + mech_params_states = mech.__dict__["params"] if params else {} + mech_params_states.update(mech.__dict__["states"] if states else {}) + for key in mech_params_states: yield key, mech_inds, jax_arrays[key] + def _get_mech_inds_of_param_state(self, key: str) -> Tuple[str, jnp.ndarray]: + jax_array = self.jaxnodes if key in self.nodes.columns else self.jaxedges + + if "_" in key and key not in ["axial_resistivity", "axial_conductances"]: + mech = key.split("_")[0] + return mech, jax_array[mech] + + return None, jax_array["index"] + + @only_allow_module + def _get_all_states_params( + self, + pstate: List[Dict], + voltage_solver=None, + delta_t=None, + all_params=None, + params=False, + states=False, + ) -> Dict[str, jnp.ndarray]: + states_params = {} + for key, _, jax_array in self.base._iter_states_params(params, states): + states_params[key] = jax_array + + # Override with those parameters set by `.make_trainable()`. + for parameter in pstate: + key = parameter["key"] + inds = parameter["indices"] + set_param = parameter["val"] + + if key in states_params: + mech, mech_inds = self.base._get_mech_inds_of_param_state(key) + # `inds` is of shape `(num_params, num_comps_per_param)`. + # `set_param` is of shape `(num_params,)` + # We need to unsqueeze `set_param` to make it `(num_params, 1)` + # for the `.set()` to work. This is done with `[:, None]`. + inds = np.searchsorted(mech_inds, inds) + states_params[key] = states_params[key].at[inds].set(set_param[:, None]) + + if params: + # Compute conductance params and add them to the params dictionary. + states_params["axial_conductances"] = self.base._compute_axial_conductances( + params=states_params + ) + if states: + all_params = states_params if all_params is None and params else all_params + # Add to the states the initial current through every channel. + states, _ = self.base._channel_currents( + states_params, delta_t, self.base.channels, self.base.nodes, all_params + ) + + # Add to the states the initial current through every synapse. + states, _ = self.base._synapse_currents( + states_params, self.base.synapses, all_params, delta_t, self.base.edges + ) + return states_params + @only_allow_module def get_all_parameters( self, pstate: List[Dict], voltage_solver: str @@ -1289,29 +1355,8 @@ def get_all_parameters( Returns: A dictionary of all module parameters. """ - pstate_inds = {d["key"]: i for i, d in enumerate(pstate)} - - params = {} - for key, mech_inds, jax_array in self._iter_states_or_params("params"): - params[key] = jax_array - - # Override with those parameters set by `.make_trainable()`. - if key in pstate_inds: - idx = pstate_inds[key] - key = pstate[idx]["key"] - inds = pstate[idx]["indices"] - set_param = pstate[idx]["val"] - - # `inds` is of shape `(num_params, num_comps_per_param)`. - # `set_param` is of shape `(num_params,)` - # We need to unsqueeze `set_param` to make it `(num_params, 1)` - # for the `.set()` to work. This is done with `[:, None]`. - inds = np.searchsorted(mech_inds, inds) - params[key] = params[key].at[inds].set(set_param[:, None]) - - # Compute conductance params and add them to the params dictionary. - params["axial_conductances"] = self.base._compute_axial_conductances( - params=params + params = self._get_all_states_params( + pstate, params=True, voltage_solver=voltage_solver ) return params @@ -1330,33 +1375,8 @@ def get_all_states( Returns: A dictionary of all states of the module. """ - pstate_inds = {d["key"]: i for i, d in enumerate(pstate)} - states = {} - for key, mech_inds, jax_array in self._iter_states_or_params("states"): - states[key] = jax_array - - # Override with those parameters set by `.make_trainable()`. - if key in pstate_inds: - idx = pstate_inds[key] - key = pstate[idx]["key"] - inds = pstate[idx]["indices"] - set_param = pstate[idx]["val"] - - # `inds` is of shape `(num_states, num_comps_per_param)`. - # `set_param` is of shape `(num_states,)` - # We need to unsqueeze `set_param` to make it `(num_states, 1)` - # for the `.set()` to work. This is done with `[:, None]`. - inds = np.searchsorted(mech_inds, inds) - states[key] = states[key].at[inds].set(set_param[:, None]) - - # Add to the states the initial current through every channel. - states, _ = self.base._channel_currents( - states, delta_t, self.channels, self.nodes, all_params - ) - - # Add to the states the initial current through every synapse. - states, _ = self.base._synapse_currents( - states, self.synapses, all_params, delta_t, self.edges + states = self._get_all_states_params( + pstate, states=True, all_params=all_params, delta_t=delta_t ) return states @@ -1384,7 +1404,7 @@ def init_states(self, delta_t: float = 0.025): self.base.to_jax() # Create `.jaxnodes` from `.nodes` and `.jaxedges` from `.edges`. channel_nodes = self.base.nodes states = {} - for key, _, jax_array in self._iter_states_or_params("states"): + for key, _, jax_array in self.base._iter_states_params(states=True): states[key] = jax_array # We do not use any `pstate` for initializing. In principle, we could change @@ -2517,26 +2537,25 @@ def _set_inds_in_view( def _jax_arrays_in_view(self, pointer: Union[Module, View]): """Update jaxnodes/jaxedges to show only those currently in view.""" a_intersects_b_at = lambda a, b: jnp.intersect1d(a, b, return_indices=True)[1] - jaxnodes = {} if pointer.jaxnodes is not None else None - if self.jaxnodes is not None: - comp_inds = pointer.jaxnodes["global_comp_index"] - common_inds = a_intersects_b_at(comp_inds, self._nodes_in_view) - jaxnodes = { - k: v[common_inds] - for k, v in pointer.jaxnodes.items() - if len(common_inds) > 0 - } - - jaxedges = {} if pointer.jaxedges is not None else None - if pointer.jaxedges is not None: - for key, values in self.base.jaxedges.items(): - if (syn_name := key.split("_")[0]) in self.synapse_names: - syn_edges = self.base.edges[self.base.edges["type"] == syn_name] - inds = np.intersect1d( - self._edges_in_view, syn_edges.index, return_indices=True - )[2] - if len(inds) > 0: - jaxedges[key] = values[inds] + + jaxnodes = {} if self.base.jaxnodes is not None else None + jaxedges = {} if self.base.jaxedges is not None else None + + mechs = [m._name for m in self.channels + self.synapses if m is not None] + for jax_array, base_jax_array, viewed_inds in zip( + [jaxnodes, jaxedges], + [self.base.jaxnodes, self.base.jaxedges], + [self._nodes_in_view, self._edges_in_view], + ): + if base_jax_array is not None and len(viewed_inds) > 0: + for key, values in base_jax_array.items(): + mech, mech_inds = self.base._get_mech_inds_of_param_state(key) + if mech is None or mech in mechs: + jax_array[key] = values[ + a_intersects_b_at(mech_inds, viewed_inds) + ] + jax_array["index"] = np.asarray(viewed_inds) + return jaxnodes, jaxedges def _set_externals_in_view(self): diff --git a/jaxley/modules/branch.py b/jaxley/modules/branch.py index 74ca31a4..f237c1b1 100644 --- a/jaxley/modules/branch.py +++ b/jaxley/modules/branch.py @@ -10,7 +10,7 @@ from jaxley.modules.base import Module from jaxley.modules.compartment import Compartment -from jaxley.utils.cell_utils import compute_children_and_parents +from jaxley.utils.cell_utils import compute_children_and_parents, dtype_aware_concat from jaxley.utils.misc_utils import cumsum_leading_zero, deprecated_kwargs from jaxley.utils.solver_utils import JaxleySolveIndexer, comp_edges_to_indices @@ -73,7 +73,7 @@ def __init__( self.cumsum_ncomp = cumsum_leading_zero(self.ncomp_per_branch) # Indexing. - self.nodes = pd.concat([c.nodes for c in compartment_list], ignore_index=True) + self.nodes = dtype_aware_concat([c.nodes for c in compartment_list]) self._append_params_and_states(self.branch_params, self.branch_states) self.nodes["global_comp_index"] = np.arange(self.ncomp).tolist() self.nodes["global_branch_index"] = [0] * self.ncomp diff --git a/jaxley/modules/cell.py b/jaxley/modules/cell.py index 3d6b39da..c440384b 100644 --- a/jaxley/modules/cell.py +++ b/jaxley/modules/cell.py @@ -18,6 +18,7 @@ compute_levels, compute_morphology_indices_in_levels, compute_parents_in_level, + dtype_aware_concat, ) from jaxley.utils.misc_utils import cumsum_leading_zero, deprecated_kwargs from jaxley.utils.solver_utils import ( @@ -102,7 +103,7 @@ def __init__( self._internal_node_inds = np.arange(self.cumsum_ncomp[-1]) # Build nodes. Has to be changed when `.set_ncomp()` is run. - self.nodes = pd.concat([c.nodes for c in branch_list], ignore_index=True) + self.nodes = dtype_aware_concat([c.nodes for c in branch_list]) self.nodes["global_comp_index"] = np.arange(self.cumsum_ncomp[-1]) self.nodes["global_branch_index"] = np.repeat( np.arange(self.total_nbranches), self.ncomp_per_branch diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index 9bcf8084..bac5b89c 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -19,6 +19,7 @@ build_branchpoint_group_inds, compute_children_and_parents, compute_current_density, + dtype_aware_concat, loc_of_index, merge_cells, query_states_and_params, @@ -67,7 +68,7 @@ def __init__( self.total_nbranches = sum(self.nbranches_per_cell) self._cumsum_nbranches = cumsum_leading_zero(self.nbranches_per_cell) - self.nodes = pd.concat([c.nodes for c in cells], ignore_index=True) + self.nodes = dtype_aware_concat([c.nodes for c in cells]) self.nodes["global_comp_index"] = np.arange(self.cumsum_ncomp[-1]) self.nodes["global_branch_index"] = np.repeat( np.arange(self.total_nbranches), self.ncomp_per_branch @@ -267,8 +268,8 @@ def _step_synapse_state( for i, group in edges.groupby("type_ind"): synapse = syn_channels[i] - pre_inds = group["global_pre_comp_index"].to_numpy() - post_inds = group["global_post_comp_index"].to_numpy() + pre_inds = group["pre_global_comp_index"].to_numpy() + post_inds = group["post_global_comp_index"].to_numpy() edge_inds = group.index.to_numpy() query_syn = lambda d, names: query_states_and_params(d, names, edge_inds) @@ -311,8 +312,8 @@ def _synapse_currents( synapse_current_states = {f"{s._name}_current": zeros for s in syn_channels} for i, group in edges.groupby("type_ind"): synapse = syn_channels[i] - pre_inds = group["global_pre_comp_index"].to_numpy() - post_inds = group["global_post_comp_index"].to_numpy() + pre_inds = group["pre_global_comp_index"].to_numpy() + post_inds = group["post_global_comp_index"].to_numpy() edge_inds = group.index.to_numpy() query_syn = lambda d, names: query_states_and_params(d, names, edge_inds) diff --git a/jaxley/utils/cell_utils.py b/jaxley/utils/cell_utils.py index 21f1fe55..d672640c 100644 --- a/jaxley/utils/cell_utils.py +++ b/jaxley/utils/cell_utils.py @@ -774,3 +774,16 @@ def compute_children_and_parents( child_belongs_to_branchpoint = remap_to_consecutive(par_inds) par_inds = np.unique(par_inds) return par_inds, child_inds, child_belongs_to_branchpoint + + +def dtype_aware_concat(dfs): + concat_df = pd.concat(dfs, ignore_index=True) + # replace nans with Nones + # this correctly casts float(None) -> NaN, bool(None) -> NaN, etc. + concat_df[concat_df.isna()] = None + for col in concat_df.columns[concat_df.dtypes == "object"]: + for df in dfs: + if col in df.columns: + concat_df[col] = concat_df[col].astype(df[col].dtype) + break # first match is sufficient + return concat_df diff --git a/tests/test_make_trainable.py b/tests/test_make_trainable.py index 0ef53655..783461b3 100644 --- a/tests/test_make_trainable.py +++ b/tests/test_make_trainable.py @@ -102,9 +102,9 @@ def test_diverse_synapse_types(SimpleNet): assert np.all(all_parameters["length"] == 10.0) assert np.all(all_parameters["axial_resistivity"] == 5000.0) assert np.all(all_parameters["IonotropicSynapse_gS"][0] == 2.2) - assert np.all(all_parameters["IonotropicSynapse_gS"][2] == 2.2) - assert np.all(all_parameters["TestSynapse_gC"][1] == 3.3) - assert np.all(all_parameters["TestSynapse_gC"][3] == 4.4) + assert np.all(all_parameters["IonotropicSynapse_gS"][1] == 2.2) + assert np.all(all_parameters["TestSynapse_gC"][0] == 3.3) + assert np.all(all_parameters["TestSynapse_gC"][1] == 4.4) # Add another trainable parameter and test again. net.IonotropicSynapse.edge(1).make_trainable("IonotropicSynapse_gS") @@ -118,7 +118,7 @@ def test_diverse_synapse_types(SimpleNet): pstate = params_to_pstate(params, net.indices_set_by_trainables) all_parameters = net.get_all_parameters(pstate, voltage_solver="jaxley.thomas") assert np.all(all_parameters["IonotropicSynapse_gS"][0] == 2.2) - assert np.all(all_parameters["IonotropicSynapse_gS"][2] == 5.5) + assert np.all(all_parameters["IonotropicSynapse_gS"][1] == 5.5) def test_make_all_trainable_corresponds_to_set(SimpleNet): From 0eaea44f94ad7b400110c07ee9b4cf8dc0725e78 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Fri, 6 Dec 2024 18:42:24 +0100 Subject: [PATCH 09/26] wip: new lookup table added --- jaxley/modules/base.py | 86 +++++++++++++++++++++++++++++++----------- tests/test_channels.py | 4 +- 2 files changed, 64 insertions(+), 26 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index de62c0a7..3b9be3da 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -740,14 +740,11 @@ def to_jax(self): [self.base.nodes, self.base.edges], [self.base.channels, self.base.synapses], ): - jax_arrays.update({"index": data.index.to_numpy()}) - all_inds = jax_arrays["index"] for mech in mechs: - inds = ( - all_inds[data["type"] == mech._name] - if "type" in data.columns - else all_inds[data[mech._name]] - ) + if isinstance(mech, Channel): + inds = data.index[data[mech._name]] + else: + inds = data.index[data["type"] == mech._name] states_params = list(mech.params) + list(mech.states) params = data[states_params].loc[inds] jax_arrays.update({mech._name: inds}) @@ -758,6 +755,14 @@ def to_jax(self): self.base.jaxnodes = {k: jnp.asarray(v) for k, v in jaxnodes.items()} self.base.jaxedges = {k: jnp.asarray(v) for k, v in jaxedges.items()} + # the parameters and states in the jaxnodes are stored on a per-mechanism basis, + # i.e. if only compartment #2 has a HH channels, then the jaxnodes will be + # {HH_gNa: [0.1], ...} vs. {gbar_HH: [NaN, NaN, 0.1, ...], ...}. This means + # a seperate lookup table is needed to figure out which parameters and states + # belong to which mechanism and at which global indices the mechanism lives. + + self.base._update_mech_lookup_table() + def show( self, param_names: Optional[Union[str, List[str]]] = None, @@ -1240,7 +1245,7 @@ def get_parameters(self) -> List[Dict[str, jnp.ndarray]]: def _iter_states_params( self, params=False, states=False - ) -> Tuple[str, jnp.ndarray, jnp.ndarray]: + ) -> Tuple[str, jnp.ndarray]: # TODO FROM #447: MAKE THIS WORK FOR VIEW? # assert that either params or states is True @@ -1249,7 +1254,7 @@ def _iter_states_params( global_states_params = morph_params if params else [] global_states_params += ["v"] if states else [] for key in global_states_params: - yield key, self.jaxnodes["index"], self.jaxnodes[key] + yield key, self.jaxnodes[key] # Join node and edge states into a single state dictionary. for jax_arrays, mechs in zip( @@ -1257,20 +1262,39 @@ def _iter_states_params( [self.channels, self.synapses], ): for mech in mechs: - mech_inds = jax_arrays[mech._name] mech_params_states = mech.__dict__["params"] if params else {} mech_params_states.update(mech.__dict__["states"] if states else {}) for key in mech_params_states: - yield key, mech_inds, jax_arrays[key] + yield key, jax_arrays[key] - def _get_mech_inds_of_param_state(self, key: str) -> Tuple[str, jnp.ndarray]: - jax_array = self.jaxnodes if key in self.nodes.columns else self.jaxedges + def _update_mech_lookup_table(self): + chan_items = [(list(c.params) + list(c.states), c._name) for c in self.channels] + syn_items = [(list(s.params) + list(s.states), s._name) for s in self.synapses] + + chan_inds = [ + {c._name + "_index": self.nodes.index[self.nodes[c._name]].to_numpy()} + for c in self.channels + ] + syn_inds = [ + { + s._name + + "_index": self.edges.index[self.edges["type"] == s._name].to_numpy() + } + for s in self.synapses + ] + + mech_items = [dict(zip(k, [v] * len(k))) for (k, v) in chan_items + syn_items] + mech_items += chan_inds + syn_inds - if "_" in key and key not in ["axial_resistivity", "axial_conductances"]: - mech = key.split("_")[0] - return mech, jax_array[mech] + self._mech_lookup_table = {k: v for d in mech_items for k, v in d.items()} - return None, jax_array["index"] + def _get_mech_inds_of_param_state(self, key: str) -> Tuple[str, jnp.ndarray]: + if key in self._mech_lookup_table: + mech = self._mech_lookup_table[key] + inds = self._mech_lookup_table[mech + "_index"] + return mech, inds + + return None, self._nodes_in_view @only_allow_module def _get_all_states_params( @@ -1283,7 +1307,7 @@ def _get_all_states_params( states=False, ) -> Dict[str, jnp.ndarray]: states_params = {} - for key, _, jax_array in self.base._iter_states_params(params, states): + for key, jax_array in self.base._iter_states_params(params, states): states_params[key] = jax_array # Override with those parameters set by `.make_trainable()`. @@ -1404,7 +1428,7 @@ def init_states(self, delta_t: float = 0.025): self.base.to_jax() # Create `.jaxnodes` from `.nodes` and `.jaxedges` from `.edges`. channel_nodes = self.base.nodes states = {} - for key, _, jax_array in self.base._iter_states_params(states=True): + for key, jax_array in self.base._iter_states_params(states=True): states[key] = jax_array # We do not use any `pstate` for initializing. In principle, we could change @@ -2537,11 +2561,28 @@ def _set_inds_in_view( def _jax_arrays_in_view(self, pointer: Union[Module, View]): """Update jaxnodes/jaxedges to show only those currently in view.""" a_intersects_b_at = lambda a, b: jnp.intersect1d(a, b, return_indices=True)[1] + morph_params = ["radius", "length", "axial_resistivity", "capacitance"] - jaxnodes = {} if self.base.jaxnodes is not None else None - jaxedges = {} if self.base.jaxedges is not None else None + jaxnodes = {} if self.base.jaxnodes else None + jaxedges = {} if self.base.jaxedges else None + + # None check is needed for View -> see `View._jax_arrays_in_view` + chan_mechs = [m._name for m in self.channels] + syn_mechs = [m._name for m in self.synapses if m is not None] + mechs = chan_mechs + syn_mechs + if self.base._mech_lookup_table: + self._mech_lookup_table = {} + for k, v in self.base._mech_lookup_table.items(): + if "index" in k: + mech = k.replace("_index", "") + if mech in chan_mechs: + v = v[a_intersects_b_at(v, self._nodes_in_view)] + elif mech in syn_mechs: + v = v[a_intersects_b_at(v, self._edges_in_view)] + self._mech_lookup_table[k] = v + elif v in mechs + ["v"] + morph_params: + self._mech_lookup_table[k] = v - mechs = [m._name for m in self.channels + self.synapses if m is not None] for jax_array, base_jax_array, viewed_inds in zip( [jaxnodes, jaxedges], [self.base.jaxnodes, self.base.jaxedges], @@ -2554,7 +2595,6 @@ def _jax_arrays_in_view(self, pointer: Union[Module, View]): jax_array[key] = values[ a_intersects_b_at(mech_inds, viewed_inds) ] - jax_array["index"] = np.asarray(viewed_inds) return jaxnodes, jaxedges diff --git a/tests/test_channels.py b/tests/test_channels.py index fe1caead..62f069c9 100644 --- a/tests/test_channels.py +++ b/tests/test_channels.py @@ -365,9 +365,7 @@ def test_delete_channel(SimpleBranch): branch3.delete_channel(K()) def channel_present(view, channel, partial=False): - states_and_params = list(channel.channel_states.keys()) + list( - channel.channel_params.keys() - ) + states_and_params = list(channel.states.keys()) + list(channel.params.keys()) # none of the states or params should be in nodes cols = view.nodes.columns.to_list() channel_cols = [ From 4b7395e25c491ef217f653080ce2f9ad02cadda3 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Fri, 6 Dec 2024 19:28:54 +0100 Subject: [PATCH 10/26] wip: more fixes --- jaxley/modules/base.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 3b9be3da..8ec9cf5d 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -1256,14 +1256,17 @@ def _iter_states_params( for key in global_states_params: yield key, self.jaxnodes[key] + # for key in self.synapse_current_names: + # yield key, self.jaxedges[key] + # Join node and edge states into a single state dictionary. for jax_arrays, mechs in zip( [self.jaxnodes, self.jaxedges], [self.channels, self.synapses], ): for mech in mechs: - mech_params_states = mech.__dict__["params"] if params else {} - mech_params_states.update(mech.__dict__["states"] if states else {}) + mech_params_states = mech.params if params else {} + mech_params_states.update(mech.states if states else {}) for key in mech_params_states: yield key, jax_arrays[key] @@ -1307,8 +1310,8 @@ def _get_all_states_params( states=False, ) -> Dict[str, jnp.ndarray]: states_params = {} - for key, jax_array in self.base._iter_states_params(params, states): - states_params[key] = jax_array + for key, jax_arrays in self.base._iter_states_params(params, states): + states_params[key] = jax_arrays # Override with those parameters set by `.make_trainable()`. for parameter in pstate: @@ -1428,8 +1431,8 @@ def init_states(self, delta_t: float = 0.025): self.base.to_jax() # Create `.jaxnodes` from `.nodes` and `.jaxedges` from `.edges`. channel_nodes = self.base.nodes states = {} - for key, jax_array in self.base._iter_states_params(states=True): - states[key] = jax_array + for key, jax_arrays in self.base._iter_states_params(states=True): + states[key] = jax_arrays # We do not use any `pstate` for initializing. In principle, we could change # that by allowing an input `params` and `pstate` to this function. @@ -1798,8 +1801,8 @@ def delete_channel(self, channel: Channel): channel_names = [c._name for c in self.channels] all_channel_names = [c._name for c in self.base.channels] if name in channel_names: - channel_cols = list(channel.channel_params.keys()) - channel_cols += list(channel.channel_states.keys()) + channel_cols = list(channel.params.keys()) + channel_cols += list(channel.states.keys()) self.base.nodes.loc[self._nodes_in_view, channel_cols] = float("nan") self.base.nodes.loc[self._nodes_in_view, name] = False @@ -2583,16 +2586,16 @@ def _jax_arrays_in_view(self, pointer: Union[Module, View]): elif v in mechs + ["v"] + morph_params: self._mech_lookup_table[k] = v - for jax_array, base_jax_array, viewed_inds in zip( + for jax_arrays, base_jax_arrays, viewed_inds in zip( [jaxnodes, jaxedges], [self.base.jaxnodes, self.base.jaxedges], [self._nodes_in_view, self._edges_in_view], ): - if base_jax_array is not None and len(viewed_inds) > 0: - for key, values in base_jax_array.items(): + if base_jax_arrays is not None and len(viewed_inds) > 0: + for key, values in base_jax_arrays.items(): mech, mech_inds = self.base._get_mech_inds_of_param_state(key) if mech is None or mech in mechs: - jax_array[key] = values[ + jax_arrays[key] = values.at[ a_intersects_b_at(mech_inds, viewed_inds) ] From 150cf6401c9374da525961c39338f14335a753d7 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Fri, 6 Dec 2024 23:01:13 +0100 Subject: [PATCH 11/26] wip: save wip, bug hunting in _synapse_current voltages --- jaxley/modules/base.py | 2 +- jaxley/modules/network.py | 17 ++++++++--------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 8ec9cf5d..d47614ae 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -747,7 +747,7 @@ def to_jax(self): inds = data.index[data["type"] == mech._name] states_params = list(mech.params) + list(mech.states) params = data[states_params].loc[inds] - jax_arrays.update({mech._name: inds}) + # jax_arrays.update({mech._name: inds}) jax_arrays.update(params.to_dict(orient="list")) morph_params = ["radius", "length", "axial_resistivity", "capacitance"] diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index bac5b89c..a8eec32b 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -251,6 +251,7 @@ def _step_synapse( ) -> Tuple[Dict, Tuple[jnp.ndarray, jnp.ndarray]]: """Perform one step of the synapses and obtain their currents.""" states = self._step_synapse_state(states, syn_channels, params, delta_t, edges) + # import jax; jax.debug.print("1 {}", states["TestSynapse_c"]) states, current_terms = self._synapse_currents( states, syn_channels, params, delta_t, edges ) @@ -288,8 +289,7 @@ def _step_synapse_state( # Rebuild state. This has to be done within the loop over channels to allow # multiple channels which modify the same state. for key, val in states_updated.items(): - states[key] = states[key].at[group.index.to_numpy()].set(val) - + states[key] = states[key].at[:].set(val) return states def _synapse_currents( @@ -309,7 +309,7 @@ def _synapse_currents( # offset. diff = 1e-3 - synapse_current_states = {f"{s._name}_current": zeros for s in syn_channels} + synapse_current_states = {f"i_{s._name}": zeros for s in syn_channels} for i, group in edges.groupby("type_ind"): synapse = syn_channels[i] pre_inds = group["pre_global_comp_index"].to_numpy() @@ -320,6 +320,7 @@ def _synapse_currents( synapse_params = query_syn(params, synapse.params) synapse_states = query_syn(states, synapse.states) + num_comp = len(voltages) v_pre, v_post = voltages[pre_inds], voltages[post_inds] pre_v_and_perturbed = jnp.array([v_pre, v_pre + diff]) post_v_and_perturbed = jnp.array([v_post, v_post + diff]) @@ -344,24 +345,22 @@ def _synapse_currents( syn_voltages = voltage_term, constant_term # Gather slope and offset for every postsynaptic compartment. - num_comp = len(voltages) gathered_syn_currents = gather_synapes(num_comp, post_inds, *syn_voltages) syn_voltage_terms = syn_voltage_terms.at[:].add(gathered_syn_currents[0]) syn_constant_terms = syn_constant_terms.at[:].add(-gathered_syn_currents[1]) - # Save the current (for the unperturbed voltage) as a state that will # also be passed to the state update. - synapse_current_states[f"{synapse._name}_current"] = ( - synapse_current_states[f"{synapse._name}_current"] - .at[edge_inds] + synapse_current_states[f"i_{synapse._name}"] = ( + synapse_current_states[f"i_{synapse._name}"] + .at[post_inds] .add(synapse_currents_dist[0]) ) # Copy the currents into the `state` dictionary such that they can be # recorded and used by `Channel.update_states()`. for name in [s._name for s in self.synapses]: - states[f"{name}_current"] = synapse_current_states[f"{name}_current"] + states[f"i_{name}"] = synapse_current_states[f"i_{name}"] return states, (syn_voltage_terms, syn_constant_terms) def arrange_in_layers( From f1b0e1c7753c873853590b94d643e4806c3e7667 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Sat, 7 Dec 2024 15:26:41 +0100 Subject: [PATCH 12/26] fix: fixed indexing --- jaxley/modules/base.py | 16 ++++------------ jaxley/modules/network.py | 17 ++++++++--------- jaxley/utils/cell_utils.py | 4 ++-- 3 files changed, 14 insertions(+), 23 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index d47614ae..f4604834 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -1449,12 +1449,8 @@ def init_states(self, delta_t: float = 0.025): channel_param_names = list(channel.params.keys()) channel_state_names = list(channel.states.keys()) - channel_states = query_states_and_params( - states, channel_state_names, channel_indices - ) - channel_params = query_states_and_params( - params, channel_param_names, channel_indices - ) + channel_states = query_states_and_params(states, channel_state_names) + channel_params = query_states_and_params(params, channel_param_names) init_state = channel.init_state( channel_states, voltages, channel_params, delta_t @@ -1987,9 +1983,7 @@ def _step_channels_state( is_channel = channel_nodes[channel._name] channel_inds = channel_nodes.loc[is_channel, "global_comp_index"].to_numpy() - query_channel = lambda d, names: query_states_and_params( - d, names, channel_inds - ) + query_channel = lambda d, names: query_states_and_params(d, names) channel_param_names = list(channel.params) + morph_params channel_params = query_channel(params, channel_param_names) channel_state_names = list(channel.states) + self.membrane_current_names @@ -2033,9 +2027,7 @@ def _channel_currents( is_channel = channel_nodes[channel._name] channel_inds = channel_nodes.loc[is_channel, "global_comp_index"].to_numpy() - query_channel = lambda d, names: query_states_and_params( - d, names, channel_inds - ) + query_channel = lambda d, names: query_states_and_params(d, names) channel_param_names = list(channel.params) + morph_params channel_params = query_channel(params, channel_param_names) channel_states = query_channel(states, channel.states) diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index a8eec32b..a1532630 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -309,18 +309,16 @@ def _synapse_currents( # offset. diff = 1e-3 + num_comp = len(voltages) synapse_current_states = {f"i_{s._name}": zeros for s in syn_channels} for i, group in edges.groupby("type_ind"): synapse = syn_channels[i] pre_inds = group["pre_global_comp_index"].to_numpy() post_inds = group["post_global_comp_index"].to_numpy() - edge_inds = group.index.to_numpy() - query_syn = lambda d, names: query_states_and_params(d, names, edge_inds) - synapse_params = query_syn(params, synapse.params) - synapse_states = query_syn(states, synapse.states) + synapse_params = query_states_and_params(params, synapse.params) + synapse_states = query_states_and_params(states, synapse.states) - num_comp = len(voltages) v_pre, v_post = voltages[pre_inds], voltages[post_inds] pre_v_and_perturbed = jnp.array([v_pre, v_pre + diff]) post_v_and_perturbed = jnp.array([v_post, v_post + diff]) @@ -345,6 +343,7 @@ def _synapse_currents( syn_voltages = voltage_term, constant_term # Gather slope and offset for every postsynaptic compartment. + # import jax; jax.debug.print("{}", synapse_params) gathered_syn_currents = gather_synapes(num_comp, post_inds, *syn_voltages) syn_voltage_terms = syn_voltage_terms.at[:].add(gathered_syn_currents[0]) @@ -357,10 +356,10 @@ def _synapse_currents( .add(synapse_currents_dist[0]) ) - # Copy the currents into the `state` dictionary such that they can be - # recorded and used by `Channel.update_states()`. - for name in [s._name for s in self.synapses]: - states[f"i_{name}"] = synapse_current_states[f"i_{name}"] + # Copy the currents into the `state` dictionary such that they can be + # recorded and used by `Channel.update_states()`. + for name in [s._name for s in self.synapses]: + states[f"i_{name}"] = synapse_current_states[f"i_{name}"] return states, (syn_voltage_terms, syn_constant_terms) def arrange_in_layers( diff --git a/jaxley/utils/cell_utils.py b/jaxley/utils/cell_utils.py index d672640c..a1e35c81 100644 --- a/jaxley/utils/cell_utils.py +++ b/jaxley/utils/cell_utils.py @@ -686,7 +686,7 @@ def group_and_sum( return group_sums -def query_states_and_params(d, keys, idcs): +def query_states_and_params(d, keys, idcs=None): """Get dict with subset of keys and values from d. This is used to restrict a dict where every item contains __all__ states to only @@ -698,7 +698,7 @@ def query_states_and_params(d, keys, idcs): ```states = {'eCa': Array([ 0., 0.]}``` Only loops over necessary keys, as opposed to looping over `d.items()`.""" - return dict(zip(keys, (v[idcs] for v in map(d.get, keys)))) + return dict(zip(keys, (v if idcs is None else v[idcs] for v in map(d.get, keys)))) def compute_axial_conductances( From 05241511bcd01945f3e4ab3f3e851629f13e30f6 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Sat, 7 Dec 2024 15:48:34 +0100 Subject: [PATCH 13/26] fix: fix remaining indexing issues, tests passing (I think) --- jaxley/modules/base.py | 24 ++++++++---------------- jaxley/modules/network.py | 15 +++++++-------- 2 files changed, 15 insertions(+), 24 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index f4604834..c3a17d95 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -1442,9 +1442,7 @@ def init_states(self, delta_t: float = 0.025): for channel in self.base.channels: name = channel._name - channel_indices = channel_nodes.loc[channel_nodes[name]][ - "global_comp_index" - ].to_numpy() + channel_indices = self._mech_lookup_table[name + "_index"] voltages = channel_nodes.loc[channel_indices, "v"].to_numpy() channel_param_names = list(channel.params.keys()) @@ -1980,23 +1978,20 @@ def _step_channels_state( morph_params = ["radius", "length", "axial_resistivity", "capacitance"] for channel in channels: - is_channel = channel_nodes[channel._name] - channel_inds = channel_nodes.loc[is_channel, "global_comp_index"].to_numpy() - - query_channel = lambda d, names: query_states_and_params(d, names) channel_param_names = list(channel.params) + morph_params - channel_params = query_channel(params, channel_param_names) + channel_params = query_states_and_params(params, channel_param_names) channel_state_names = list(channel.states) + self.membrane_current_names - channel_states = query_channel(states, channel_state_names) + channel_states = query_states_and_params(states, channel_state_names) # States updates. + channel_inds = self._mech_lookup_table[channel._name + "_index"] states_updated = channel.update_states( channel_states, delta_t, voltages[channel_inds], channel_params ) # Rebuild state. This has to be done within the loop over channels to allow # multiple channels which modify the same state. for key, val in states_updated.items(): - states[key] = states[key].at[channel_inds].set(val) + states[key] = states[key].at[:].set(val) return states @@ -2024,14 +2019,11 @@ def _channel_currents( current_states = {name: zeros for name in self.membrane_current_names} for channel in channels: - is_channel = channel_nodes[channel._name] - channel_inds = channel_nodes.loc[is_channel, "global_comp_index"].to_numpy() - - query_channel = lambda d, names: query_states_and_params(d, names) channel_param_names = list(channel.params) + morph_params - channel_params = query_channel(params, channel_param_names) - channel_states = query_channel(states, channel.states) + channel_params = query_states_and_params(params, channel_param_names) + channel_states = query_states_and_params(states, channel.states) + channel_inds = self._mech_lookup_table[channel._name + "_index"] v_channel = voltages[channel_inds] v_and_perturbed = jnp.array([v_channel, v_channel + diff]) diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index a1532630..3f1f7339 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -271,11 +271,9 @@ def _step_synapse_state( synapse = syn_channels[i] pre_inds = group["pre_global_comp_index"].to_numpy() post_inds = group["post_global_comp_index"].to_numpy() - edge_inds = group.index.to_numpy() - query_syn = lambda d, names: query_states_and_params(d, names, edge_inds) - synapse_params = query_syn(params, synapse.params) - synapse_states = query_syn(states, synapse.states) + synapse_params = query_states_and_params(params, synapse.params) + synapse_states = query_states_and_params(states, synapse.states) # State updates. states_updated = synapse.update_states( @@ -290,6 +288,7 @@ def _step_synapse_state( # multiple channels which modify the same state. for key, val in states_updated.items(): states[key] = states[key].at[:].set(val) + return states def _synapse_currents( @@ -356,10 +355,10 @@ def _synapse_currents( .add(synapse_currents_dist[0]) ) - # Copy the currents into the `state` dictionary such that they can be - # recorded and used by `Channel.update_states()`. - for name in [s._name for s in self.synapses]: - states[f"i_{name}"] = synapse_current_states[f"i_{name}"] + # Copy the currents into the `state` dictionary such that they can be + # recorded and used by `Channel.update_states()`. + for name in [s._name for s in self.synapses]: + states[f"i_{name}"] = synapse_current_states[f"i_{name}"] return states, (syn_voltage_terms, syn_constant_terms) def arrange_in_layers( From 8820561dc894a3714bd93e57f0ad66fb16bf2641 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Mon, 9 Dec 2024 15:10:19 +0100 Subject: [PATCH 14/26] wip: wip fixing multiple mechs with same param / state --- jaxley/modules/base.py | 279 ++++++++++++++++++++----------------- jaxley/modules/network.py | 30 ++-- jaxley/utils/cell_utils.py | 15 -- 3 files changed, 163 insertions(+), 161 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index c3a17d95..17cf4223 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -32,7 +32,6 @@ compute_levels, interpolate_xyzr, params_to_pstate, - query_states_and_params, v_interp, ) from jaxley.utils.debug_solver import compute_morphology_indices @@ -733,35 +732,30 @@ def to_jax(self): they can be processed on GPU/TPU and such that the simulation can be differentiated. `.to_jax()` copies the `.nodes` to `.jaxnodes`. """ - jaxnodes, jaxedges = {}, {} - - for jax_arrays, data, mechs in zip( - [jaxnodes, jaxedges], - [self.base.nodes, self.base.edges], - [self.base.channels, self.base.synapses], - ): - for mech in mechs: - if isinstance(mech, Channel): - inds = data.index[data[mech._name]] - else: - inds = data.index[data["type"] == mech._name] - states_params = list(mech.params) + list(mech.states) - params = data[states_params].loc[inds] - # jax_arrays.update({mech._name: inds}) - jax_arrays.update(params.to_dict(orient="list")) - - morph_params = ["radius", "length", "axial_resistivity", "capacitance"] - jaxnodes.update(self.nodes[["v"] + morph_params].to_dict(orient="list")) - self.base.jaxnodes = {k: jnp.asarray(v) for k, v in jaxnodes.items()} - self.base.jaxedges = {k: jnp.asarray(v) for k, v in jaxedges.items()} - # the parameters and states in the jaxnodes are stored on a per-mechanism basis, # i.e. if only compartment #2 has a HH channels, then the jaxnodes will be # {HH_gNa: [0.1], ...} vs. {gbar_HH: [NaN, NaN, 0.1, ...], ...}. This means # a seperate lookup table is needed to figure out which parameters and states # belong to which mechanism and at which global indices the mechanism lives. + self._update_mech_lookup_table() + syn_param_states = sum( + [list(s.params) + list(s.states) for s in self.synapses], [] + ) - self.base._update_mech_lookup_table() + jaxnodes, jaxedges = {}, {} + + for state_param in self._mech_lookup: + mech_inds = self._get_state_param_inds(state_param) + data = self.edges if state_param in syn_param_states else self.nodes + jax_arrays = jaxedges if state_param in syn_param_states else jaxnodes + + values = data.loc[mech_inds, state_param].to_numpy() + jax_arrays.update({state_param: values}) + + morph_params = ["radius", "length", "axial_resistivity", "capacitance"] + jaxnodes.update(self.nodes[["v"] + morph_params].to_dict(orient="list")) + self.jaxnodes = {k: jnp.asarray(v) for k, v in jaxnodes.items()} + self.jaxedges = {k: jnp.asarray(v) for k, v in jaxedges.items()} def show( self, @@ -1150,12 +1144,12 @@ def write_trainables(self, trainable_params: List[Dict[str, jnp.ndarray]]): # However, I think it allows us to reuse as much code as possible and it avoids # any kind of issues with indexing or parameter sharing (as this is fully # taken care of by `get_all_parameters()`). - self.base.to_jax() + self.to_jax() # The value for `delta_t` does not matter here because it is only used to # compute the initial current. However, the initial current cannot be made # trainable and so its value never gets used below. - pstate = params_to_pstate(trainable_params, self.base.indices_set_by_trainables) - all_params_states = self.base._get_all_states_params( + pstate = params_to_pstate(trainable_params, self.indices_set_by_trainables) + all_params_states = self._get_all_states_params( pstate, delta_t=0.025, voltage_solver="jaxley.stone", @@ -1166,10 +1160,9 @@ def write_trainables(self, trainable_params: List[Dict[str, jnp.ndarray]]): # Loop only over the keys in `pstate` to avoid unnecessary computation. for parameter in pstate: key = parameter["key"] - mech, mech_inds = self.base._get_mech_inds_of_param_state(key) - data = ( - self.base.nodes if key in self.base.nodes.columns else self.base.edges - ) + mech_inds = self._get_state_param_inds(key) + mech_inds = np.concatenate(list(mech_inds.values())) + data = self.nodes if key in self.nodes.columns else self.edges data.loc[mech_inds, key] = all_params_states[key] def distance(self, endpoint: "View") -> float: @@ -1271,33 +1264,64 @@ def _iter_states_params( yield key, jax_arrays[key] def _update_mech_lookup_table(self): - chan_items = [(list(c.params) + list(c.states), c._name) for c in self.channels] - syn_items = [(list(s.params) + list(s.states), s._name) for s in self.synapses] - - chan_inds = [ - {c._name + "_index": self.nodes.index[self.nodes[c._name]].to_numpy()} - for c in self.channels - ] - syn_inds = [ - { - s._name - + "_index": self.edges.index[self.edges["type"] == s._name].to_numpy() + state_param_lookup = {} + mech2inds = {} + + for mech in self.channels + self.synapses: + is_channel = isinstance(mech, Channel) + data = self.nodes if is_channel else self.edges + cond = data[mech._name] if is_channel else data["type"] == mech._name + + chan_inds = {mech._name: data.index[cond].to_numpy()} + mech2inds.update(chan_inds) + for item in list(mech.params) + list(mech.states): + inds = chan_inds[mech._name] + mech_info = {mech._name: {"global_index": inds}} + if item not in state_param_lookup: + state_param_lookup[item] = mech_info + else: + state_param_lookup[item].update(mech_info) + + for state_param, info in state_param_lookup.items(): + global_inds = [v["global_index"] for v in info.values()] + comb_inds = np.unique(np.concatenate(global_inds)) + local_inds = { + k: jnp.searchsorted(comb_inds, v["global_index"]) + for k, v in info.items() } - for s in self.synapses - ] - - mech_items = [dict(zip(k, [v] * len(k))) for (k, v) in chan_items + syn_items] - mech_items += chan_inds + syn_inds + new_info = {} + for v1, (k, v2) in zip(global_inds, local_inds.items()): + new_info.update({k: {"global_index": v1, "local_index": v2}}) + state_param_lookup[state_param].update(new_info) + + self._mech_lookup = state_param_lookup + self._mech_inds = mech2inds + + def _get_state_param_inds(self, key: str) -> Tuple[str, jnp.ndarray]: + inds = None + if key in self._mech_lookup: + mechs = self._mech_lookup[key] + inds = np.unique(np.concatenate([self._mech_inds[m] for m in mechs])) + elif key in self.nodes.columns: + inds = {key: self._nodes_in_view} + elif key in self.edges.columns: + inds = {key: self._edges_in_view} + else: + raise KeyError(f"Key '{key}' not found in nodes or edges") - self._mech_lookup_table = {k: v for d in mech_items for k, v in d.items()} + return inds - def _get_mech_inds_of_param_state(self, key: str) -> Tuple[str, jnp.ndarray]: - if key in self._mech_lookup_table: - mech = self._mech_lookup_table[key] - inds = self._mech_lookup_table[mech + "_index"] - return mech, inds + def _filter_states_params( + self, states_params: Dict[str, jnp.ndarray], mech: Union[Channel, Synapse] + ) -> Dict[str, jnp.ndarray]: + pkeys = list(mech.params) + mech_states_params = pkeys if pkeys[0] in states_params else list(mech.states) - return None, self._nodes_in_view + filtered_states_params = {} + for key in mech_states_params: + mech_inds = self._mech_lookup[key][mech._name]["local_index"] + filtered_states_params[key] = states_params[key][mech_inds] + return filtered_states_params @only_allow_module def _get_all_states_params( @@ -1310,39 +1334,38 @@ def _get_all_states_params( states=False, ) -> Dict[str, jnp.ndarray]: states_params = {} - for key, jax_arrays in self.base._iter_states_params(params, states): + for key, jax_arrays in self._iter_states_params(params, states): states_params[key] = jax_arrays # Override with those parameters set by `.make_trainable()`. - for parameter in pstate: - key = parameter["key"] - inds = parameter["indices"] - set_param = parameter["val"] + for p in pstate: + key, inds, set_param = p["key"], p["indices"], p["val"] if key in states_params: - mech, mech_inds = self.base._get_mech_inds_of_param_state(key) # `inds` is of shape `(num_params, num_comps_per_param)`. # `set_param` is of shape `(num_params,)` # We need to unsqueeze `set_param` to make it `(num_params, 1)` # for the `.set()` to work. This is done with `[:, None]`. - inds = np.searchsorted(mech_inds, inds) + mech_inds = self._get_state_param_inds(key) + inds = jnp.searchsorted(mech_inds, inds) states_params[key] = states_params[key].at[inds].set(set_param[:, None]) if params: # Compute conductance params and add them to the params dictionary. - states_params["axial_conductances"] = self.base._compute_axial_conductances( + states_params["axial_conductances"] = self._compute_axial_conductances( params=states_params ) + if states: all_params = states_params if all_params is None and params else all_params # Add to the states the initial current through every channel. - states, _ = self.base._channel_currents( - states_params, delta_t, self.base.channels, self.base.nodes, all_params + states, _ = self._channel_currents( + states_params, delta_t, self.channels, self.nodes, all_params ) # Add to the states the initial current through every synapse. - states, _ = self.base._synapse_currents( - states_params, self.base.synapses, all_params, delta_t, self.base.edges + states, _ = self._synapse_currents( + states_params, self.synapses, all_params, delta_t, self.edges ) return states_params @@ -1428,30 +1451,25 @@ def init_states(self, delta_t: float = 0.025): delta_t: Passed on to `channel.init_state()`. """ # Update states of the channels. - self.base.to_jax() # Create `.jaxnodes` from `.nodes` and `.jaxedges` from `.edges`. - channel_nodes = self.base.nodes + self.to_jax() # Create `.jaxnodes` from `.nodes` and `.jaxedges` from `.edges`. states = {} - for key, jax_arrays in self.base._iter_states_params(states=True): + for key, jax_arrays in self._iter_states_params(states=True): states[key] = jax_arrays # We do not use any `pstate` for initializing. In principle, we could change # that by allowing an input `params` and `pstate` to this function. # `voltage_solver` could also be `jax.sparse` here, because both of them # build the channel parameters in the same way. - params = self.base.get_all_parameters([], voltage_solver="jaxley.thomas") + params = self.get_all_parameters([], voltage_solver="jaxley.thomas") + voltages = self.nodes["v"].to_numpy() - for channel in self.base.channels: - name = channel._name - channel_indices = self._mech_lookup_table[name + "_index"] - voltages = channel_nodes.loc[channel_indices, "v"].to_numpy() - - channel_param_names = list(channel.params.keys()) - channel_state_names = list(channel.states.keys()) - channel_states = query_states_and_params(states, channel_state_names) - channel_params = query_states_and_params(params, channel_param_names) + for channel in self.channels: + channel_states = self._filter_states_params(states, channel) + channel_params = self._filter_states_params(params, channel) + channel_inds = self._mech_inds[channel._name] init_state = channel.init_state( - channel_states, voltages, channel_params, delta_t + channel_states, voltages[channel_inds], channel_params, delta_t ) # `init_state` might not return all channel states. Only the ones that are @@ -1460,7 +1478,7 @@ def init_states(self, delta_t: float = 0.025): # Note that we are overriding `self.nodes` here, but `self.nodes` is # not used above to actually compute the current states (so there are # no issues with overriding states). - self.nodes.loc[channel_indices, key] = val + self.nodes.loc[channel_inds, key] = val def _init_morph_for_debugging(self): """Instandiates row and column inds which can be used to solve the voltage eqs. @@ -1795,8 +1813,7 @@ def delete_channel(self, channel: Channel): channel_names = [c._name for c in self.channels] all_channel_names = [c._name for c in self.base.channels] if name in channel_names: - channel_cols = list(channel.params.keys()) - channel_cols += list(channel.states.keys()) + channel_cols = list(channel.params) + list(channel.states) self.base.nodes.loc[self._nodes_in_view, channel_cols] = float("nan") self.base.nodes.loc[self._nodes_in_view, name] = False @@ -1976,22 +1993,26 @@ def _step_channels_state( """One integration step of the channels.""" voltages = states["v"] morph_params = ["radius", "length", "axial_resistivity", "capacitance"] + morph_params = {pkey: params[pkey] for pkey in morph_params} + current_states = {name: states[name] for name in self.membrane_current_names} for channel in channels: - channel_param_names = list(channel.params) + morph_params - channel_params = query_states_and_params(params, channel_param_names) - channel_state_names = list(channel.states) + self.membrane_current_names - channel_states = query_states_and_params(states, channel_state_names) + channel_params = self._filter_states_params(params, channel) + channel_states = self._filter_states_params(states, channel) + + channel_params.update(morph_params) + channel_states.update(current_states) # States updates. - channel_inds = self._mech_lookup_table[channel._name + "_index"] - states_updated = channel.update_states( + channel_inds = self._mech_inds[channel._name] + channel_states_updated = channel.update_states( channel_states, delta_t, voltages[channel_inds], channel_params ) # Rebuild state. This has to be done within the loop over channels to allow # multiple channels which modify the same state. - for key, val in states_updated.items(): - states[key] = states[key].at[:].set(val) + for key, val in channel_states_updated.items(): + channel_inds = self._mech_lookup[key][channel._name]["local_index"] + states[key] = states[key].at[channel_inds].set(val) return states @@ -2009,21 +2030,22 @@ def _channel_currents( """ voltages = states["v"] morph_params = ["radius", "length", "axial_resistivity"] + morph_params = {pkey: params[pkey] for pkey in morph_params} # Compute current through channels. zeros = jnp.zeros_like(voltages) - voltage_terms, constant_terms = zeros, zeros + voltage_terms, const_terms = zeros, zeros # Run with two different voltages that are `diff` apart to infer the slope and # offset. diff = 1e-3 - current_states = {name: zeros for name in self.membrane_current_names} for channel in channels: - channel_param_names = list(channel.params) + morph_params - channel_params = query_states_and_params(params, channel_param_names) - channel_states = query_states_and_params(states, channel.states) + channel_params = self._filter_states_params(params, channel) + channel_states = self._filter_states_params(states, channel) + + channel_params.update(morph_params) - channel_inds = self._mech_lookup_table[channel._name + "_index"] + channel_inds = self._mech_inds[channel._name] v_channel = voltages[channel_inds] v_and_perturbed = jnp.array([v_channel, v_channel + diff]) @@ -2033,13 +2055,11 @@ def _channel_currents( # Split into voltage and constant terms. voltage_term = (membrane_currents[1] - membrane_currents[0]) / diff - constant_term = membrane_currents[0] - voltage_term * v_channel + const_term = membrane_currents[0] - voltage_term * v_channel # * 1000 to convert from mA/cm^2 to uA/cm^2. voltage_terms = voltage_terms.at[channel_inds].add(voltage_term * 1000.0) - constant_terms = constant_terms.at[channel_inds].add( - -constant_term * 1000.0 - ) + const_terms = const_terms.at[channel_inds].add(-const_term * 1000.0) # Save the current (for the unperturbed voltage) as a state that will # also be passed to the state update. @@ -2053,7 +2073,7 @@ def _channel_currents( # recorded and used by `Channel.update_states()`. for name in self.membrane_current_names: states[name] = current_states[name] - return states, (voltage_terms, constant_terms) + return states, (voltage_terms, const_terms) def _step_synapse( self, @@ -2553,35 +2573,32 @@ def _jax_arrays_in_view(self, pointer: Union[Module, View]): jaxnodes = {} if self.base.jaxnodes else None jaxedges = {} if self.base.jaxedges else None - # None check is needed for View -> see `View._jax_arrays_in_view` - chan_mechs = [m._name for m in self.channels] - syn_mechs = [m._name for m in self.synapses if m is not None] - mechs = chan_mechs + syn_mechs - if self.base._mech_lookup_table: - self._mech_lookup_table = {} - for k, v in self.base._mech_lookup_table.items(): - if "index" in k: - mech = k.replace("_index", "") - if mech in chan_mechs: - v = v[a_intersects_b_at(v, self._nodes_in_view)] - elif mech in syn_mechs: - v = v[a_intersects_b_at(v, self._edges_in_view)] - self._mech_lookup_table[k] = v - elif v in mechs + ["v"] + morph_params: - self._mech_lookup_table[k] = v - - for jax_arrays, base_jax_arrays, viewed_inds in zip( - [jaxnodes, jaxedges], - [self.base.jaxnodes, self.base.jaxedges], - [self._nodes_in_view, self._edges_in_view], - ): - if base_jax_arrays is not None and len(viewed_inds) > 0: - for key, values in base_jax_arrays.items(): - mech, mech_inds = self.base._get_mech_inds_of_param_state(key) - if mech is None or mech in mechs: - jax_arrays[key] = values.at[ - a_intersects_b_at(mech_inds, viewed_inds) - ] + # # None check is needed for View -> see `View._jax_arrays_in_view` + # chan_mechs = [m._name for m in self.channels] + # syn_mechs = [m._name for m in self.synapses if m is not None] + # mechs = chan_mechs + syn_mechs + # if self.base._mech_lookup: + # self._mech_lookup = {} + # self._mech_inds = {} + # for k, v in self.base._mech_inds.items(): + # inds = self._nodes_in_view if k in chan_mechs else self._edges_in_view + # self._mech_inds[k] = v[a_intersects_b_at(v, inds)] + # for k, v in self.base._mech_lookup.items(): + # self._mech_lookup[k] = {k:v_i for k,v_i in v.items() if k in mechs} + + # for jax_arrays, base_jax_arrays, viewed_inds in zip( + # [jaxnodes, jaxedges], + # [self.base.jaxnodes, self.base.jaxedges], + # [self._nodes_in_view, self._edges_in_view], + # ): + # if base_jax_arrays is not None and len(viewed_inds) > 0: + # for key, values in base_jax_arrays.items(): + # mech_inds = self.base._get_state_param_inds(key) + # for mech, inds in mech_inds.items(): + # if mech is None or mech in mechs: + # jax_arrays[key] = values.at[ + # a_intersects_b_at(inds, viewed_inds) + # ] return jaxnodes, jaxedges diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index 3f1f7339..a86cf8a0 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -22,7 +22,6 @@ dtype_aware_concat, loc_of_index, merge_cells, - query_states_and_params, ) from jaxley.utils.misc_utils import concat_and_ignore_empty, cumsum_leading_zero from jaxley.utils.solver_utils import ( @@ -267,16 +266,16 @@ def _step_synapse_state( ) -> Dict: voltages = states["v"] - for i, group in edges.groupby("type_ind"): - synapse = syn_channels[i] - pre_inds = group["pre_global_comp_index"].to_numpy() - post_inds = group["post_global_comp_index"].to_numpy() + for synapse in syn_channels: + inds = self._mech_inds[synapse._name] + pre_inds = edges.loc[inds, "pre_global_comp_index"].to_numpy() + post_inds = edges.loc[inds, "post_global_comp_index"].to_numpy() - synapse_params = query_states_and_params(params, synapse.params) - synapse_states = query_states_and_params(states, synapse.states) + synapse_params = self._filter_states_params(params, synapse) + synapse_states = self._filter_states_params(states, synapse) # State updates. - states_updated = synapse.update_states( + synapse_states_updated = synapse.update_states( synapse_states, delta_t, voltages[pre_inds], @@ -286,8 +285,9 @@ def _step_synapse_state( # Rebuild state. This has to be done within the loop over channels to allow # multiple channels which modify the same state. - for key, val in states_updated.items(): - states[key] = states[key].at[:].set(val) + for key, val in synapse_states_updated.items(): + synapse_inds = self._mech_lookup[key][synapse._name]["local_inds"] + states[key] = states[key].at[synapse_inds].set(val) return states @@ -303,7 +303,7 @@ def _synapse_currents( # Compute current through synapses. zeros = jnp.zeros_like(voltages) - syn_voltage_terms, syn_constant_terms = zeros, zeros + syn_voltage_terms, syn_const_terms = zeros, zeros # Run with two different voltages that are `diff` apart to infer the slope and # offset. diff = 1e-3 @@ -315,8 +315,8 @@ def _synapse_currents( pre_inds = group["pre_global_comp_index"].to_numpy() post_inds = group["post_global_comp_index"].to_numpy() - synapse_params = query_states_and_params(params, synapse.params) - synapse_states = query_states_and_params(states, synapse.states) + synapse_params = self._filter_states_params(params, synapse) + synapse_states = self._filter_states_params(states, synapse) v_pre, v_post = voltages[pre_inds], voltages[post_inds] pre_v_and_perturbed = jnp.array([v_pre, v_pre + diff]) @@ -346,7 +346,7 @@ def _synapse_currents( gathered_syn_currents = gather_synapes(num_comp, post_inds, *syn_voltages) syn_voltage_terms = syn_voltage_terms.at[:].add(gathered_syn_currents[0]) - syn_constant_terms = syn_constant_terms.at[:].add(-gathered_syn_currents[1]) + syn_const_terms = syn_const_terms.at[:].add(-gathered_syn_currents[1]) # Save the current (for the unperturbed voltage) as a state that will # also be passed to the state update. synapse_current_states[f"i_{synapse._name}"] = ( @@ -359,7 +359,7 @@ def _synapse_currents( # recorded and used by `Channel.update_states()`. for name in [s._name for s in self.synapses]: states[f"i_{name}"] = synapse_current_states[f"i_{name}"] - return states, (syn_voltage_terms, syn_constant_terms) + return states, (syn_voltage_terms, syn_const_terms) def arrange_in_layers( self, diff --git a/jaxley/utils/cell_utils.py b/jaxley/utils/cell_utils.py index a1e35c81..27e337ab 100644 --- a/jaxley/utils/cell_utils.py +++ b/jaxley/utils/cell_utils.py @@ -686,21 +686,6 @@ def group_and_sum( return group_sums -def query_states_and_params(d, keys, idcs=None): - """Get dict with subset of keys and values from d. - - This is used to restrict a dict where every item contains __all__ states to only - the ones that are relevant for the channel. E.g. - - ```states = {'eCa': Array([ 0., 0., nan]}``` - - will be - ```states = {'eCa': Array([ 0., 0.]}``` - - Only loops over necessary keys, as opposed to looping over `d.items()`.""" - return dict(zip(keys, (v if idcs is None else v[idcs] for v in map(d.get, keys)))) - - def compute_axial_conductances( comp_edges: pd.DataFrame, params: Dict[str, jnp.ndarray] ) -> jnp.ndarray: From f139f1592392c0db5b1c81e1a8c43a833cbbb838 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Fri, 13 Dec 2024 17:15:22 +0100 Subject: [PATCH 15/26] wip: more refactoring in light of recent discussion about new channel api --- jaxley/modules/base.py | 265 ++++++++++++-------------------------- jaxley/modules/network.py | 34 ++--- 2 files changed, 90 insertions(+), 209 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 17cf4223..3a75990b 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -143,9 +143,6 @@ def __init__(self): # List of all types of `jx.Synapse`s. self.synapses: List = [] - self.synapse_param_names = [] - self.synapse_state_names = [] - self.synapse_names = [] self.synapse_current_names: List[str] = [] # List of types of all `jx.Channel`s. @@ -187,7 +184,9 @@ def __str__(self): def __dir__(self): base_dir = object.__dir__(self) - return sorted(base_dir + self.synapse_names + list(self.group_nodes.keys())) + synapses = [s._name for s in self.synapses] + groups = [] if len(self.groups) == 0 else list(self.groups.keys()) + return sorted(base_dir + synapses + groups) def __getattr__(self, key): # Ensure that hidden methods such as `__deepcopy__` still work. @@ -213,14 +212,16 @@ def __getattr__(self, key): return view # intercepts calls to synapse types - if key in self.base.synapse_names: + base_syn_names = [s._name for s in self.base.synapses] + syn_names = [s._name for s in self.synapses] + if key in base_syn_names: syn_inds = self.edges[self.edges["type"] == key][ "global_edge_index" ].to_numpy() orig_scope = self._scope view = ( self.scope("global").edge(syn_inds).scope(orig_scope) - if key in self.synapse_names + if key in syn_names else self.select(None) ) view._set_controlled_by_param(key) # overwrites param set by edge @@ -721,7 +722,6 @@ def _gather_channels_from_constituents(self, constituents: List): name = channel._name self.base.nodes.loc[self.nodes[name].isna(), name] = False - @only_allow_module def to_jax(self): # TODO FROM #447: Make this work for View? """Move `.nodes` to `.jaxnodes`. @@ -737,23 +737,18 @@ def to_jax(self): # {HH_gNa: [0.1], ...} vs. {gbar_HH: [NaN, NaN, 0.1, ...], ...}. This means # a seperate lookup table is needed to figure out which parameters and states # belong to which mechanism and at which global indices the mechanism lives. - self._update_mech_lookup_table() - syn_param_states = sum( - [list(s.params) + list(s.states) for s in self.synapses], [] - ) + self._prepare_for_jax() jaxnodes, jaxedges = {}, {} - for state_param in self._mech_lookup: - mech_inds = self._get_state_param_inds(state_param) - data = self.edges if state_param in syn_param_states else self.nodes - jax_arrays = jaxedges if state_param in syn_param_states else jaxnodes + for key, inds in self._inds_of_state_param.items(): + data = self.nodes if key in self.nodes.columns else self.edges + jax_arrays = jaxnodes if key in self.nodes.columns else jaxedges - values = data.loc[mech_inds, state_param].to_numpy() - jax_arrays.update({state_param: values}) + inds = self._inds_of_state_param[key] + values = data.loc[inds, key].to_numpy() + jax_arrays.update({key: values}) - morph_params = ["radius", "length", "axial_resistivity", "capacitance"] - jaxnodes.update(self.nodes[["v"] + morph_params].to_dict(orient="list")) self.jaxnodes = {k: jnp.asarray(v) for k, v in jaxnodes.items()} self.jaxedges = {k: jnp.asarray(v) for k, v in jaxedges.items()} @@ -1160,8 +1155,7 @@ def write_trainables(self, trainable_params: List[Dict[str, jnp.ndarray]]): # Loop only over the keys in `pstate` to avoid unnecessary computation. for parameter in pstate: key = parameter["key"] - mech_inds = self._get_state_param_inds(key) - mech_inds = np.concatenate(list(mech_inds.values())) + mech_inds = self._inds_of_state_param[key] data = self.nodes if key in self.nodes.columns else self.edges data.loc[mech_inds, key] = all_params_states[key] @@ -1216,9 +1210,7 @@ def _get_state_names(self) -> Tuple[List, List]: Returns states seperated by comps and edges.""" channel_states = [name for c in self.channels for name in c.states] - synapse_states = [ - name for s in self.synapses if s is not None for name in s.states - ] + synapse_states = [name for s in self.synapses for name in s.states] membrane_states = ["v", "i"] + self.membrane_current_names return ( channel_states + membrane_states, @@ -1243,85 +1235,43 @@ def _iter_states_params( # assert that either params or states is True assert params or states, "Either params or states must be True." + global_states = ["v"] morph_params = ["radius", "length", "axial_resistivity", "capacitance"] - global_states_params = morph_params if params else [] - global_states_params += ["v"] if states else [] - for key in global_states_params: - yield key, self.jaxnodes[key] - # for key in self.synapse_current_names: - # yield key, self.jaxedges[key] + all_mechs = self.channels + self.synapses + all_states = sum([list(m.states) for m in all_mechs], []) + global_states + all_params = sum([list(m.params) for m in all_mechs], []) + morph_params + all_states_params = all_states if states else [] + all_states_params += all_params if params else [] # Join node and edge states into a single state dictionary. - for jax_arrays, mechs in zip( - [self.jaxnodes, self.jaxedges], - [self.channels, self.synapses], - ): - for mech in mechs: - mech_params_states = mech.params if params else {} - mech_params_states.update(mech.states if states else {}) - for key in mech_params_states: - yield key, jax_arrays[key] + for key in all_states_params: + jax_arrays = self.jaxnodes if key in self.nodes.columns else self.jaxedges + yield key, jax_arrays[key], self._inds_of_state_param[key] + + def _prepare_for_jax(self): + global_params = ["radius", "length", "axial_resistivity", "capacitance"] + global_states = ["v"] - def _update_mech_lookup_table(self): - state_param_lookup = {} - mech2inds = {} + def inds_of_key(key): + data = self.nodes if key in self.nodes.columns else self.edges + return data.index[~data[key].isna()].to_numpy() + + self._inds_of_state_param = { + k: inds_of_key(k) for k in global_states + global_params + } for mech in self.channels + self.synapses: is_channel = isinstance(mech, Channel) data = self.nodes if is_channel else self.edges cond = data[mech._name] if is_channel else data["type"] == mech._name + inds = data.index[cond].to_numpy() + mech.indices = jnp.asarray(inds) - chan_inds = {mech._name: data.index[cond].to_numpy()} - mech2inds.update(chan_inds) - for item in list(mech.params) + list(mech.states): - inds = chan_inds[mech._name] - mech_info = {mech._name: {"global_index": inds}} - if item not in state_param_lookup: - state_param_lookup[item] = mech_info - else: - state_param_lookup[item].update(mech_info) - - for state_param, info in state_param_lookup.items(): - global_inds = [v["global_index"] for v in info.values()] - comb_inds = np.unique(np.concatenate(global_inds)) - local_inds = { - k: jnp.searchsorted(comb_inds, v["global_index"]) - for k, v in info.items() - } - new_info = {} - for v1, (k, v2) in zip(global_inds, local_inds.items()): - new_info.update({k: {"global_index": v1, "local_index": v2}}) - state_param_lookup[state_param].update(new_info) - - self._mech_lookup = state_param_lookup - self._mech_inds = mech2inds - - def _get_state_param_inds(self, key: str) -> Tuple[str, jnp.ndarray]: - inds = None - if key in self._mech_lookup: - mechs = self._mech_lookup[key] - inds = np.unique(np.concatenate([self._mech_inds[m] for m in mechs])) - elif key in self.nodes.columns: - inds = {key: self._nodes_in_view} - elif key in self.edges.columns: - inds = {key: self._edges_in_view} - else: - raise KeyError(f"Key '{key}' not found in nodes or edges") - - return inds - - def _filter_states_params( - self, states_params: Dict[str, jnp.ndarray], mech: Union[Channel, Synapse] - ) -> Dict[str, jnp.ndarray]: - pkeys = list(mech.params) - mech_states_params = pkeys if pkeys[0] in states_params else list(mech.states) - - filtered_states_params = {} - for key in mech_states_params: - mech_inds = self._mech_lookup[key][mech._name]["local_index"] - filtered_states_params[key] = states_params[key][mech_inds] - return filtered_states_params + for key in list(mech.params) + list(mech.states): + is_global = mech._name not in key + param_state_inds = inds_of_key(key) if is_global else inds + self._inds_of_state_param[key] = jnp.asarray(param_state_inds) @only_allow_module def _get_all_states_params( @@ -1334,7 +1284,7 @@ def _get_all_states_params( states=False, ) -> Dict[str, jnp.ndarray]: states_params = {} - for key, jax_arrays in self._iter_states_params(params, states): + for key, jax_arrays, _ in self._iter_states_params(params, states): states_params[key] = jax_arrays # Override with those parameters set by `.make_trainable()`. @@ -1346,7 +1296,7 @@ def _get_all_states_params( # `set_param` is of shape `(num_params,)` # We need to unsqueeze `set_param` to make it `(num_params, 1)` # for the `.set()` to work. This is done with `[:, None]`. - mech_inds = self._get_state_param_inds(key) + mech_inds = self._inds_of_state_param[key] inds = jnp.searchsorted(mech_inds, inds) states_params[key] = states_params[key].at[inds].set(set_param[:, None]) @@ -1453,7 +1403,7 @@ def init_states(self, delta_t: float = 0.025): # Update states of the channels. self.to_jax() # Create `.jaxnodes` from `.nodes` and `.jaxedges` from `.edges`. states = {} - for key, jax_arrays in self._iter_states_params(states=True): + for key, jax_arrays, _ in self._iter_states_params(states=True): states[key] = jax_arrays # We do not use any `pstate` for initializing. In principle, we could change @@ -1464,12 +1414,11 @@ def init_states(self, delta_t: float = 0.025): voltages = self.nodes["v"].to_numpy() for channel in self.channels: - channel_states = self._filter_states_params(states, channel) - channel_params = self._filter_states_params(params, channel) + params = self._filter_global_params_states(params, channel) + states = self._filter_global_params_states(states, channel) - channel_inds = self._mech_inds[channel._name] init_state = channel.init_state( - channel_states, voltages[channel_inds], channel_params, delta_t + states, voltages[channel.indices], params, delta_t ) # `init_state` might not return all channel states. Only the ones that are @@ -1478,7 +1427,7 @@ def init_states(self, delta_t: float = 0.025): # Note that we are overriding `self.nodes` here, but `self.nodes` is # not used above to actually compute the current states (so there are # no issues with overriding states). - self.nodes.loc[channel_inds, key] = val + self.nodes.loc[channel.indices, key] = val def _init_morph_for_debugging(self): """Instandiates row and column inds which can be used to solve the voltage eqs. @@ -1992,30 +1941,30 @@ def _step_channels_state( ) -> Dict[str, jnp.ndarray]: """One integration step of the channels.""" voltages = states["v"] - morph_params = ["radius", "length", "axial_resistivity", "capacitance"] - morph_params = {pkey: params[pkey] for pkey in morph_params} - current_states = {name: states[name] for name in self.membrane_current_names} for channel in channels: - channel_params = self._filter_states_params(params, channel) - channel_states = self._filter_states_params(states, channel) - - channel_params.update(morph_params) - channel_states.update(current_states) - # States updates. - channel_inds = self._mech_inds[channel._name] channel_states_updated = channel.update_states( - channel_states, delta_t, voltages[channel_inds], channel_params + states, delta_t, voltages[channel.indices], params ) # Rebuild state. This has to be done within the loop over channels to allow # multiple channels which modify the same state. for key, val in channel_states_updated.items(): - channel_inds = self._mech_lookup[key][channel._name]["local_index"] - states[key] = states[key].at[channel_inds].set(val) + states[key] = states[key].at[:].set(val) return states + def _filter_global_params_states(self, dct, mech): + mech_state_params = list(mech.params) + list(mech.states) + is_global = lambda key: f"{mech._name}_" not in key and key in dct + global_params_states = [key for key in mech_state_params if is_global(key)] + for key in global_params_states: + param_inds = self._inds_of_state_param[key] + param_where_channel = jnp.searchsorted(param_inds, mech.indices) + dct[key] = dct[key][param_where_channel] + + return dct + def _channel_currents( self, states: Dict[str, jnp.ndarray], @@ -2040,17 +1989,15 @@ def _channel_currents( diff = 1e-3 current_states = {name: zeros for name in self.membrane_current_names} for channel in channels: - channel_params = self._filter_states_params(params, channel) - channel_states = self._filter_states_params(states, channel) - - channel_params.update(morph_params) - - channel_inds = self._mech_inds[channel._name] + channel_inds = channel.indices v_channel = voltages[channel_inds] v_and_perturbed = jnp.array([v_channel, v_channel + diff]) + params = self._filter_global_params_states(params, channel) + states = self._filter_global_params_states(states, channel) + membrane_currents = vmap(channel.compute_current, in_axes=(None, 0, None))( - channel_states, v_and_perturbed, channel_params + states, v_and_perturbed, params ) # Split into voltage and constant terms. @@ -2472,10 +2419,12 @@ def __init__( self.ncomp = pointer.ncomp self.nodes = pointer.nodes.loc[self._nodes_in_view] - ptr_edges = pointer.edges - self.edges = ( - ptr_edges if ptr_edges.empty else ptr_edges.loc[self._edges_in_view] - ) + self.edges = pointer.edges + if not self.edges.empty: + self.edges = pointer.edges.loc[self._edges_in_view] + + # re-enumerate type_inds + self.edges["type_ind"] = self.edges["type"].astype("category").cat.codes self.xyzr = self._xyzr_in_view() self.ncomp = 1 if len(self.nodes) == 1 else pointer.ncomp @@ -2487,8 +2436,8 @@ def __init__( self.ncomp_per_branch = self.base.ncomp_per_branch[self._branches_in_view] self.cumsum_ncomp = cumsum_leading_zero(self.ncomp_per_branch) - self.synapse_names = np.unique(self.edges["type"]).tolist() - self._set_synapses_in_view(pointer) + self.synapses = self._synapses_in_view(pointer) + self.channels = self._channels_in_view(pointer) ptr_recs = pointer.recordings self.recordings = ( @@ -2497,7 +2446,6 @@ def __init__( else ptr_recs.loc[ptr_recs["rec_index"].isin(self._comps_in_view)] ) - self.channels = self._channels_in_view(pointer) self.membrane_current_names = [c.current_name for c in self.channels] self.synapse_current_names = pointer.synapse_current_names self._set_trainables_in_view() # run after synapses and channels @@ -2514,9 +2462,8 @@ def __init__( k: np.intersect1d(v, self._nodes_in_view) for k, v in pointer.groups.items() } - self.jaxnodes, self.jaxedges = self._jax_arrays_in_view( - pointer - ) # run after trainables + if pointer.jaxnodes: + self.to_jax() self._current_view = "view" # if not instantiated via `comp`, `cell` etc. self._update_local_indices() @@ -2565,43 +2512,6 @@ def _set_inds_in_view( self._nodes_in_view = nodes self._edges_in_view = edges - def _jax_arrays_in_view(self, pointer: Union[Module, View]): - """Update jaxnodes/jaxedges to show only those currently in view.""" - a_intersects_b_at = lambda a, b: jnp.intersect1d(a, b, return_indices=True)[1] - morph_params = ["radius", "length", "axial_resistivity", "capacitance"] - - jaxnodes = {} if self.base.jaxnodes else None - jaxedges = {} if self.base.jaxedges else None - - # # None check is needed for View -> see `View._jax_arrays_in_view` - # chan_mechs = [m._name for m in self.channels] - # syn_mechs = [m._name for m in self.synapses if m is not None] - # mechs = chan_mechs + syn_mechs - # if self.base._mech_lookup: - # self._mech_lookup = {} - # self._mech_inds = {} - # for k, v in self.base._mech_inds.items(): - # inds = self._nodes_in_view if k in chan_mechs else self._edges_in_view - # self._mech_inds[k] = v[a_intersects_b_at(v, inds)] - # for k, v in self.base._mech_lookup.items(): - # self._mech_lookup[k] = {k:v_i for k,v_i in v.items() if k in mechs} - - # for jax_arrays, base_jax_arrays, viewed_inds in zip( - # [jaxnodes, jaxedges], - # [self.base.jaxnodes, self.base.jaxedges], - # [self._nodes_in_view, self._edges_in_view], - # ): - # if base_jax_arrays is not None and len(viewed_inds) > 0: - # for key, values in base_jax_arrays.items(): - # mech_inds = self.base._get_state_param_inds(key) - # for mech, inds in mech_inds.items(): - # if mech is None or mech in mechs: - # jax_arrays[key] = values.at[ - # a_intersects_b_at(inds, viewed_inds) - # ] - - return jaxnodes, jaxedges - def _set_externals_in_view(self): """Update external inputs to show only those currently in view.""" self.externals = {} @@ -2691,25 +2601,12 @@ def _channels_in_view(self, pointer: Union[Module, View]) -> List[Channel]: names = [name._name for name in pointer.channels] channel_in_view = self.nodes[names].any(axis=0) channel_in_view = channel_in_view[channel_in_view].index - return [c for c in pointer.channels if c._name in channel_in_view] + return [deepcopy(c) for c in pointer.channels if c._name in channel_in_view] - def _set_synapses_in_view(self, pointer: Union[Module, View]): + def _synapses_in_view(self, pointer: Union[Module, View]): """Set synapses to show only those in view.""" - viewed_synapses = [] - viewed_params = [] - viewed_states = [] - if pointer.synapses is not None: - for syn in pointer.synapses: - if syn is not None: # needed for recurive viewing - in_view = syn._name in self.synapse_names - viewed_synapses += ( - [syn] if in_view else [None] - ) # padded with None to keep indices consistent - viewed_params += list(syn.params.keys()) if in_view else [] - viewed_states += list(syn.states.keys()) if in_view else [] - self.synapses = viewed_synapses - self.synapse_param_names = viewed_params - self.synapse_state_names = viewed_states + names = self.edges["type"].unique() + return [deepcopy(syn) for syn in pointer.synapses if syn._name in names] def _nbranches_per_cell_in_view(self) -> np.ndarray: cell_nodes = self.nodes.groupby("global_cell_index") diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index a86cf8a0..05f6fabe 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -250,7 +250,6 @@ def _step_synapse( ) -> Tuple[Dict, Tuple[jnp.ndarray, jnp.ndarray]]: """Perform one step of the synapses and obtain their currents.""" states = self._step_synapse_state(states, syn_channels, params, delta_t, edges) - # import jax; jax.debug.print("1 {}", states["TestSynapse_c"]) states, current_terms = self._synapse_currents( states, syn_channels, params, delta_t, edges ) @@ -267,27 +266,22 @@ def _step_synapse_state( voltages = states["v"] for synapse in syn_channels: - inds = self._mech_inds[synapse._name] - pre_inds = edges.loc[inds, "pre_global_comp_index"].to_numpy() - post_inds = edges.loc[inds, "post_global_comp_index"].to_numpy() - - synapse_params = self._filter_states_params(params, synapse) - synapse_states = self._filter_states_params(states, synapse) + pre_inds = edges.loc[synapse.indices, "pre_global_comp_index"].to_numpy() + post_inds = edges.loc[synapse.indices, "post_global_comp_index"].to_numpy() # State updates. synapse_states_updated = synapse.update_states( - synapse_states, + states, delta_t, voltages[pre_inds], voltages[post_inds], - synapse_params, + params, ) # Rebuild state. This has to be done within the loop over channels to allow # multiple channels which modify the same state. for key, val in synapse_states_updated.items(): - synapse_inds = self._mech_lookup[key][synapse._name]["local_inds"] - states[key] = states[key].at[synapse_inds].set(val) + states[key] = states[key].at[:].set(val) return states @@ -315,9 +309,6 @@ def _synapse_currents( pre_inds = group["pre_global_comp_index"].to_numpy() post_inds = group["post_global_comp_index"].to_numpy() - synapse_params = self._filter_states_params(params, synapse) - synapse_states = self._filter_states_params(states, synapse) - v_pre, v_post = voltages[pre_inds], voltages[post_inds] pre_v_and_perturbed = jnp.array([v_pre, v_pre + diff]) post_v_and_perturbed = jnp.array([v_post, v_post + diff]) @@ -325,10 +316,10 @@ def _synapse_currents( synapse_currents = vmap( synapse.compute_current, in_axes=(None, 0, 0, None) )( - synapse_states, + states, pre_v_and_perturbed, post_v_and_perturbed, - synapse_params, + params, ) synapse_currents_dist = compute_current_density( synapse_currents, @@ -483,25 +474,18 @@ def vis( return ax def _infer_synapse_type_ind(self, synapse_name): - syn_names = self.base.synapse_names + syn_names = [s._name for s in self.base.synapses] is_new_type = False if synapse_name in syn_names else True type_ind = len(syn_names) if is_new_type else syn_names.index(synapse_name) return type_ind, is_new_type - def _update_synapse_state_names(self, synapse_type): - # (Potentially) update variables that track meta information about synapses. - self.base.synapse_names.append(synapse_type._name) - self.base.synapse_param_names += list(synapse_type.params.keys()) - self.base.synapse_state_names += list(synapse_type.states.keys()) - self.base.synapses.append(synapse_type) - def _append_multiple_synapses(self, pre_nodes, post_nodes, synapse_type): # Add synapse types to the module and infer their unique identifier. synapse_name = synapse_type._name synapse_current_name = f"i_{synapse_name}" type_ind, is_new = self._infer_synapse_type_ind(synapse_name) if is_new: # synapse is not known - self._update_synapse_state_names(synapse_type) + self.base.synapses.append(synapse_type) self.base.synapse_current_names.append(synapse_current_name) index = len(self.base.edges) From 44c4a80ae555f395902c68034b827def0ad171f7 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Fri, 13 Dec 2024 17:42:50 +0100 Subject: [PATCH 16/26] fix: fix jitting issues of to_jax! --- jaxley/modules/base.py | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 3a75990b..e6cb318e 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -723,7 +723,6 @@ def _gather_channels_from_constituents(self, constituents: List): self.base.nodes.loc[self.nodes[name].isna(), name] = False def to_jax(self): - # TODO FROM #447: Make this work for View? """Move `.nodes` to `.jaxnodes`. Before the actual simulation is run (via `jx.integrate`), all parameters of @@ -741,12 +740,15 @@ def to_jax(self): jaxnodes, jaxedges = {}, {} + nodes = self.nodes.to_dict(orient="list") + edges = self.edges.to_dict(orient="list") + for key, inds in self._inds_of_state_param.items(): - data = self.nodes if key in self.nodes.columns else self.edges + data = nodes if key in self.nodes.columns else edges jax_arrays = jaxnodes if key in self.nodes.columns else jaxedges inds = self._inds_of_state_param[key] - values = data.loc[inds, key].to_numpy() + values = jnp.asarray(data[key])[inds] jax_arrays.update({key: values}) self.jaxnodes = {k: jnp.asarray(v) for k, v in jaxnodes.items()} @@ -1115,7 +1117,6 @@ def make_trainable( f"Number of newly added trainable parameters: {num_created_parameters}. Total number of trainable parameters: {self.base.num_trainable_params}" ) - @only_allow_module def write_trainables(self, trainable_params: List[Dict[str, jnp.ndarray]]): """Write the trainables into `.nodes` and `.edges`. @@ -1124,10 +1125,6 @@ def write_trainables(self, trainable_params: List[Dict[str, jnp.ndarray]]): Args: trainable_params: The trainable parameters returned by `get_parameters()`. """ - # We do not support views. Why? `jaxedges` does not have any NaN - # elements, whereas edges does. Because of this, we already need special - # treatment to make this function work, and it would be an even bigger hassle - # if we wanted to support this. assert self.__class__.__name__ in [ "Compartment", "Branch", @@ -1156,7 +1153,9 @@ def write_trainables(self, trainable_params: List[Dict[str, jnp.ndarray]]): for parameter in pstate: key = parameter["key"] mech_inds = self._inds_of_state_param[key] - data = self.nodes if key in self.nodes.columns else self.edges + data = ( + self.base.nodes if key in self.base.nodes.columns else self.base.edges + ) data.loc[mech_inds, key] = all_params_states[key] def distance(self, endpoint: "View") -> float: @@ -1273,7 +1272,6 @@ def inds_of_key(key): param_state_inds = inds_of_key(key) if is_global else inds self._inds_of_state_param[key] = jnp.asarray(param_state_inds) - @only_allow_module def _get_all_states_params( self, pstate: List[Dict], @@ -1319,7 +1317,6 @@ def _get_all_states_params( ) return states_params - @only_allow_module def get_all_parameters( self, pstate: List[Dict], voltage_solver: str ) -> Dict[str, jnp.ndarray]: @@ -1360,7 +1357,6 @@ def get_all_parameters( ) return params - @only_allow_module def get_all_states( self, pstate: List[Dict], all_params, delta_t: float ) -> Dict[str, jnp.ndarray]: @@ -1390,7 +1386,6 @@ def _initialize(self): self._init_morph() return self - @only_allow_module def init_states(self, delta_t: float = 0.025): # TODO FROM #447: MAKE THIS WORK FOR VIEW? """Initialize all mechanisms in their steady state. @@ -1427,7 +1422,7 @@ def init_states(self, delta_t: float = 0.025): # Note that we are overriding `self.nodes` here, but `self.nodes` is # not used above to actually compute the current states (so there are # no issues with overriding states). - self.nodes.loc[channel.indices, key] = val + self.base.nodes.loc[channel.indices, key] = val def _init_morph_for_debugging(self): """Instandiates row and column inds which can be used to solve the voltage eqs. From cd4f6648bcb1b498902719fd9a77a89d010fdaac Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Mon, 16 Dec 2024 17:47:56 +0100 Subject: [PATCH 17/26] fix: all tests finally passing --- jaxley/modules/base.py | 149 ++++++++++++++++++++++++----------------- 1 file changed, 87 insertions(+), 62 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index e6cb318e..dbd1f7ce 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -9,6 +9,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union from warnings import warn +import jax import jax.numpy as jnp import numpy as np import pandas as pd @@ -743,11 +744,10 @@ def to_jax(self): nodes = self.nodes.to_dict(orient="list") edges = self.edges.to_dict(orient="list") - for key, inds in self._inds_of_state_param.items(): + for key, inds in self._iter_states_params(states=True, params=True): data = nodes if key in self.nodes.columns else edges jax_arrays = jaxnodes if key in self.nodes.columns else jaxedges - inds = self._inds_of_state_param[key] values = jnp.asarray(data[key])[inds] jax_arrays.update({key: values}) @@ -1150,13 +1150,13 @@ def write_trainables(self, trainable_params: List[Dict[str, jnp.ndarray]]): ) # Loop only over the keys in `pstate` to avoid unnecessary computation. - for parameter in pstate: - key = parameter["key"] - mech_inds = self._inds_of_state_param[key] + for p in pstate: + key, inds = p["key"], p["indices"] + inds = np.array(inds.reshape(-1)) data = ( self.base.nodes if key in self.base.nodes.columns else self.base.edges ) - data.loc[mech_inds, key] = all_params_states[key] + data.loc[inds, key] = all_params_states[key][inds] def distance(self, endpoint: "View") -> float: """Return the direct distance between two compartments. @@ -1228,49 +1228,75 @@ def get_parameters(self) -> List[Dict[str, jnp.ndarray]]: return self.trainable_params def _iter_states_params( - self, params=False, states=False + self, params=False, states=False, currents=False ) -> Tuple[str, jnp.ndarray]: - # TODO FROM #447: MAKE THIS WORK FOR VIEW? - # assert that either params or states is True - assert params or states, "Either params or states must be True." + assert params or states or currents, "Select either params / states / currents." + global_params = ["radius", "length", "axial_resistivity", "capacitance"] global_states = ["v"] - morph_params = ["radius", "length", "axial_resistivity", "capacitance"] + + current_names = self.membrane_current_names + self.synapse_current_names + channel_currents = [c.current_name for c in self.channels] all_mechs = self.channels + self.synapses all_states = sum([list(m.states) for m in all_mechs], []) + global_states - all_params = sum([list(m.params) for m in all_mechs], []) + morph_params - all_states_params = all_states if states else [] - all_states_params += all_params if params else [] + all_params = sum([list(m.params) for m in all_mechs], []) + global_params - # Join node and edge states into a single state dictionary. - for key in all_states_params: - jax_arrays = self.jaxnodes if key in self.nodes.columns else self.jaxedges - yield key, jax_arrays[key], self._inds_of_state_param[key] + if params: + for key in all_states: + yield key, self._inds_of_state_param(key) + + if states: + for key in all_params: + yield key, self._inds_of_state_param(key) + + if currents: + for key in current_names + channel_currents: + yield key, self._inds_of_state_param(key) def _prepare_for_jax(self): + # prepare lookup of indices of states, parameters and mechanisms global_params = ["radius", "length", "axial_resistivity", "capacitance"] global_states = ["v"] - def inds_of_key(key): - data = self.nodes if key in self.nodes.columns else self.edges - return data.index[~data[key].isna()].to_numpy() + current_names = self.membrane_current_names + self.synapse_current_names + global_states_params = global_states + global_params + current_names + node_attrs = self.nodes.columns.to_list() + current_names + channel_names - self._inds_of_state_param = { - k: inds_of_key(k) for k in global_states + global_params - } + channel_names = [c._name for c in self.channels] + syn_names = [s._name for s in self.synapses] + + def inds_of_key(key: str) -> np.ndarray: + """Return the indices for params, states, mechanisms and currents.""" + data = self.nodes if key in node_attrs else pd.DataFrame() + data = self.edges if key in self.edges.columns or key in syn_names else data + + if key in channel_names + syn_names: + where = data["type"] == key if key in syn_names else data[key] + elif key in data.columns: + where = ~data[key].isna() + elif key in global_states_params: + where = pd.Index([True] * len(data)) + else: + raise ValueError(f"Key '{key}' not found in nodes or edges") + return data.index[where].to_numpy() + + # expose the lookup function to the class with precomputed attrs in scope + self._inds_of_state_param = inds_of_key + # add index attrs to mechansisms (i.e. where was it inserted) and also keep track + # of states / parameters that are also shared by other mechanisms. for mech in self.channels + self.synapses: - is_channel = isinstance(mech, Channel) - data = self.nodes if is_channel else self.edges - cond = data[mech._name] if is_channel else data["type"] == mech._name - inds = data.index[cond].to_numpy() - mech.indices = jnp.asarray(inds) + mech.indices = self._inds_of_state_param(mech._name) + mech._jax_inds = {} + currents = {mech.current_name: None} if isinstance(mech, Channel) else {} - for key in list(mech.params) + list(mech.states): - is_global = mech._name not in key - param_state_inds = inds_of_key(key) if is_global else inds - self._inds_of_state_param[key] = jnp.asarray(param_state_inds) + for param_state in {**mech.params, **mech.states, **currents}: + is_global = not param_state.startswith(f"{mech._name}_") + if is_global: + global_inds = self._inds_of_state_param(param_state) + local_inds = np.where(np.isin(global_inds, mech.indices))[0] + mech._jax_inds[param_state] = local_inds def _get_all_states_params( self, @@ -1282,20 +1308,22 @@ def _get_all_states_params( states=False, ) -> Dict[str, jnp.ndarray]: states_params = {} - for key, jax_arrays, _ in self._iter_states_params(params, states): - states_params[key] = jax_arrays + pkeys = {} + for i, p in enumerate(pstate): + pkeys[p["key"]] = pkeys[p["key"]] + [i] if p["key"] in pkeys else [i] - # Override with those parameters set by `.make_trainable()`. - for p in pstate: - key, inds, set_param = p["key"], p["indices"], p["val"] - - if key in states_params: + for key, param_state_inds in self._iter_states_params(params, states): + jax_arrays = self.jaxnodes if key in self.nodes.columns else self.jaxedges + states_params[key] = jax_arrays[key] + # Override with those parameters set by `.make_trainable()`. + for i in pkeys.get(key, []): + p = pstate[i] + key, inds, set_param = p["key"], p["indices"], p["val"] # `inds` is of shape `(num_params, num_comps_per_param)`. # `set_param` is of shape `(num_params,)` # We need to unsqueeze `set_param` to make it `(num_params, 1)` # for the `.set()` to work. This is done with `[:, None]`. - mech_inds = self._inds_of_state_param[key] - inds = jnp.searchsorted(mech_inds, inds) + inds = jnp.searchsorted(param_state_inds, inds) states_params[key] = states_params[key].at[inds].set(set_param[:, None]) if params: @@ -1398,8 +1426,9 @@ def init_states(self, delta_t: float = 0.025): # Update states of the channels. self.to_jax() # Create `.jaxnodes` from `.nodes` and `.jaxedges` from `.edges`. states = {} - for key, jax_arrays, _ in self._iter_states_params(states=True): - states[key] = jax_arrays + for key, _ in self._iter_states_params(states=True): + jax_arrays = self.jaxnodes if key in self.nodes.columns else self.jaxedges + states[key] = jax_arrays[key] # We do not use any `pstate` for initializing. In principle, we could change # that by allowing an input `params` and `pstate` to this function. @@ -1409,8 +1438,8 @@ def init_states(self, delta_t: float = 0.025): voltages = self.nodes["v"].to_numpy() for channel in self.channels: - params = self._filter_global_params_states(params, channel) - states = self._filter_global_params_states(states, channel) + states = self._filter_params_states(states, channel._jax_inds) + params = self._filter_params_states(params, channel._jax_inds) init_state = channel.init_state( states, voltages[channel.indices], params, delta_t @@ -1757,7 +1786,7 @@ def delete_channel(self, channel: Channel): channel_names = [c._name for c in self.channels] all_channel_names = [c._name for c in self.base.channels] if name in channel_names: - channel_cols = list(channel.params) + list(channel.states) + channel_cols = list({**channel.params, **channel.states}.keys()) self.base.nodes.loc[self._nodes_in_view, channel_cols] = float("nan") self.base.nodes.loc[self._nodes_in_view, name] = False @@ -1769,6 +1798,12 @@ def delete_channel(self, channel: Channel): else: raise ValueError(f"Channel {name} not found in the module.") + def _filter_params_states(self, pytree, filter_dct): + for key, inds in filter_dct.items(): + if key in pytree: + pytree[key] = pytree[key][inds] + return pytree + @only_allow_module def step( self, @@ -1801,7 +1836,6 @@ def step( Returns: The updated state of the module. """ - # Extract the voltages voltages = u["v"] @@ -1939,6 +1973,8 @@ def _step_channels_state( for channel in channels: # States updates. + states = self._filter_params_states(states, channel._jax_inds) + params = self._filter_params_states(params, channel._jax_inds) channel_states_updated = channel.update_states( states, delta_t, voltages[channel.indices], params ) @@ -1949,17 +1985,6 @@ def _step_channels_state( return states - def _filter_global_params_states(self, dct, mech): - mech_state_params = list(mech.params) + list(mech.states) - is_global = lambda key: f"{mech._name}_" not in key and key in dct - global_params_states = [key for key in mech_state_params if is_global(key)] - for key in global_params_states: - param_inds = self._inds_of_state_param[key] - param_where_channel = jnp.searchsorted(param_inds, mech.indices) - dct[key] = dct[key][param_where_channel] - - return dct - def _channel_currents( self, states: Dict[str, jnp.ndarray], @@ -1973,7 +1998,7 @@ def _channel_currents( This is also updates `state` because the `state` also contains the current. """ voltages = states["v"] - morph_params = ["radius", "length", "axial_resistivity"] + morph_params = ["radius", "length", "axial_resistivity", "capacitance"] morph_params = {pkey: params[pkey] for pkey in morph_params} # Compute current through channels. @@ -1988,8 +2013,8 @@ def _channel_currents( v_channel = voltages[channel_inds] v_and_perturbed = jnp.array([v_channel, v_channel + diff]) - params = self._filter_global_params_states(params, channel) - states = self._filter_global_params_states(states, channel) + states = self._filter_params_states(states, channel._jax_inds) + params = self._filter_params_states(params, channel._jax_inds) membrane_currents = vmap(channel.compute_current, in_axes=(None, 0, None))( states, v_and_perturbed, params From 8bbc6ba17b4ff26e7624a43adbc3ab3f255099d0 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Mon, 16 Dec 2024 17:49:40 +0100 Subject: [PATCH 18/26] fix: ammend last commit --- jaxley/modules/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index dbd1f7ce..f19320a4 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -1261,11 +1261,11 @@ def _prepare_for_jax(self): current_names = self.membrane_current_names + self.synapse_current_names global_states_params = global_states + global_params + current_names - node_attrs = self.nodes.columns.to_list() + current_names + channel_names channel_names = [c._name for c in self.channels] syn_names = [s._name for s in self.synapses] + node_attrs = self.nodes.columns.to_list() + current_names + channel_names def inds_of_key(key: str) -> np.ndarray: """Return the indices for params, states, mechanisms and currents.""" data = self.nodes if key in node_attrs else pd.DataFrame() From 5e31be8961dee25ec21b5760f491692ae5564d01 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Mon, 16 Dec 2024 18:07:50 +0100 Subject: [PATCH 19/26] fix: small fixes and comments added --- jaxley/modules/base.py | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index f19320a4..a7d0e1bd 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -1229,29 +1229,26 @@ def get_parameters(self) -> List[Dict[str, jnp.ndarray]]: def _iter_states_params( self, params=False, states=False, currents=False - ) -> Tuple[str, jnp.ndarray]: + ) -> Tuple[str, np.ndarray]: # type: ignore # assert that either params or states is True assert params or states or currents, "Select either params / states / currents." - global_params = ["radius", "length", "axial_resistivity", "capacitance"] - global_states = ["v"] - - current_names = self.membrane_current_names + self.synapse_current_names - channel_currents = [c.current_name for c in self.channels] - all_mechs = self.channels + self.synapses - all_states = sum([list(m.states) for m in all_mechs], []) + global_states - all_params = sum([list(m.params) for m in all_mechs], []) + global_params - + if params: - for key in all_states: + global_params = ["radius", "length", "axial_resistivity", "capacitance"] + all_params = sum([list(m.params) for m in all_mechs], []) + global_params + for key in all_params: yield key, self._inds_of_state_param(key) if states: - for key in all_params: + global_states = ["v"] + all_states = sum([list(m.states) for m in all_mechs], []) + global_states + for key in all_states: yield key, self._inds_of_state_param(key) if currents: - for key in current_names + channel_currents: + current_names = self.membrane_current_names + self.synapse_current_names + for key in current_names: yield key, self._inds_of_state_param(key) def _prepare_for_jax(self): @@ -1289,9 +1286,9 @@ def inds_of_key(key: str) -> np.ndarray: for mech in self.channels + self.synapses: mech.indices = self._inds_of_state_param(mech._name) mech._jax_inds = {} - currents = {mech.current_name: None} if isinstance(mech, Channel) else {} + current = {mech.current_name: None} if isinstance(mech, Channel) else {} - for param_state in {**mech.params, **mech.states, **currents}: + for param_state in {**mech.params, **mech.states, **current}: is_global = not param_state.startswith(f"{mech._name}_") if is_global: global_inds = self._inds_of_state_param(param_state) From cc23b2b707d13fec573e4846fc032f8110bc4343 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Tue, 17 Dec 2024 00:44:56 +0100 Subject: [PATCH 20/26] fix: move some things around --- jaxley/modules/base.py | 147 +++++++++++++++++++------------------- jaxley/modules/network.py | 12 ++-- 2 files changed, 80 insertions(+), 79 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index a7d0e1bd..71435643 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -185,7 +185,7 @@ def __str__(self): def __dir__(self): base_dir = object.__dir__(self) - synapses = [s._name for s in self.synapses] + synapses = [s.name for s in self.synapses] groups = [] if len(self.groups) == 0 else list(self.groups.keys()) return sorted(base_dir + synapses + groups) @@ -205,16 +205,16 @@ def __getattr__(self, key): return view # intercepts calls to channels - if key in [c._name for c in self.base.channels]: - channel_names = [c._name for c in self.channels] + if key in [c.name for c in self.base.channels]: + channel_names = [c.name for c in self.channels] inds = self.nodes.index[self.nodes[key]].to_numpy() view = self.select(inds) if key in channel_names else self.select(None) view._set_controlled_by_param(key) return view # intercepts calls to synapse types - base_syn_names = [s._name for s in self.base.synapses] - syn_names = [s._name for s in self.synapses] + base_syn_names = [s.name for s in self.base.synapses] + syn_names = [s.name for s in self.synapses] if key in base_syn_names: syn_inds = self.edges[self.edges["type"] == key][ "global_edge_index" @@ -714,15 +714,60 @@ def _gather_channels_from_constituents(self, constituents: List): """ for module in constituents: for channel in module.channels: - if channel._name not in [c._name for c in self.channels]: + if channel.name not in [c.name for c in self.channels]: self.base.channels.append(channel) if channel.current_name not in self.membrane_current_names: self.base.membrane_current_names.append(channel.current_name) # Setting columns of channel names to `False` instead of `NaN`. for channel in self.base.channels: - name = channel._name + name = channel.name self.base.nodes.loc[self.nodes[name].isna(), name] = False + def _prepare_for_jax(self): + # prepare lookup of indices of states, parameters and mechanisms + global_params = ["radius", "length", "axial_resistivity", "capacitance"] + global_states = ["v"] + + current_names = self.membrane_current_names + self.synapse_current_names + global_states_params = global_states + global_params + current_names + + channel_names = [c.name for c in self.channels] + syn_names = [s.name for s in self.synapses] + + node_attrs = self.nodes.columns.to_list() + current_names + channel_names + + def inds_of_key(key: str) -> np.ndarray: + """Return the indices for params, states, mechanisms and currents.""" + data = self.nodes if key in node_attrs else pd.DataFrame() + data = self.edges if key in self.edges.columns or key in syn_names else data + + if key in channel_names + syn_names: + where = data["type"] == key if key in syn_names else data[key] + elif key in data.columns: + where = ~data[key].isna() + elif key in global_states_params: + where = pd.Index([True] * len(data)) + else: + raise ValueError(f"Key '{key}' not found in nodes or edges") + return data.index[where].to_numpy() + + # expose the lookup function to the class with precomputed attrs in scope + self._inds_of_state_param = inds_of_key + + # add index attrs to mechansisms (i.e. where was it inserted) and also keep track + # of states / parameters that are also shared by other mechanisms. + for mech in self.channels + self.synapses: + mech.indices = self._inds_of_state_param(mech.name) + mech._jax_inds = {} + current = {mech.current_name: None} if isinstance(mech, Channel) else {} + + for param_state in {**mech.params, **mech.states, **current}: + is_global = not param_state.startswith(f"{mech.name}_") + if is_global: + global_inds = self._inds_of_state_param(param_state) + local_inds = np.where(np.isin(global_inds, mech.indices))[0] + mech._jax_inds[param_state] = local_inds + def to_jax(self): """Move `.nodes` to `.jaxnodes`. @@ -784,7 +829,7 @@ def show( scopes = ["local", "global"] inds = [f"{s}_{i}" for i in inds for s in scopes] if indices else [] cols += inds - cols += [ch._name for ch in self.channels] if channel_names else [] + cols += [ch.name for ch in self.channels] if channel_names else [] cols += sum([list(ch.params) for ch in self.channels], []) if params else [] cols += sum([list(ch.states) for ch in self.channels], []) if states else [] @@ -914,7 +959,7 @@ def set_ncomp( all_nodes = self.base.nodes start_idx = self.nodes["global_comp_index"].to_numpy()[0] ncomp_per_branch = self.base.ncomp_per_branch - channel_names = [c._name for c in self.base.channels] + channel_names = [c.name for c in self.base.channels] channel_param_names = list(chain(*[c.params for c in self.base.channels])) channel_state_names = list(chain(*[c.states for c in self.base.channels])) radius_generating_fns = self.base._radius_generating_fns @@ -1216,33 +1261,22 @@ def _get_state_names(self) -> Tuple[List, List]: synapse_states + self.synapse_current_names, ) - def get_parameters(self) -> List[Dict[str, jnp.ndarray]]: - """Get all trainable parameters. - - The returned parameters should be passed to `jx.integrate(..., params=params). - - Returns: - A list of all trainable parameters in the form of - [{"gNa": jnp.array([0.1, 0.2, 0.3])}, ...]. - """ - return self.trainable_params - def _iter_states_params( self, params=False, states=False, currents=False - ) -> Tuple[str, np.ndarray]: # type: ignore + ) -> Tuple[str, np.ndarray]: # type: ignore # assert that either params or states is True assert params or states or currents, "Select either params / states / currents." all_mechs = self.channels + self.synapses - + if params: global_params = ["radius", "length", "axial_resistivity", "capacitance"] - all_params = sum([list(m.params) for m in all_mechs], []) + global_params + all_params = [p for m in all_mechs for p in m.params] + global_params for key in all_params: yield key, self._inds_of_state_param(key) if states: global_states = ["v"] - all_states = sum([list(m.states) for m in all_mechs], []) + global_states + all_states = [s for m in all_mechs for s in m.states] + global_states for key in all_states: yield key, self._inds_of_state_param(key) @@ -1251,49 +1285,16 @@ def _iter_states_params( for key in current_names: yield key, self._inds_of_state_param(key) - def _prepare_for_jax(self): - # prepare lookup of indices of states, parameters and mechanisms - global_params = ["radius", "length", "axial_resistivity", "capacitance"] - global_states = ["v"] - - current_names = self.membrane_current_names + self.synapse_current_names - global_states_params = global_states + global_params + current_names - - channel_names = [c._name for c in self.channels] - syn_names = [s._name for s in self.synapses] - - node_attrs = self.nodes.columns.to_list() + current_names + channel_names - def inds_of_key(key: str) -> np.ndarray: - """Return the indices for params, states, mechanisms and currents.""" - data = self.nodes if key in node_attrs else pd.DataFrame() - data = self.edges if key in self.edges.columns or key in syn_names else data - - if key in channel_names + syn_names: - where = data["type"] == key if key in syn_names else data[key] - elif key in data.columns: - where = ~data[key].isna() - elif key in global_states_params: - where = pd.Index([True] * len(data)) - else: - raise ValueError(f"Key '{key}' not found in nodes or edges") - return data.index[where].to_numpy() - - # expose the lookup function to the class with precomputed attrs in scope - self._inds_of_state_param = inds_of_key + def get_parameters(self) -> List[Dict[str, jnp.ndarray]]: + """Get all trainable parameters. - # add index attrs to mechansisms (i.e. where was it inserted) and also keep track - # of states / parameters that are also shared by other mechanisms. - for mech in self.channels + self.synapses: - mech.indices = self._inds_of_state_param(mech._name) - mech._jax_inds = {} - current = {mech.current_name: None} if isinstance(mech, Channel) else {} + The returned parameters should be passed to `jx.integrate(..., params=params). - for param_state in {**mech.params, **mech.states, **current}: - is_global = not param_state.startswith(f"{mech._name}_") - if is_global: - global_inds = self._inds_of_state_param(param_state) - local_inds = np.where(np.isin(global_inds, mech.indices))[0] - mech._jax_inds[param_state] = local_inds + Returns: + A list of all trainable parameters in the form of + [{"gNa": jnp.array([0.1, 0.2, 0.3])}, ...]. + """ + return self.trainable_params def _get_all_states_params( self, @@ -1751,10 +1752,10 @@ def insert(self, channel: Channel): Args: channel: The channel to insert.""" - name = channel._name + name = channel.name # Channel does not yet exist in the `jx.Module` at all. - if name not in [c._name for c in self.base.channels]: + if name not in [c.name for c in self.base.channels]: self.base.channels.append(channel) self.base.nodes[name] = ( False # Previous columns do not have the new channel. @@ -1779,9 +1780,9 @@ def delete_channel(self, channel: Channel): Args: channel: The channel to remove.""" - name = channel._name - channel_names = [c._name for c in self.channels] - all_channel_names = [c._name for c in self.base.channels] + name = channel.name + channel_names = [c.name for c in self.channels] + all_channel_names = [c.name for c in self.base.channels] if name in channel_names: channel_cols = list({**channel.params, **channel.states}.keys()) self.base.nodes.loc[self._nodes_in_view, channel_cols] = float("nan") @@ -2615,15 +2616,15 @@ def _set_trainables_in_view(self): def _channels_in_view(self, pointer: Union[Module, View]) -> List[Channel]: """Set channels to show only those in view.""" - names = [name._name for name in pointer.channels] + names = [c.name for c in pointer.channels] channel_in_view = self.nodes[names].any(axis=0) channel_in_view = channel_in_view[channel_in_view].index - return [deepcopy(c) for c in pointer.channels if c._name in channel_in_view] + return [deepcopy(c) for c in pointer.channels if c.name in channel_in_view] def _synapses_in_view(self, pointer: Union[Module, View]): """Set synapses to show only those in view.""" names = self.edges["type"].unique() - return [deepcopy(syn) for syn in pointer.synapses if syn._name in names] + return [deepcopy(syn) for syn in pointer.synapses if syn.name in names] def _nbranches_per_cell_in_view(self) -> np.ndarray: cell_nodes = self.nodes.groupby("global_cell_index") diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index 05f6fabe..4a4e4f9d 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -303,7 +303,7 @@ def _synapse_currents( diff = 1e-3 num_comp = len(voltages) - synapse_current_states = {f"i_{s._name}": zeros for s in syn_channels} + synapse_current_states = {f"i_{s.name}": zeros for s in syn_channels} for i, group in edges.groupby("type_ind"): synapse = syn_channels[i] pre_inds = group["pre_global_comp_index"].to_numpy() @@ -340,15 +340,15 @@ def _synapse_currents( syn_const_terms = syn_const_terms.at[:].add(-gathered_syn_currents[1]) # Save the current (for the unperturbed voltage) as a state that will # also be passed to the state update. - synapse_current_states[f"i_{synapse._name}"] = ( - synapse_current_states[f"i_{synapse._name}"] + synapse_current_states[f"i_{synapse.name}"] = ( + synapse_current_states[f"i_{synapse.name}"] .at[post_inds] .add(synapse_currents_dist[0]) ) # Copy the currents into the `state` dictionary such that they can be # recorded and used by `Channel.update_states()`. - for name in [s._name for s in self.synapses]: + for name in [s.name for s in self.synapses]: states[f"i_{name}"] = synapse_current_states[f"i_{name}"] return states, (syn_voltage_terms, syn_const_terms) @@ -474,14 +474,14 @@ def vis( return ax def _infer_synapse_type_ind(self, synapse_name): - syn_names = [s._name for s in self.base.synapses] + syn_names = [s.name for s in self.base.synapses] is_new_type = False if synapse_name in syn_names else True type_ind = len(syn_names) if is_new_type else syn_names.index(synapse_name) return type_ind, is_new_type def _append_multiple_synapses(self, pre_nodes, post_nodes, synapse_type): # Add synapse types to the module and infer their unique identifier. - synapse_name = synapse_type._name + synapse_name = synapse_type.name synapse_current_name = f"i_{synapse_name}" type_ind, is_new = self._infer_synapse_type_ind(synapse_name) if is_new: # synapse is not known From efee504f9df67b358b927d5797ec731f21ffa2ca Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Tue, 17 Dec 2024 00:54:03 +0100 Subject: [PATCH 21/26] doc: add documentation --- jaxley/modules/base.py | 48 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 41 insertions(+), 7 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 71435643..cf9392dd 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -724,6 +724,16 @@ def _gather_channels_from_constituents(self, constituents: List): self.base.nodes.loc[self.nodes[name].isna(), name] = False def _prepare_for_jax(self): + """Prepare the module for simulation with JAX. + + This function has to be run inside or before `to_jax`. It's main purpose is to; + 1. Prepare the lookup of indices of states, parameters and mechanisms. + 2. Add index attributes to mechanisms (i.e. where was it inserted) and also keep + track of states / parameters that are also shared by other mechanisms. + + Adds `_inds_of_state_param(key: str)` to the module and also adds `indices` and + `_jax_inds` to the mechanisms. + """ # prepare lookup of indices of states, parameters and mechanisms global_params = ["radius", "length", "axial_resistivity", "capacitance"] global_states = ["v"] @@ -1262,9 +1272,18 @@ def _get_state_names(self) -> Tuple[List, List]: ) def _iter_states_params( - self, params=False, states=False, currents=False + self, params: bool = False, states: bool = False, currents: bool = False ) -> Tuple[str, np.ndarray]: # type: ignore - # assert that either params or states is True + """Iterate over all states and parameters. + + Args: + params: Whether to iterate over parameters. + states: Whether to iterate over states. + currents: Whether to iterate over currents. + + Yields: + The key and the indices of the states / parameters. + """ assert params or states or currents, "Select either params / states / currents." all_mechs = self.channels + self.synapses @@ -1299,12 +1318,27 @@ def get_parameters(self) -> List[Dict[str, jnp.ndarray]]: def _get_all_states_params( self, pstate: List[Dict], - voltage_solver=None, - delta_t=None, - all_params=None, - params=False, - states=False, + voltage_solver: str = None, + delta_t: float = None, + all_params: Dict[str, jnp.ndarray] = None, + params: bool = False, + states: bool = False, ) -> Dict[str, jnp.ndarray]: + """Get all parameters and/or states of the module. + + Common backbone of both `get_all_parameters()` and `get_all_states()`. + + Args: + pstate: The state of the trainable parameters. + voltage_solver: The voltage solver that is used. + delta_t: The stepsize. + all_params: All parameters of the module. + params: Whether to get the parameters. + states: Whether to get the states. + + Returns: + A dictionary of all parameters and/or states of the module. + """ states_params = {} pkeys = {} for i, p in enumerate(pstate): From b760e90b8d616858f03d49d5cb9ea0191430f72c Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Tue, 17 Dec 2024 12:48:18 +0100 Subject: [PATCH 22/26] fix: fix param sharing --- jaxley/modules/base.py | 3 ++- jaxley/utils/cell_utils.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index cf9392dd..f49182e2 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -31,6 +31,7 @@ compute_axial_conductances, compute_current_density, compute_levels, + index_of_a_in_b, interpolate_xyzr, params_to_pstate, v_interp, @@ -1355,7 +1356,7 @@ def _get_all_states_params( # `set_param` is of shape `(num_params,)` # We need to unsqueeze `set_param` to make it `(num_params, 1)` # for the `.set()` to work. This is done with `[:, None]`. - inds = jnp.searchsorted(param_state_inds, inds) + inds = index_of_a_in_b(inds, param_state_inds) states_params[key] = states_params[key].at[inds].set(set_param[:, None]) if params: diff --git a/jaxley/utils/cell_utils.py b/jaxley/utils/cell_utils.py index 27e337ab..ccbe5833 100644 --- a/jaxley/utils/cell_utils.py +++ b/jaxley/utils/cell_utils.py @@ -772,3 +772,32 @@ def dtype_aware_concat(dfs): concat_df[col] = concat_df[col].astype(df[col].dtype) break # first match is sufficient return concat_df + + +def index_of_a_in_b(A: jnp.ndarray, B: jnp.ndarray) -> jnp.ndarray: + """Replace values in A with the indices of the corresponding values in B. + + Mainly used to determine the indices of parameters in jaxnodes based on the global + indices of the parameters in the cell. All values in A that are not in B are + replaced with -1. + + Example: + - indices_of_gNa = [5,6,7,8,9] + - indices_to_change = [6,7] + - index_of_a_in_b(indices_to_change, indices_of_gNa) -> [1,2] + + Args: + A: Array of shape (N, M). + B: Array of shape (N, K). + + Returns: + Array of shape of A with the indices of the values of A in B.""" + matches = A[:, :, None] == B + # Get mask for values that exist in B + exists_in_B = matches.any(axis=-1) + # Get indices where matches occur + indices = jnp.where(matches, jnp.arange(len(B))[None, None, :], 0) + # Sum along last axis to get the indices + result = jnp.sum(indices, axis=-1) + # Replace values not in B with -1 + return jnp.where(exists_in_B, result, -1) From 62351b819b228dfb849c2112ae196dd6b56c5abb Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Mon, 23 Dec 2024 02:04:36 +0100 Subject: [PATCH 23/26] doc: update changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1b4fa8ed..dc56b02f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,8 @@ net.vis() - changelog added to CI (#537, #558, @jnsbck) +- Refactor of channel and synapse stepping internals and how the model is transferred to jax for more efficient and readable code (#487, @jnsbck). + # 0.5.0 ### API changes From 983817c1357fbf30615c7279b058c47a6da01423 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Sat, 11 Jan 2025 20:59:48 +0100 Subject: [PATCH 24/26] fix: refactor of shared states and got rid of prepare_for_jax and other unecessary things --- jaxley/modules/base.py | 257 ++++++++++++++++++-------------------- jaxley/modules/network.py | 24 +++- 2 files changed, 141 insertions(+), 140 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index f49182e2..b95fcbce 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -724,91 +724,82 @@ def _gather_channels_from_constituents(self, constituents: List): name = channel.name self.base.nodes.loc[self.nodes[name].isna(), name] = False - def _prepare_for_jax(self): - """Prepare the module for simulation with JAX. + def _inds_of_state_param(self, key: str) -> jnp.ndarray: + """lookup the indices for params or states. - This function has to be run inside or before `to_jax`. It's main purpose is to; - 1. Prepare the lookup of indices of states, parameters and mechanisms. - 2. Add index attributes to mechanisms (i.e. where was it inserted) and also keep - track of states / parameters that are also shared by other mechanisms. + Returns indices that have non-NaN values for the given key in `nodes` or `edges`. - Adds `_inds_of_state_param(key: str)` to the module and also adds `indices` and - `_jax_inds` to the mechanisms. - """ - # prepare lookup of indices of states, parameters and mechanisms - global_params = ["radius", "length", "axial_resistivity", "capacitance"] - global_states = ["v"] - - current_names = self.membrane_current_names + self.synapse_current_names - global_states_params = global_states + global_params + current_names - - channel_names = [c.name for c in self.channels] - syn_names = [s.name for s in self.synapses] - - node_attrs = self.nodes.columns.to_list() + current_names + channel_names - - def inds_of_key(key: str) -> np.ndarray: - """Return the indices for params, states, mechanisms and currents.""" - data = self.nodes if key in node_attrs else pd.DataFrame() - data = self.edges if key in self.edges.columns or key in syn_names else data + Args: + key: The name of the param or state to get the indices for. - if key in channel_names + syn_names: - where = data["type"] == key if key in syn_names else data[key] - elif key in data.columns: - where = ~data[key].isna() - elif key in global_states_params: - where = pd.Index([True] * len(data)) - else: - raise ValueError(f"Key '{key}' not found in nodes or edges") - return data.index[where].to_numpy() - - # expose the lookup function to the class with precomputed attrs in scope - self._inds_of_state_param = inds_of_key - - # add index attrs to mechansisms (i.e. where was it inserted) and also keep track - # of states / parameters that are also shared by other mechanisms. - for mech in self.channels + self.synapses: - mech.indices = self._inds_of_state_param(mech.name) - mech._jax_inds = {} - current = {mech.current_name: None} if isinstance(mech, Channel) else {} - - for param_state in {**mech.params, **mech.states, **current}: - is_global = not param_state.startswith(f"{mech.name}_") - if is_global: - global_inds = self._inds_of_state_param(param_state) - local_inds = np.where(np.isin(global_inds, mech.indices))[0] - mech._jax_inds[param_state] = local_inds + Returns: + The indices of the param or state. + """ + if key in self.nodes.columns: + data = self.nodes[key] + return jnp.asarray(data.index[data.notna()]) + elif key in self.edges.columns: + data = self.edges[key] + return jnp.asarray(data.index[data.notna()]) + else: + raise ValueError(f"Key '{key}' not found in nodes or edges") def to_jax(self): - """Move `.nodes` to `.jaxnodes`. + """Move `.nodes` to `.jaxnodes` and `.edges` to `.jaxedges`. Before the actual simulation is run (via `jx.integrate`), all parameters of the `jx.Module` are stored in `.nodes` (a `pd.DataFrame`). However, for simulation, these parameters have to be moved to be `jnp.ndarrays` such that they can be processed on GPU/TPU and such that the simulation can be - differentiated. `.to_jax()` copies the `.nodes` to `.jaxnodes`. + differentiated. `.to_jax()` copies the `.nodes` to `.jaxnodes` and the `.edges` + to `.jaxedges`. In addition, jaxglobals keeps indices for parameters and states + that are shared by multiple mechanisms. """ # the parameters and states in the jaxnodes are stored on a per-mechanism basis, # i.e. if only compartment #2 has a HH channels, then the jaxnodes will be # {HH_gNa: [0.1], ...} vs. {gbar_HH: [NaN, NaN, 0.1, ...], ...}. This means # a seperate lookup table is needed to figure out which parameters and states # belong to which mechanism and at which global indices the mechanism lives. - self._prepare_for_jax() - - jaxnodes, jaxedges = {}, {} - nodes = self.nodes.to_dict(orient="list") - edges = self.edges.to_dict(orient="list") + jaxnodes, jaxedges = {"states": {}, "params": {}}, {"states": {}, "params": {}} + jaxglobals = {"states": {}, "params": {}} - for key, inds in self._iter_states_params(states=True, params=True): - data = nodes if key in self.nodes.columns else edges - jax_arrays = jaxnodes if key in self.nodes.columns else jaxedges + global_params = ["radius", "length", "axial_resistivity", "capacitance"] + global_states = ["v"] - values = jnp.asarray(data[key])[inds] - jax_arrays.update({key: values}) + for state in global_states: + jaxnodes["states"][state] = jnp.asarray(self.nodes[state]) + for param in global_params: + jaxnodes["params"][param] = jnp.asarray(self.nodes[param]) - self.jaxnodes = {k: jnp.asarray(v) for k, v in jaxnodes.items()} - self.jaxedges = {k: jnp.asarray(v) for k, v in jaxedges.items()} + for data, mechs in zip( + [self.nodes, self.edges], [self.channels, self.synapses] + ): + is_nodes = "v" in data.columns + jax_arrays = jaxnodes if is_nodes else jaxedges + for mech in mechs: + where_mech = data[mech.name] if is_nodes else data["type"] == mech.name + mech.indices = jnp.asarray(data.index[where_mech].to_list()) + + params = data.loc[where_mech, mech.params.keys()].to_dict(orient="list") + states = data.loc[where_mech, mech.states.keys()].to_dict(orient="list") + + is_local = lambda x: x.startswith(f"{mech.name}_") + for label, params_or_states in zip( + ["params", "states"], [params, states] + ): + for k, v in params_or_states.items(): + v = v if is_local(k) else data[k][data[k].notna()].to_list() + jax_arrays[label][k] = jnp.asarray(v) + + if not is_local(k): + if mech.name not in jaxglobals[label]: + jaxglobals[label][mech.name] = {} + jaxglobals[label][mech.name][k] = mech.indices + + self.jaxnodes = jaxnodes + self.jaxedges = jaxedges + self.jaxglobals = jaxglobals def show( self, @@ -1272,39 +1263,6 @@ def _get_state_names(self) -> Tuple[List, List]: synapse_states + self.synapse_current_names, ) - def _iter_states_params( - self, params: bool = False, states: bool = False, currents: bool = False - ) -> Tuple[str, np.ndarray]: # type: ignore - """Iterate over all states and parameters. - - Args: - params: Whether to iterate over parameters. - states: Whether to iterate over states. - currents: Whether to iterate over currents. - - Yields: - The key and the indices of the states / parameters. - """ - assert params or states or currents, "Select either params / states / currents." - all_mechs = self.channels + self.synapses - - if params: - global_params = ["radius", "length", "axial_resistivity", "capacitance"] - all_params = [p for m in all_mechs for p in m.params] + global_params - for key in all_params: - yield key, self._inds_of_state_param(key) - - if states: - global_states = ["v"] - all_states = [s for m in all_mechs for s in m.states] + global_states - for key in all_states: - yield key, self._inds_of_state_param(key) - - if currents: - current_names = self.membrane_current_names + self.synapse_current_names - for key in current_names: - yield key, self._inds_of_state_param(key) - def get_parameters(self) -> List[Dict[str, jnp.ndarray]]: """Get all trainable parameters. @@ -1341,23 +1299,22 @@ def _get_all_states_params( A dictionary of all parameters and/or states of the module. """ states_params = {} - pkeys = {} - for i, p in enumerate(pstate): - pkeys[p["key"]] = pkeys[p["key"]] + [i] if p["key"] in pkeys else [i] - - for key, param_state_inds in self._iter_states_params(params, states): - jax_arrays = self.jaxnodes if key in self.nodes.columns else self.jaxedges - states_params[key] = jax_arrays[key] - # Override with those parameters set by `.make_trainable()`. - for i in pkeys.get(key, []): - p = pstate[i] - key, inds, set_param = p["key"], p["indices"], p["val"] - # `inds` is of shape `(num_params, num_comps_per_param)`. - # `set_param` is of shape `(num_params,)` - # We need to unsqueeze `set_param` to make it `(num_params, 1)` - # for the `.set()` to work. This is done with `[:, None]`. - inds = index_of_a_in_b(inds, param_state_inds) - states_params[key] = states_params[key].at[inds].set(set_param[:, None]) + + jax_states = {**self.jaxnodes["states"], **self.jaxedges["states"]} + jax_params = {**self.jaxnodes["params"], **self.jaxedges["params"]} + for key, data in {**jax_params, **jax_states}.items(): + states_params[key] = data + + # Override with those parameters set by `.make_trainable()`. + for p in pstate: + key, inds, set_param = p["key"], p["indices"], p["val"] + # `inds` is of shape `(num_params, num_comps_per_param)`. + # `set_param` is of shape `(num_params,)` + # We need to unsqueeze `set_param` to make it `(num_params, 1)` + # for the `.set()` to work. This is done with `[:, None]`. + param_state_inds = self._inds_of_state_param(key) + inds = index_of_a_in_b(inds, param_state_inds) + states_params[key] = states_params[key].at[inds].set(set_param[:, None]) if params: # Compute conductance params and add them to the params dictionary. @@ -1447,6 +1404,32 @@ def _initialize(self): self._init_morph() return self + def _mech_filter_globals(self, dct, active_mech, globals): + global_params = ["radius", "length", "axial_resistivity", "capacitance"] + global_states = ["v"] + + global_params_or_states_or_currents = [] + global_params_or_states_or_currents += global_states if "v" in dct else [] + global_params_or_states_or_currents += global_params if "radius" in dct else [] + + is_channel = isinstance(active_mech, Channel) + i_mech = active_mech.current_name if is_channel else f"{active_mech.name}_i" + global_params_or_states_or_currents += [i_mech] if i_mech in dct else [] + + mech_inds = active_mech.indices + + filtered_dct = dct.copy() + for key in global_params_or_states_or_currents: + if key in dct: + filtered_dct[key] = dct[key][mech_inds] + + if active_mech.name in globals: + for key, inds in globals[active_mech.name].items(): + if key in dct: + filter_inds = index_of_a_in_b(mech_inds.reshape(1, -1), inds) + filtered_dct[key] = dct[key][filter_inds.flatten()] + return filtered_dct + def init_states(self, delta_t: float = 0.025): # TODO FROM #447: MAKE THIS WORK FOR VIEW? """Initialize all mechanisms in their steady state. @@ -1459,9 +1442,8 @@ def init_states(self, delta_t: float = 0.025): # Update states of the channels. self.to_jax() # Create `.jaxnodes` from `.nodes` and `.jaxedges` from `.edges`. states = {} - for key, _ in self._iter_states_params(states=True): - jax_arrays = self.jaxnodes if key in self.nodes.columns else self.jaxedges - states[key] = jax_arrays[key] + for key, data in {**self.jaxnodes["states"], **self.jaxedges["states"]}.items(): + states[key] = data # We do not use any `pstate` for initializing. In principle, we could change # that by allowing an input `params` and `pstate` to this function. @@ -1471,11 +1453,13 @@ def init_states(self, delta_t: float = 0.025): voltages = self.nodes["v"].to_numpy() for channel in self.channels: - states = self._filter_params_states(states, channel._jax_inds) - params = self._filter_params_states(params, channel._jax_inds) + global_states = self.jaxglobals["states"] + global_params = self.jaxglobals["params"] + channel_states = self._mech_filter_globals(states, channel, global_states) + channel_params = self._mech_filter_globals(params, channel, global_params) init_state = channel.init_state( - states, voltages[channel.indices], params, delta_t + channel_states, voltages[channel.indices], channel_params, delta_t ) # `init_state` might not return all channel states. Only the ones that are @@ -1831,12 +1815,6 @@ def delete_channel(self, channel: Channel): else: raise ValueError(f"Channel {name} not found in the module.") - def _filter_params_states(self, pytree, filter_dct): - for key, inds in filter_dct.items(): - if key in pytree: - pytree[key] = pytree[key][inds] - return pytree - @only_allow_module def step( self, @@ -2006,15 +1984,22 @@ def _step_channels_state( for channel in channels: # States updates. - states = self._filter_params_states(states, channel._jax_inds) - params = self._filter_params_states(params, channel._jax_inds) + global_states = self.jaxglobals["states"] + global_params = self.jaxglobals["params"] + channel_states = self._mech_filter_globals(states, channel, global_states) + channel_params = self._mech_filter_globals(params, channel, global_params) + channel_states_updated = channel.update_states( - states, delta_t, voltages[channel.indices], params + channel_states, delta_t, voltages[channel.indices], channel_params ) + # Rebuild state. This has to be done within the loop over channels to allow # multiple channels which modify the same state. for key, val in channel_states_updated.items(): - states[key] = states[key].at[:].set(val) + param_state_inds = self._inds_of_state_param(key) + channel_inds = channel.indices.reshape(1, -1) + inds = index_of_a_in_b(channel_inds, param_state_inds).flatten() + states[key] = states[key].at[inds].set(val) return states @@ -2046,11 +2031,13 @@ def _channel_currents( v_channel = voltages[channel_inds] v_and_perturbed = jnp.array([v_channel, v_channel + diff]) - states = self._filter_params_states(states, channel._jax_inds) - params = self._filter_params_states(params, channel._jax_inds) + global_states = self.jaxglobals["states"] + global_params = self.jaxglobals["params"] + channel_states = self._mech_filter_globals(states, channel, global_states) + channel_params = self._mech_filter_globals(params, channel, global_params) membrane_currents = vmap(channel.compute_current, in_axes=(None, 0, None))( - states, v_and_perturbed, params + channel_states, v_and_perturbed, channel_params ) # Split into voltage and constant terms. @@ -2654,7 +2641,7 @@ def _channels_in_view(self, pointer: Union[Module, View]) -> List[Channel]: names = [c.name for c in pointer.channels] channel_in_view = self.nodes[names].any(axis=0) channel_in_view = channel_in_view[channel_in_view].index - return [deepcopy(c) for c in pointer.channels if c.name in channel_in_view] + return [c for c in self.base.channels if c.name in channel_in_view] def _synapses_in_view(self, pointer: Union[Module, View]): """Set synapses to show only those in view.""" diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index 4a4e4f9d..9292d35b 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -22,6 +22,7 @@ dtype_aware_concat, loc_of_index, merge_cells, + index_of_a_in_b, ) from jaxley.utils.misc_utils import concat_and_ignore_empty, cumsum_leading_zero from jaxley.utils.solver_utils import ( @@ -269,19 +270,27 @@ def _step_synapse_state( pre_inds = edges.loc[synapse.indices, "pre_global_comp_index"].to_numpy() post_inds = edges.loc[synapse.indices, "post_global_comp_index"].to_numpy() + global_states = self.jaxglobals["states"] + global_params = self.jaxglobals["params"] + syn_states = self._mech_filter_globals(states, synapse, global_states) + syn_params = self._mech_filter_globals(params, synapse, global_params) + # State updates. synapse_states_updated = synapse.update_states( - states, + syn_states, delta_t, voltages[pre_inds], voltages[post_inds], - params, + syn_params, ) # Rebuild state. This has to be done within the loop over channels to allow # multiple channels which modify the same state. for key, val in synapse_states_updated.items(): - states[key] = states[key].at[:].set(val) + param_state_inds = self._inds_of_state_param(key) + synapse_inds = synapse.indices.reshape(1, -1) + inds = index_of_a_in_b(synapse_inds, param_state_inds).flatten() + states[key] = states[key].at[inds].set(val) return states @@ -313,13 +322,18 @@ def _synapse_currents( pre_v_and_perturbed = jnp.array([v_pre, v_pre + diff]) post_v_and_perturbed = jnp.array([v_post, v_post + diff]) + global_states = self.jaxglobals["states"] + global_params = self.jaxglobals["params"] + syn_states = self._mech_filter_globals(states, synapse, global_states) + syn_params = self._mech_filter_globals(params, synapse, global_params) + synapse_currents = vmap( synapse.compute_current, in_axes=(None, 0, 0, None) )( - states, + syn_states, pre_v_and_perturbed, post_v_and_perturbed, - params, + syn_params, ) synapse_currents_dist = compute_current_density( synapse_currents, From 6f7f3890f2b3c7fc422ac8ffc2fd55085580d9f2 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Mon, 13 Jan 2025 15:35:07 +0100 Subject: [PATCH 25/26] fix: major refactor of jaxnodes and fix regression tests. --- jaxley/modules/base.py | 214 +++++++++++++++++++------------------ jaxley/modules/network.py | 45 ++------ jaxley/utils/cell_utils.py | 41 +++++-- 3 files changed, 150 insertions(+), 150 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index b95fcbce..e75d9df9 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -33,6 +33,7 @@ compute_levels, index_of_a_in_b, interpolate_xyzr, + iterate_leaves, params_to_pstate, v_interp, ) @@ -745,61 +746,60 @@ def _inds_of_state_param(self, key: str) -> jnp.ndarray: raise ValueError(f"Key '{key}' not found in nodes or edges") def to_jax(self): - """Move `.nodes` to `.jaxnodes` and `.edges` to `.jaxedges`. + """Move `jx.Module` to `jax`. Before the actual simulation is run (via `jx.integrate`), all parameters of - the `jx.Module` are stored in `.nodes` (a `pd.DataFrame`). However, for + the `jx.Module` are stored in `.nodes`/`.edges` (`pd.DataFrame`). However, for simulation, these parameters have to be moved to be `jnp.ndarrays` such that they can be processed on GPU/TPU and such that the simulation can be - differentiated. `.to_jax()` copies the `.nodes` to `.jaxnodes` and the `.edges` - to `.jaxedges`. In addition, jaxglobals keeps indices for parameters and states - that are shared by multiple mechanisms. + differentiated. `.to_jax()` copies `.nodes` to `.jax["nodes"]` and `.edges` + to `.jax["edges"]`. In addition, jax["global"] keeps track of parameters and + states that are shared by multiple mechanisms. """ - # the parameters and states in the jaxnodes are stored on a per-mechanism basis, - # i.e. if only compartment #2 has a HH channels, then the jaxnodes will be + # the parameters and states in the jax["nodes"] are stored on a per-mechanism basis, + # i.e. if only compartment #2 has a HH channels, then the jax["nodes"] will be # {HH_gNa: [0.1], ...} vs. {gbar_HH: [NaN, NaN, 0.1, ...], ...}. This means # a seperate lookup table is needed to figure out which parameters and states # belong to which mechanism and at which global indices the mechanism lives. - jaxnodes, jaxedges = {"states": {}, "params": {}}, {"states": {}, "params": {}} - jaxglobals = {"states": {}, "params": {}} + keys = ["nodes", "edges", "global"] + jax = dict(zip(keys, [{"states": {}, "params": {}}] * len(keys))) - global_params = ["radius", "length", "axial_resistivity", "capacitance"] - global_states = ["v"] + module_param_states = { + "states": ["v"], + "params": ["radius", "length", "axial_resistivity", "capacitance"], + } - for state in global_states: - jaxnodes["states"][state] = jnp.asarray(self.nodes[state]) - for param in global_params: - jaxnodes["params"][param] = jnp.asarray(self.nodes[param]) + for label, keys in module_param_states.items(): + for key in keys: + jax["global"][label][key] = jnp.asarray(self.nodes[key]) + + for mech in self.channels + self.synapses: + is_channel = isinstance(mech, Channel) + jax_arrays = jax["nodes"] if is_channel else jax["edges"] + data = self.nodes if is_channel else self.edges + + where_mech = data[mech.name] if is_channel else data["type"] == mech.name + mech.indices = jnp.asarray(data.index[where_mech].to_list()) + if isinstance(mech, Synapse): + pre_inds = data["pre_global_comp_index"] + post_inds = data["post_global_comp_index"] + mech.pre_indices = jnp.asarray(pre_inds[where_mech].to_list()) + mech.post_indices = jnp.asarray(post_inds[where_mech].to_list()) + + params = data.loc[where_mech, mech.params.keys()].to_dict(orient="list") + states = data.loc[where_mech, mech.states.keys()].to_dict(orient="list") + + is_global = lambda x: not x.startswith(f"{mech.name}_") + for label, params_or_states in zip(["params", "states"], [params, states]): + for k in params_or_states: + jax_data = jnp.asarray(data[k][data[k].notna()].to_list()) + if not is_global(k): + jax["global"][label][k] = jax_data + else: + jax_arrays[label][k] = jax_data[mech.indices] - for data, mechs in zip( - [self.nodes, self.edges], [self.channels, self.synapses] - ): - is_nodes = "v" in data.columns - jax_arrays = jaxnodes if is_nodes else jaxedges - for mech in mechs: - where_mech = data[mech.name] if is_nodes else data["type"] == mech.name - mech.indices = jnp.asarray(data.index[where_mech].to_list()) - - params = data.loc[where_mech, mech.params.keys()].to_dict(orient="list") - states = data.loc[where_mech, mech.states.keys()].to_dict(orient="list") - - is_local = lambda x: x.startswith(f"{mech.name}_") - for label, params_or_states in zip( - ["params", "states"], [params, states] - ): - for k, v in params_or_states.items(): - v = v if is_local(k) else data[k][data[k].notna()].to_list() - jax_arrays[label][k] = jnp.asarray(v) - - if not is_local(k): - if mech.name not in jaxglobals[label]: - jaxglobals[label][mech.name] = {} - jaxglobals[label][mech.name][k] = mech.indices - - self.jaxnodes = jaxnodes - self.jaxedges = jaxedges - self.jaxglobals = jaxglobals + self.jax = jax def show( self, @@ -1300,10 +1300,8 @@ def _get_all_states_params( """ states_params = {} - jax_states = {**self.jaxnodes["states"], **self.jaxedges["states"]} - jax_params = {**self.jaxnodes["params"], **self.jaxedges["params"]} - for key, data in {**jax_params, **jax_states}.items(): - states_params[key] = data + for key, values, path in iterate_leaves(self.jax): + states_params[key] = values # Override with those parameters set by `.make_trainable()`. for p in pstate: @@ -1355,7 +1353,7 @@ def get_all_parameters( params = module.get_parameters() # i.e. [0, 1, 2] pstate = params_to_pstate(params, module.indices_set_by_trainables) - module.to_jax() # needed for call to module.jaxnodes + module.to_jax() # needed for call to module.jax Args: pstate: The state of the trainable parameters. pstate takes the form @@ -1379,7 +1377,7 @@ def get_all_states( self, pstate: List[Dict], all_params, delta_t: float ) -> Dict[str, jnp.ndarray]: # TODO FROM #447: MAKE THIS WORK FOR VIEW? - """Get the full initial state of the module from jaxnodes and trainables. + """Get the full initial state of the module from `.jax` and `.trainables`. Args: pstate: The state of the trainable parameters. @@ -1404,31 +1402,34 @@ def _initialize(self): self._init_morph() return self - def _mech_filter_globals(self, dct, active_mech, globals): - global_params = ["radius", "length", "axial_resistivity", "capacitance"] - global_states = ["v"] - - global_params_or_states_or_currents = [] - global_params_or_states_or_currents += global_states if "v" in dct else [] - global_params_or_states_or_currents += global_params if "radius" in dct else [] + def _filter_by_mech( + self, param_states: Dict, mech: Union[Channel, Synapse] + ) -> Dict: + """Filter params/states to include only those relevant to the active mech. - is_channel = isinstance(active_mech, Channel) - i_mech = active_mech.current_name if is_channel else f"{active_mech.name}_i" - global_params_or_states_or_currents += [i_mech] if i_mech in dct else [] - - mech_inds = active_mech.indices - - filtered_dct = dct.copy() - for key in global_params_or_states_or_currents: - if key in dct: - filtered_dct[key] = dct[key][mech_inds] + Args: + param_states: The param_states dictionary to filter. + mech: The active mechanism. - if active_mech.name in globals: - for key, inds in globals[active_mech.name].items(): - if key in dct: - filter_inds = index_of_a_in_b(mech_inds.reshape(1, -1), inds) - filtered_dct[key] = dct[key][filter_inds.flatten()] - return filtered_dct + Returns: + The filtered dictionary. + """ + is_channel = isinstance(mech, Channel) + i_mech = mech.current_name if is_channel else f"{mech.name}_i" + + filtered_param_states = param_states.copy() + if i_mech in param_states: + filtered_param_states[i_mech] = param_states[i_mech][mech.indices] + + params_and_or_states = ["states"] if "v" in param_states else [] + params_and_or_states += ["params"] if "radius" in param_states else [] + for param_state_key in params_and_or_states: + for key, _, _ in iterate_leaves(self.jax["global"][param_state_key]): + if key in param_states: + param_state_inds = self._inds_of_state_param(key) + filtered_inds = index_of_a_in_b(mech.indices, param_state_inds) + filtered_param_states[key] = param_states[key][filtered_inds] + return filtered_param_states def init_states(self, delta_t: float = 0.025): # TODO FROM #447: MAKE THIS WORK FOR VIEW? @@ -1440,26 +1441,23 @@ def init_states(self, delta_t: float = 0.025): delta_t: Passed on to `channel.init_state()`. """ # Update states of the channels. - self.to_jax() # Create `.jaxnodes` from `.nodes` and `.jaxedges` from `.edges`. - states = {} - for key, data in {**self.jaxnodes["states"], **self.jaxedges["states"]}.items(): - states[key] = data + self.to_jax() # Create `.jax` from `.nodes` and `.edges`. # We do not use any `pstate` for initializing. In principle, we could change # that by allowing an input `params` and `pstate` to this function. # `voltage_solver` could also be `jax.sparse` here, because both of them # build the channel parameters in the same way. - params = self.get_all_parameters([], voltage_solver="jaxley.thomas") + param_states = self._get_all_states_params([], voltage_solver="jaxley.thomas") voltages = self.nodes["v"].to_numpy() for channel in self.channels: - global_states = self.jaxglobals["states"] - global_params = self.jaxglobals["params"] - channel_states = self._mech_filter_globals(states, channel, global_states) - channel_params = self._mech_filter_globals(params, channel, global_params) + channel_param_states = self._filter_by_mech(param_states, channel) init_state = channel.init_state( - channel_states, voltages[channel.indices], channel_params, delta_t + channel_param_states, + voltages[channel.indices], + channel_param_states, + delta_t, ) # `init_state` might not return all channel states. Only the ones that are @@ -1971,38 +1969,50 @@ def _step_channels( ) return states, current_terms - def _step_channels_state( + def _step_mech_state( self, - states, - delta_t, - channels: List[Channel], - channel_nodes: pd.DataFrame, + states: Dict[str, jnp.ndarray], + delta_t: float, + mechs: List, + mech_data: pd.DataFrame, params: Dict[str, jnp.ndarray], ) -> Dict[str, jnp.ndarray]: - """One integration step of the channels.""" voltages = states["v"] - for channel in channels: + for mech in mechs: # States updates. - global_states = self.jaxglobals["states"] - global_params = self.jaxglobals["params"] - channel_states = self._mech_filter_globals(states, channel, global_states) - channel_params = self._mech_filter_globals(params, channel, global_params) + mech_states = self._filter_by_mech(states, mech) + mech_params = self._filter_by_mech(params, mech) + v_mech = ( + (voltages[mech.indices],) + if isinstance(mech, Channel) + else (voltages[mech.pre_indices], voltages[mech.post_indices]) + ) - channel_states_updated = channel.update_states( - channel_states, delta_t, voltages[channel.indices], channel_params + mech_states_updated = mech.update_states( + mech_states, delta_t, *v_mech, mech_params ) # Rebuild state. This has to be done within the loop over channels to allow # multiple channels which modify the same state. - for key, val in channel_states_updated.items(): + for key, val in mech_states_updated.items(): param_state_inds = self._inds_of_state_param(key) - channel_inds = channel.indices.reshape(1, -1) - inds = index_of_a_in_b(channel_inds, param_state_inds).flatten() + inds = index_of_a_in_b(mech.indices, param_state_inds) states[key] = states[key].at[inds].set(val) return states + def _step_channels_state( + self, + states, + delta_t, + channels: List[Channel], + channel_nodes: pd.DataFrame, + params: Dict[str, jnp.ndarray], + ) -> Dict[str, jnp.ndarray]: + """One integration step of the channels.""" + return self._step_mech_state(states, delta_t, channels, channel_nodes, params) + def _channel_currents( self, states: Dict[str, jnp.ndarray], @@ -2031,10 +2041,8 @@ def _channel_currents( v_channel = voltages[channel_inds] v_and_perturbed = jnp.array([v_channel, v_channel + diff]) - global_states = self.jaxglobals["states"] - global_params = self.jaxglobals["params"] - channel_states = self._mech_filter_globals(states, channel, global_states) - channel_params = self._mech_filter_globals(params, channel, global_params) + channel_states = self._filter_by_mech(states, channel) + channel_params = self._filter_by_mech(params, channel) membrane_currents = vmap(channel.compute_current, in_axes=(None, 0, None))( channel_states, v_and_perturbed, channel_params @@ -2502,7 +2510,7 @@ def __init__( k: np.intersect1d(v, self._nodes_in_view) for k, v in pointer.groups.items() } - if pointer.jaxnodes: + if pointer.jax: self.to_jax() self._current_view = "view" # if not instantiated via `comp`, `cell` etc. diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index 9292d35b..6f86a92e 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -20,9 +20,9 @@ compute_children_and_parents, compute_current_density, dtype_aware_concat, + index_of_a_in_b, loc_of_index, merge_cells, - index_of_a_in_b, ) from jaxley.utils.misc_utils import concat_and_ignore_empty, cumsum_leading_zero from jaxley.utils.solver_utils import ( @@ -264,35 +264,7 @@ def _step_synapse_state( delta_t: float, edges: pd.DataFrame, ) -> Dict: - voltages = states["v"] - - for synapse in syn_channels: - pre_inds = edges.loc[synapse.indices, "pre_global_comp_index"].to_numpy() - post_inds = edges.loc[synapse.indices, "post_global_comp_index"].to_numpy() - - global_states = self.jaxglobals["states"] - global_params = self.jaxglobals["params"] - syn_states = self._mech_filter_globals(states, synapse, global_states) - syn_params = self._mech_filter_globals(params, synapse, global_params) - - # State updates. - synapse_states_updated = synapse.update_states( - syn_states, - delta_t, - voltages[pre_inds], - voltages[post_inds], - syn_params, - ) - - # Rebuild state. This has to be done within the loop over channels to allow - # multiple channels which modify the same state. - for key, val in synapse_states_updated.items(): - param_state_inds = self._inds_of_state_param(key) - synapse_inds = synapse.indices.reshape(1, -1) - inds = index_of_a_in_b(synapse_inds, param_state_inds).flatten() - states[key] = states[key].at[inds].set(val) - - return states + return self._step_mech_state(states, delta_t, syn_channels, edges, params) def _synapse_currents( self, @@ -313,19 +285,16 @@ def _synapse_currents( num_comp = len(voltages) synapse_current_states = {f"i_{s.name}": zeros for s in syn_channels} - for i, group in edges.groupby("type_ind"): - synapse = syn_channels[i] - pre_inds = group["pre_global_comp_index"].to_numpy() - post_inds = group["post_global_comp_index"].to_numpy() + for synapse in syn_channels: + pre_inds = synapse.pre_indices + post_inds = synapse.post_indices v_pre, v_post = voltages[pre_inds], voltages[post_inds] pre_v_and_perturbed = jnp.array([v_pre, v_pre + diff]) post_v_and_perturbed = jnp.array([v_post, v_post + diff]) - global_states = self.jaxglobals["states"] - global_params = self.jaxglobals["params"] - syn_states = self._mech_filter_globals(states, synapse, global_states) - syn_params = self._mech_filter_globals(params, synapse, global_params) + syn_states = self._filter_by_mech(states, synapse) + syn_params = self._filter_by_mech(params, synapse) synapse_currents = vmap( synapse.compute_current, in_axes=(None, 0, 0, None) diff --git a/jaxley/utils/cell_utils.py b/jaxley/utils/cell_utils.py index ccbe5833..ae408dbc 100644 --- a/jaxley/utils/cell_utils.py +++ b/jaxley/utils/cell_utils.py @@ -777,7 +777,7 @@ def dtype_aware_concat(dfs): def index_of_a_in_b(A: jnp.ndarray, B: jnp.ndarray) -> jnp.ndarray: """Replace values in A with the indices of the corresponding values in B. - Mainly used to determine the indices of parameters in jaxnodes based on the global + Mainly used to determine the indices of parameters in `jax` based on the global indices of the parameters in the cell. All values in A that are not in B are replaced with -1. @@ -792,12 +792,35 @@ def index_of_a_in_b(A: jnp.ndarray, B: jnp.ndarray) -> jnp.ndarray: Returns: Array of shape of A with the indices of the values of A in B.""" + A_is_flat = A.ndim == 1 + A = A.reshape(1, -1) if A_is_flat else A matches = A[:, :, None] == B - # Get mask for values that exist in B - exists_in_B = matches.any(axis=-1) - # Get indices where matches occur - indices = jnp.where(matches, jnp.arange(len(B))[None, None, :], 0) - # Sum along last axis to get the indices - result = jnp.sum(indices, axis=-1) - # Replace values not in B with -1 - return jnp.where(exists_in_B, result, -1) + exists_in_B = matches.any(axis=-1) # mask for vals also in B + indices = jnp.where( + matches, jnp.arange(len(B))[None, None, :], 0 + ) # inds of matches + result = jnp.sum(indices, axis=-1) # Sum along last axis to get the indices + inds = jnp.where(exists_in_B, result, -1) # Replace values not in B with -1 + return inds.flatten() if A_is_flat else inds + + +def iterate_leaves(tree, path=[]): + """Iterate over all leafs (arrays) in a pytree while keeping track of their paths. + + Args: + tree: The pytree to iterate over + path: Current path in the tree (used recursively) + + Yields: + tuple: (final_key, array_value, full_path) + """ + if isinstance(tree, dict): + for key, value in tree.items(): + yield from iterate_leaves(value, path + [key]) + elif isinstance(tree, (list, tuple)): + for i, value in enumerate(tree): + yield from iterate_leaves(value, path + [str(i)]) + else: + # Assuming any non-dict/list/tuple is a leaf node (Array in this case) + if path: # Only yield if we have a path + yield path[-1], tree, path From a204f63c80b3e1c509b73d7b65d278a1d093d39b Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Mon, 13 Jan 2025 16:11:51 +0100 Subject: [PATCH 26/26] fix: ammend prev commit --- jaxley/modules/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index e75d9df9..e3648010 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -1415,7 +1415,7 @@ def _filter_by_mech( The filtered dictionary. """ is_channel = isinstance(mech, Channel) - i_mech = mech.current_name if is_channel else f"{mech.name}_i" + i_mech = mech.current_name if is_channel else f"i_{mech.name}" filtered_param_states = param_states.copy() if i_mech in param_states: @@ -1448,7 +1448,7 @@ def init_states(self, delta_t: float = 0.025): # `voltage_solver` could also be `jax.sparse` here, because both of them # build the channel parameters in the same way. param_states = self._get_all_states_params([], voltage_solver="jaxley.thomas") - voltages = self.nodes["v"].to_numpy() + voltages = param_states["v"] for channel in self.channels: channel_param_states = self._filter_by_mech(param_states, channel)