diff --git a/jaxley/channels/channel.py b/jaxley/channels/channel.py index 678b1e1e..b8a1dc41 100644 --- a/jaxley/channels/channel.py +++ b/jaxley/channels/channel.py @@ -59,22 +59,22 @@ def change_name(self, new_name: str): new_prefix = new_name + "_" self._name = new_name - self.channel_params = { + self.params = { ( new_prefix + key[len(old_prefix) :] if key.startswith(old_prefix) else key ): value - for key, value in self.channel_params.items() + for key, value in self.params.items() } - self.channel_states = { + self.states = { ( new_prefix + key[len(old_prefix) :] if key.startswith(old_prefix) else key ): value - for key, value in self.channel_states.items() + for key, value in self.states.items() } return self diff --git a/jaxley/channels/hh.py b/jaxley/channels/hh.py index c19bf002..70fc72b5 100644 --- a/jaxley/channels/hh.py +++ b/jaxley/channels/hh.py @@ -17,7 +17,7 @@ def __init__(self, name: Optional[str] = None): super().__init__(name) prefix = self._name - self.channel_params = { + self.params = { f"{prefix}_gNa": 0.12, f"{prefix}_gK": 0.036, f"{prefix}_gLeak": 0.0003, @@ -25,7 +25,7 @@ def __init__(self, name: Optional[str] = None): f"{prefix}_eK": -77.0, f"{prefix}_eLeak": -54.3, } - self.channel_states = { + self.states = { f"{prefix}_m": 0.2, f"{prefix}_h": 0.2, f"{prefix}_n": 0.2, diff --git a/jaxley/channels/pospischil.py b/jaxley/channels/pospischil.py index 5884deac..8602a72c 100644 --- a/jaxley/channels/pospischil.py +++ b/jaxley/channels/pospischil.py @@ -40,11 +40,11 @@ def __init__(self, name: Optional[str] = None): super().__init__(name) prefix = self._name - self.channel_params = { + self.params = { f"{prefix}_gLeak": 1e-4, f"{prefix}_eLeak": -70.0, } - self.channel_states = {} + self.states = {} self.current_name = f"i_{prefix}" def update_states( @@ -77,12 +77,12 @@ def __init__(self, name: Optional[str] = None): super().__init__(name) prefix = self._name - self.channel_params = { + self.params = { f"{prefix}_gNa": 50e-3, "eNa": 50.0, "vt": -60.0, # Global parameter, not prefixed with `Na`. } - self.channel_states = {f"{prefix}_m": 0.2, f"{prefix}_h": 0.2} + self.states = {f"{prefix}_m": 0.2, f"{prefix}_h": 0.2} self.current_name = f"i_Na" def update_states( @@ -148,12 +148,12 @@ def __init__(self, name: Optional[str] = None): super().__init__(name) prefix = self._name - self.channel_params = { + self.params = { f"{prefix}_gK": 5e-3, "eK": -90.0, "vt": -60.0, # Global parameter, not prefixed with `Na`. } - self.channel_states = {f"{prefix}_n": 0.2} + self.states = {f"{prefix}_n": 0.2} self.current_name = f"i_K" def update_states( @@ -204,12 +204,12 @@ def __init__(self, name: Optional[str] = None): super().__init__(name) prefix = self._name - self.channel_params = { + self.params = { f"{prefix}_gKm": 0.004e-3, f"{prefix}_taumax": 4000.0, f"eK": -90.0, } - self.channel_states = {f"{prefix}_p": 0.2} + self.states = {f"{prefix}_p": 0.2} self.current_name = f"i_K" def update_states( @@ -261,11 +261,11 @@ def __init__(self, name: Optional[str] = None): super().__init__(name) prefix = self._name - self.channel_params = { + self.params = { f"{prefix}_gCaL": 0.1e-3, "eCa": 120.0, } - self.channel_states = {f"{prefix}_q": 0.2, f"{prefix}_r": 0.2} + self.states = {f"{prefix}_q": 0.2, f"{prefix}_r": 0.2} self.current_name = f"i_Ca" def update_states( @@ -329,12 +329,12 @@ def __init__(self, name: Optional[str] = None): super().__init__(name) prefix = self._name - self.channel_params = { + self.params = { f"{prefix}_gCaT": 0.4e-4, f"{prefix}_vx": 2.0, "eCa": 120.0, # Global parameter, not prefixed with `CaT`. } - self.channel_states = {f"{prefix}_u": 0.2} + self.states = {f"{prefix}_u": 0.2} self.current_name = f"i_Ca" def update_states( diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 5ac3a4aa..14a30e38 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -223,7 +223,6 @@ def __getattr__(self, key): view._set_controlled_by_param(key) # overwrites param set by edge # Ensure synapse param sharing works with `edge` # `edge` will be removed as part of #463 - view.edges["local_edge_index"] = np.arange(len(view.edges)) return view def _childviews(self) -> List[str]: @@ -710,17 +709,30 @@ def to_jax(self): they can be processed on GPU/TPU and such that the simulation can be differentiated. `.to_jax()` copies the `.nodes` to `.jaxnodes`. """ - jaxnodes = self.base.jaxnodes = {} - nodes = self.base.nodes.to_dict(orient="list") + jaxnodes, jaxedges = {}, {} - jaxedges = self.base.jaxedges = {} - edges = self.base.edges.to_dict(orient="list") - edges.pop("type") # drop since column type is string + for jax_arrays, data, mechs in zip( + [jaxnodes, jaxedges], + [self.nodes, self.edges], + [self.channels, self.synapses], + ): + jax_arrays.update({"index": data.index.to_numpy()}) + all_inds = jax_arrays["index"] + for mech in mechs: + inds = ( + all_inds[data["type"] == mech._name] + if "type" in data.columns + else all_inds[self.nodes[mech._name]] + ) + states_params = list(mech.params) + list(mech.states) + params = data[states_params].loc[inds] + jax_arrays.update({mech._name: inds}) + jax_arrays.update(params.to_dict(orient="list")) - for jax_array, params in zip([jaxnodes, jaxedges], [nodes, edges]): - for key, value in params.items(): - inds = jnp.arange(len(value)) - jax_array[key] = jnp.asarray(value)[inds] + morph_params = ["radius", "length", "axial_resistivity", "capacitance"] + jaxnodes.update(self.nodes[["v"]+morph_params].to_dict(orient="list")) + jaxnodes = {k: jnp.asarray(v) for k, v in jaxnodes.items()} + jaxedges = {k: jnp.asarray(v) for k, v in jaxedges.items()} def show( self, @@ -753,12 +765,8 @@ def show( inds = [f"{s}_{i}" for i in inds for s in scopes] if indices else [] cols += inds cols += [ch._name for ch in self.channels] if channel_names else [] - cols += ( - sum([list(ch.channel_params) for ch in self.channels], []) if params else [] - ) - cols += ( - sum([list(ch.channel_states) for ch in self.channels], []) if states else [] - ) + cols += sum([list(ch.params) for ch in self.channels], []) if params else [] + cols += sum([list(ch.states) for ch in self.channels], []) if states else [] if not param_names is None: cols = ( @@ -887,12 +895,8 @@ def set_ncomp( start_idx = self.nodes["global_comp_index"].to_numpy()[0] nseg_per_branch = self.base.nseg_per_branch channel_names = [c._name for c in self.base.channels] - channel_param_names = list( - chain(*[c.channel_params for c in self.base.channels]) - ) - channel_state_names = list( - chain(*[c.channel_states for c in self.base.channels]) - ) + channel_param_names = list(chain(*[c.params for c in self.base.channels])) + channel_state_names = list(chain(*[c.states for c in self.base.channels])) radius_generating_fns = self.base._radius_generating_fns within_branch_radiuses = view["radius"].to_numpy() @@ -1213,11 +1217,12 @@ def get_all_parameters( A dictionary of all module parameters. """ params = {} - for key in ["radius", "length", "axial_resistivity", "capacitance"]: + morph_params = ["radius", "length", "axial_resistivity", "capacitance"] + for key in ["v"] + morph_params: params[key] = self.base.jaxnodes[key] for channel in self.base.channels: - for channel_params in channel.channel_params: + for channel_params in channel.params: params[channel_params] = self.base.jaxnodes[channel_params] for synapse_params in self.base.synapse_param_names: @@ -1250,7 +1255,7 @@ def _get_states_from_nodes_and_edges(self) -> Dict[str, jnp.ndarray]: states = {"v": self.base.jaxnodes["v"]} # Join node and edge states into a single state dictionary. for channel in self.base.channels: - for channel_states in channel.channel_states: + for channel_states in channel.states: states[channel_states] = self.base.jaxnodes[channel_states] for synapse_states in self.base.synapse_state_names: states[synapse_states] = self.base.jaxedges[synapse_states] @@ -1333,8 +1338,8 @@ def init_states(self, delta_t: float = 0.025): ].to_numpy() voltages = channel_nodes.loc[channel_indices, "v"].to_numpy() - channel_param_names = list(channel.channel_params.keys()) - channel_state_names = list(channel.channel_states.keys()) + channel_param_names = list(channel.params.keys()) + channel_state_names = list(channel.states.keys()) channel_states = query_states_and_params( states, channel_state_names, channel_indices ) @@ -1653,12 +1658,12 @@ def insert(self, channel: Channel): self.base.nodes.loc[self._nodes_in_view, name] = True # Loop over all new parameters, e.g. gNa, eNa. - for key in channel.channel_params: - self.base.nodes.loc[self._nodes_in_view, key] = channel.channel_params[key] + for key in channel.params: + self.base.nodes.loc[self._nodes_in_view, key] = channel.params[key] # Loop over all new parameters, e.g. gNa, eNa. - for key in channel.channel_states: - self.base.nodes.loc[self._nodes_in_view, key] = channel.channel_states[key] + for key in channel.states: + self.base.nodes.loc[self._nodes_in_view, key] = channel.states[key] @only_allow_module def step( @@ -1836,11 +1841,9 @@ def _step_channels_state( query_channel = lambda d, names: query_states_and_params( d, names, channel_inds ) - channel_param_names = list(channel.channel_params) + morph_params + channel_param_names = list(channel.params) + morph_params channel_params = query_channel(params, channel_param_names) - channel_state_names = ( - list(channel.channel_states) + self.membrane_current_names - ) + channel_state_names = list(channel.states) + self.membrane_current_names channel_states = query_channel(states, channel_state_names) # States updates. @@ -1884,9 +1887,9 @@ def _channel_currents( query_channel = lambda d, names: query_states_and_params( d, names, channel_inds ) - channel_param_names = list(channel.channel_params) + morph_params + channel_param_names = list(channel.params) + morph_params channel_params = query_channel(params, channel_param_names) - channel_states = query_channel(states, channel.channel_states) + channel_states = query_channel(states, channel.states) v_channel = voltages[channel_inds] v_and_perturbed = jnp.array([v_channel, v_channel + diff]) @@ -2406,13 +2409,9 @@ def _filter_trainables( ): pkey, pval = next(iter(params.items())) trainable_inds_in_view = None - if pkey in sum( - [list(c.channel_params.keys()) for c in self.base.channels], [] - ): + if pkey in sum([list(c.params.keys()) for c in self.base.channels], []): trainable_inds_in_view = np.intersect1d(inds, self._nodes_in_view) - elif pkey in sum( - [list(s.synapse_params.keys()) for s in self.base.synapses], [] - ): + elif pkey in sum([list(s.params.keys()) for s in self.base.synapses], []): trainable_inds_in_view = np.intersect1d(inds, self._edges_in_view) in_view = is_viewed == np.isin(inds, trainable_inds_in_view) @@ -2475,8 +2474,8 @@ def _set_synapses_in_view(self, pointer: Union[Module, View]): viewed_synapses += ( [syn] if in_view else [None] ) # padded with None to keep indices consistent - viewed_params += list(syn.synapse_params.keys()) if in_view else [] - viewed_states += list(syn.synapse_states.keys()) if in_view else [] + viewed_params += list(syn.params.keys()) if in_view else [] + viewed_states += list(syn.states.keys()) if in_view else [] self.synapses = viewed_synapses self.synapse_param_names = viewed_params self.synapse_state_names = viewed_states diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index 103add35..ba477870 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -269,8 +269,8 @@ def _step_synapse_state( edge_inds = group.index.to_numpy() query_syn = lambda d, names: query_states_and_params(d, names, edge_inds) - synapse_params = query_syn(params, synapse.synapse_params) - synapse_states = query_syn(states, synapse.synapse_states) + synapse_params = query_syn(params, synapse.params) + synapse_states = query_syn(states, synapse.states) # State updates. states_updated = synapse.update_states( @@ -313,8 +313,8 @@ def _synapse_currents( edge_inds = group.index.to_numpy() query_syn = lambda d, names: query_states_and_params(d, names, edge_inds) - synapse_params = query_syn(params, synapse.synapse_params) - synapse_states = query_syn(states, synapse.synapse_states) + synapse_params = query_syn(params, synapse.params) + synapse_states = query_syn(states, synapse.states) v_pre, v_post = voltages[pre_inds], voltages[post_inds] pre_v_and_perturbed = jnp.array([v_pre, v_pre + diff]) @@ -535,8 +535,8 @@ def _infer_synapse_type_ind(self, synapse_name): def _update_synapse_state_names(self, synapse_type): # (Potentially) update variables that track meta information about synapses. self.base.synapse_names.append(synapse_type._name) - self.base.synapse_param_names += list(synapse_type.synapse_params.keys()) - self.base.synapse_state_names += list(synapse_type.synapse_states.keys()) + self.base.synapse_param_names += list(synapse_type.params.keys()) + self.base.synapse_state_names += list(synapse_type.states.keys()) self.base.synapses.append(synapse_type) def _append_multiple_synapses(self, pre_nodes, post_nodes, synapse_type): @@ -586,9 +586,9 @@ def _append_multiple_synapses(self, pre_nodes, post_nodes, synapse_type): def _add_params_to_edges(self, synapse_type, indices): # Add parameters and states to the `.edges` table. - for key, param_val in synapse_type.synapse_params.items(): + for key, param_val in synapse_type.params.items(): self.base.edges.loc[indices, key] = param_val # Update synaptic state array. - for key, state_val in synapse_type.synapse_states.items(): + for key, state_val in synapse_type.states.items(): self.base.edges.loc[indices, key] = state_val diff --git a/jaxley/synapses/ionotropic.py b/jaxley/synapses/ionotropic.py index da89113f..101dd95b 100644 --- a/jaxley/synapses/ionotropic.py +++ b/jaxley/synapses/ionotropic.py @@ -32,12 +32,12 @@ class IonotropicSynapse(Synapse): def __init__(self, name: Optional[str] = None): super().__init__(name) prefix = self._name - self.synapse_params = { + self.params = { f"{prefix}_gS": 1e-4, f"{prefix}_e_syn": 0.0, f"{prefix}_k_minus": 0.025, } - self.synapse_states = {f"{prefix}_s": 0.2} + self.states = {f"{prefix}_s": 0.2} def update_states( self, diff --git a/jaxley/synapses/synapse.py b/jaxley/synapses/synapse.py index a3b4752f..38cd7d3f 100644 --- a/jaxley/synapses/synapse.py +++ b/jaxley/synapses/synapse.py @@ -38,22 +38,22 @@ def change_name(self, new_name: str): new_prefix = new_name + "_" self._name = new_name - self.synapse_params = { + self.params = { ( new_prefix + key[len(old_prefix) :] if key.startswith(old_prefix) else key ): value - for key, value in self.synapse_params.items() + for key, value in self.params.items() } - self.synapse_states = { + self.states = { ( new_prefix + key[len(old_prefix) :] if key.startswith(old_prefix) else key ): value - for key, value in self.synapse_states.items() + for key, value in self.states.items() } return self diff --git a/jaxley/synapses/tanh_rate.py b/jaxley/synapses/tanh_rate.py index e006a278..6bbd49cc 100644 --- a/jaxley/synapses/tanh_rate.py +++ b/jaxley/synapses/tanh_rate.py @@ -16,12 +16,12 @@ class TanhRateSynapse(Synapse): def __init__(self, name: Optional[str] = None): super().__init__(name) prefix = self._name - self.synapse_params = { + self.params = { f"{prefix}_gS": 1e-4, f"{prefix}_x_offset": -70.0, f"{prefix}_slope": 1.0, } - self.synapse_states = {} + self.states = {} def update_states( self, diff --git a/jaxley/synapses/test.py b/jaxley/synapses/test.py index 49a7311e..84cb5d4d 100644 --- a/jaxley/synapses/test.py +++ b/jaxley/synapses/test.py @@ -19,8 +19,8 @@ class TestSynapse(Synapse): def __init__(self, name: Optional[str] = None): super().__init__(name) prefix = self._name - self.synapse_params = {f"{prefix}_gC": 1e-4} - self.synapse_states = {f"{prefix}_c": 0.2} + self.params = {f"{prefix}_gC": 1e-4} + self.states = {f"{prefix}_c": 0.2} def update_states( self, diff --git a/tests/test_channels.py b/tests/test_channels.py index 41024040..54ae8577 100644 --- a/tests/test_channels.py +++ b/tests/test_channels.py @@ -25,13 +25,13 @@ def __init__( ): self.current_is_in_mA_per_cm2 = True super().__init__(name) - self.channel_params = { + self.params = { f"{self._name}_gamma": 0.05, # Fraction of free calcium (not buffered) f"{self._name}_decay": 80, # Rate of removal of calcium in ms f"{self._name}_depth": 0.1, # Depth of shell in um f"{self._name}_minCai": 1e-4, # Minimum intracellular calcium concentration in mM } - self.channel_states = { + self.states = { f"CaCon_i": 5e-05, # Initial internal calcium concentration in mM } self.current_name = f"i_Ca" @@ -84,8 +84,8 @@ def __init__( "T": 279.45, # Kelvin (temperature) "R": 8.314, # J/(mol K) (gas constant) } - self.channel_params = {} - self.channel_states = {"eCa": 0.0, "CaCon_i": 5e-05, "CaCon_e": 2.0} + self.params = {} + self.states = {"eCa": 0.0, "CaCon_i": 5e-05, "CaCon_e": 2.0} self.current_name = f"i_Ca" def update_states(self, u, dt, voltages, params): @@ -117,21 +117,21 @@ def test_channel_set_name(): # channel name can be set in the constructor na = Na(name="NaPospischil") assert na.name == "NaPospischil" - assert "NaPospischil_gNa" in na.channel_params.keys() - assert "eNa" in na.channel_params.keys() - assert "NaPospischil_h" in na.channel_states.keys() - assert "NaPospischil_m" in na.channel_states.keys() - assert "NaPospischil_vt" not in na.channel_params.keys() - assert "vt" in na.channel_params.keys() + assert "NaPospischil_gNa" in na.params.keys() + assert "eNa" in na.params.keys() + assert "NaPospischil_h" in na.states.keys() + assert "NaPospischil_m" in na.states.keys() + assert "NaPospischil_vt" not in na.params.keys() + assert "vt" in na.params.keys() # channel name can not be changed directly k = K() with pytest.raises(AttributeError): k.name = "KPospischil" - assert "KPospischil_gNa" not in k.channel_params.keys() - assert "eNa" not in k.channel_params.keys() - assert "KPospischil_h" not in k.channel_states.keys() - assert "KPospischil_m" not in k.channel_states.keys() + assert "KPospischil_gNa" not in k.params.keys() + assert "eNa" not in k.params.keys() + assert "KPospischil_h" not in k.states.keys() + assert "KPospischil_m" not in k.states.keys() def test_channel_change_name(): @@ -139,12 +139,12 @@ def test_channel_change_name(): # (and only this way after initialization) na = Na().change_name("NaPospischil") assert na.name == "NaPospischil" - assert "NaPospischil_gNa" in na.channel_params.keys() - assert "eNa" in na.channel_params.keys() - assert "NaPospischil_h" in na.channel_states.keys() - assert "NaPospischil_m" in na.channel_states.keys() - assert "NaPospischil_vt" not in na.channel_params.keys() - assert "vt" in na.channel_params.keys() + assert "NaPospischil_gNa" in na.params.keys() + assert "eNa" in na.params.keys() + assert "NaPospischil_h" in na.states.keys() + assert "NaPospischil_m" in na.states.keys() + assert "NaPospischil_vt" not in na.params.keys() + assert "vt" in na.params.keys() def test_integration_with_renamed_channels(): @@ -201,12 +201,12 @@ def __init__(self, name: Optional[str] = None): self.current_is_in_mA_per_cm2 = True super().__init__(name) prefix = self._name - self.channel_params = { + self.params = { f"{prefix}_q10_ch": 3, f"{prefix}_q10_ch0": 22, "celsius": 22, } - self.channel_states = {f"{prefix}_m": 0.02, "CaCon_i": 1e-4} + self.states = {f"{prefix}_m": 0.02, "CaCon_i": 1e-4} self.current_name = f"i_K" def update_states( @@ -292,8 +292,8 @@ class User(Channel): def __init__(self, name: Optional[str] = None): self.current_is_in_mA_per_cm2 = True super().__init__(name) - self.channel_params = {} - self.channel_states = {"cumulative": 0.0} + self.params = {} + self.states = {"cumulative": 0.0} self.current_name = f"i_User" def update_states(self, states, dt, v, params): @@ -308,8 +308,8 @@ class Dummy1(Channel): def __init__(self, name: Optional[str] = None): self.current_is_in_mA_per_cm2 = True super().__init__(name) - self.channel_params = {} - self.channel_states = {} + self.params = {} + self.states = {} self.current_name = f"i_Dummy" def update_states(self, states, dt, v, params): @@ -322,8 +322,8 @@ class Dummy2(Channel): def __init__(self, name: Optional[str] = None): self.current_is_in_mA_per_cm2 = True super().__init__(name) - self.channel_params = {} - self.channel_states = {} + self.params = {} + self.states = {} self.current_name = f"i_Dummy" def update_states(self, states, dt, v, params): diff --git a/tests/test_shared_state.py b/tests/test_shared_state.py index 3e7642ce..9f9f66bd 100644 --- a/tests/test_shared_state.py +++ b/tests/test_shared_state.py @@ -22,8 +22,8 @@ class Dummy1(Channel): def __init__(self, name: Optional[str] = None): self.current_is_in_mA_per_cm2 = True super().__init__(name) - self.channel_params = {} - self.channel_states = {"Dummy_s": 0.0} + self.params = {} + self.states = {"Dummy_s": 0.0} self.current_name = f"i_Dummy1" @staticmethod @@ -45,8 +45,8 @@ class Dummy2(Channel): def __init__(self, name: Optional[str] = None): self.current_is_in_mA_per_cm2 = True super().__init__(name) - self.channel_params = {} - self.channel_states = {"Dummy_s": 0.0} + self.params = {} + self.states = {"Dummy_s": 0.0} self.current_name = f"i_Dummy2" @staticmethod @@ -68,10 +68,10 @@ class CaHVA(Channel): def __init__(self, name: Optional[str] = None): self.current_is_in_mA_per_cm2 = True super().__init__(name) - self.channel_params = { + self.params = { f"{self._name}_gCaHVA": 0.00001, # S/cm^2 } - self.channel_states = { + self.states = { f"{self._name}_m": 0.1, # Initial value for m gating variable f"{self._name}_h": 0.1, # Initial value for h gating variable "eCa": 0.0, # mV, assuming eca for demonstration @@ -140,13 +140,13 @@ def __init__( ): self.current_is_in_mA_per_cm2 = True super().__init__(name) - self.channel_params = { + self.params = { f"{self._name}_gamma": 0.05, # Fraction of free calcium (not buffered) f"{self._name}_decay": 80, # Rate of removal of calcium in ms f"{self._name}_depth": 0.1, # Depth of shell in um f"{self._name}_minCai": 1e-4, # Minimum intracellular calcium concentration in mM } - self.channel_states = { + self.states = { f"CaCon_i": 5e-05, # Initial internal calcium concentration in mM } self.current_name = f"i_Ca" diff --git a/tests/test_syn.py b/tests/test_syn.py index 89d5ab6f..3ae48123 100644 --- a/tests/test_syn.py +++ b/tests/test_syn.py @@ -30,7 +30,7 @@ def test_set_and_querying_params_one_type(): connect(pre, post, IonotropicSynapse()) # Get the synapse parameters to test setting - syn_params = list(IonotropicSynapse().synapse_params.keys()) + syn_params = list(IonotropicSynapse().params.keys()) for p in syn_params: net.set(p, 0.15) assert np.all(net.edges[p].to_numpy() == 0.15) diff --git a/tests/test_synapse_indexing.py b/tests/test_synapse_indexing.py index 136a38d7..8f414cdc 100644 --- a/tests/test_synapse_indexing.py +++ b/tests/test_synapse_indexing.py @@ -74,7 +74,7 @@ def test_set_and_querying_params_one_type(synapse_type): connect(pre, post, synapse_type) # Get the synapse parameters to test setting - syn_params = list(synapse_type.synapse_params.keys()) + syn_params = list(synapse_type.params.keys()) for p in syn_params: net.set(p, 0.15) assert np.all(net.edges[p].to_numpy() == 0.15) @@ -114,8 +114,8 @@ def test_set_and_querying_params_two_types(synapse_type): post = net.cell(post_ind).branch(0).loc(0.0) connect(pre, post, synapse) - type1_params = list(IonotropicSynapse().synapse_params.keys()) - synapse_type_params = list(synapse_type.synapse_params.keys()) + type1_params = list(IonotropicSynapse().params.keys()) + synapse_type_params = list(synapse_type.params.keys()) default_synapse_type = net.edges[synapse_type_params[0]].to_numpy()[[1, 3]]