Skip to content

Commit

Permalink
Bugfixes for set_params etc for synapses
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Dec 11, 2023
1 parent f2b9257 commit 5d49426
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 12 deletions.
24 changes: 12 additions & 12 deletions jaxley/modules/compartment.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,17 @@ def __call__(self, loc: float):
return super().adjust_view("comp_index", index)

def connect(self, post, synapse_type):
if type(synapse_type).__name__ not in self.pointer.synapse_names:
new_type = True
synapse_name = type(synapse_type).__name__
if synapse_name not in self.pointer.synapse_names:
new_synapse_type = True
else:
new_type = False
new_synapse_type = False

if new_type:
if new_synapse_type:
max_ind = self.pointer.syn_edges["type_ind"].max() + 1
type_ind = 0 if jnp.isnan(max_ind) else max_ind

else:
# TODO: here, we assume that synapses are added one type after another.
type_ind = self.pointer.syn_edges["type_ind"].to_numpy()[-1]
type_ind = self.pointer.syn_edges.query(f"type == '{synapse_name}'")["type_ind"].to_numpy()[0]

pre_comp = loc_of_index(
self.view["global_comp_index"].to_numpy(), self.pointer.nseg
Expand All @@ -93,7 +92,7 @@ def connect(self, post, synapse_type):
post_branch_index=post.view["branch_index"].to_numpy(),
pre_cell_index=self.view["cell_index"].to_numpy(),
post_cell_index=post.view["cell_index"].to_numpy(),
type=type(synapse_type).__name__,
type=synapse_name,
type_ind=type_ind,
global_pre_comp_index=self.view["global_comp_index"].to_numpy(),
global_post_comp_index=post.view[
Expand All @@ -107,13 +106,14 @@ def connect(self, post, synapse_type):
].to_numpy(),
)
),
]
],
ignore_index=True
)
self.pointer.syn_edges["index"] = list(self.pointer.syn_edges.index)

for key in synapse_type.synapse_params:
param_vals = jnp.asarray([synapse_type.synapse_params[key]])
if new_type:
if new_synapse_type:
self.pointer.syn_params[key] = param_vals
else:
self.pointer.syn_params[key] = jnp.concatenate(
Expand All @@ -122,14 +122,14 @@ def connect(self, post, synapse_type):

for key in synapse_type.synapse_states:
state_vals = jnp.asarray([synapse_type.synapse_states[key]])
if new_type:
if new_synapse_type:
self.pointer.syn_states[key] = state_vals
else:
self.pointer.syn_states[key] = jnp.concatenate(
[self.pointer.syn_states[key], state_vals]
)

if new_type:
if new_synapse_type:
self.pointer.synapse_names.append(type(synapse_type).__name__)
self.pointer.synapse_param_names.append(synapse_type.synapse_params.keys())
self.pointer.synapse_state_names.append(synapse_type.synapse_states.keys())
Expand Down
1 change: 1 addition & 0 deletions jaxley/modules/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,7 @@ def set_params(self, key: str, val: float):
assert (
key in self.pointer.synapse_param_names[self.view["type_ind"].values[0]]
), f"Parameter {key} does not exist in synapse of type {self.view['type'].values[0]}."
# TODO: have to reset the pointer here.
self.pointer._set_params(key, val, self.view)

def set_states(self, key: str, val: float):
Expand Down

0 comments on commit 5d49426

Please sign in to comment.