diff --git a/jaxley/modules/compartment.py b/jaxley/modules/compartment.py index b7c76820a..e61c0fa4b 100644 --- a/jaxley/modules/compartment.py +++ b/jaxley/modules/compartment.py @@ -63,18 +63,17 @@ def __call__(self, loc: float): return super().adjust_view("comp_index", index) def connect(self, post, synapse_type): - if type(synapse_type).__name__ not in self.pointer.synapse_names: - new_type = True + synapse_name = type(synapse_type).__name__ + if synapse_name not in self.pointer.synapse_names: + new_synapse_type = True else: - new_type = False + new_synapse_type = False - if new_type: + if new_synapse_type: max_ind = self.pointer.syn_edges["type_ind"].max() + 1 type_ind = 0 if jnp.isnan(max_ind) else max_ind - else: - # TODO: here, we assume that synapses are added one type after another. - type_ind = self.pointer.syn_edges["type_ind"].to_numpy()[-1] + type_ind = self.pointer.syn_edges.query(f"type == '{synapse_name}'")["type_ind"].to_numpy()[0] pre_comp = loc_of_index( self.view["global_comp_index"].to_numpy(), self.pointer.nseg @@ -93,7 +92,7 @@ def connect(self, post, synapse_type): post_branch_index=post.view["branch_index"].to_numpy(), pre_cell_index=self.view["cell_index"].to_numpy(), post_cell_index=post.view["cell_index"].to_numpy(), - type=type(synapse_type).__name__, + type=synapse_name, type_ind=type_ind, global_pre_comp_index=self.view["global_comp_index"].to_numpy(), global_post_comp_index=post.view[ @@ -107,13 +106,14 @@ def connect(self, post, synapse_type): ].to_numpy(), ) ), - ] + ], + ignore_index=True ) self.pointer.syn_edges["index"] = list(self.pointer.syn_edges.index) for key in synapse_type.synapse_params: param_vals = jnp.asarray([synapse_type.synapse_params[key]]) - if new_type: + if new_synapse_type: self.pointer.syn_params[key] = param_vals else: self.pointer.syn_params[key] = jnp.concatenate( @@ -122,14 +122,14 @@ def connect(self, post, synapse_type): for key in synapse_type.synapse_states: state_vals = jnp.asarray([synapse_type.synapse_states[key]]) - if new_type: + if new_synapse_type: self.pointer.syn_states[key] = state_vals else: self.pointer.syn_states[key] = jnp.concatenate( [self.pointer.syn_states[key], state_vals] ) - if new_type: + if new_synapse_type: self.pointer.synapse_names.append(type(synapse_type).__name__) self.pointer.synapse_param_names.append(synapse_type.synapse_params.keys()) self.pointer.synapse_state_names.append(synapse_type.synapse_states.keys()) diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index a101008e2..8d97807bc 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -500,6 +500,7 @@ def set_params(self, key: str, val: float): assert ( key in self.pointer.synapse_param_names[self.view["type_ind"].values[0]] ), f"Parameter {key} does not exist in synapse of type {self.view['type'].values[0]}." + # TODO: have to reset the pointer here. self.pointer._set_params(key, val, self.view) def set_states(self, key: str, val: float):