Skip to content

Commit

Permalink
fix several bugs (make it work in the first place
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Jan 16, 2024
1 parent 0ac7d75 commit 815877b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
12 changes: 6 additions & 6 deletions jaxley/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,18 @@ def integrate(
# At least one stimulus was inserted.
if module.currents is not None or data_stimuli is not None:
if module.currents is not None:
i_current = module.currents.T
i_current = module.currents # Shape `(num_stimuli, time)`.
i_inds = module.current_inds.comp_index.to_numpy()

if data_stimuli is not None:
# Append stimuli from `data_stimuli`.
i_current = jnp.concatenate(
[i_current, jnp.expand_dims(data_stimuli[0], axis=0)]
)
i_inds = np.concatenate([i_inds, data_stimuli[1].comp_index.to_numpy()])
i_current = jnp.concatenate([i_current, data_stimuli[0]])
i_inds = jnp.concatenate([i_inds, data_stimuli[1].comp_index.to_numpy()])
else:
i_current = data_stimuli[0]
i_current = data_stimuli[0] # Shape `(num_stimuli, time)`
i_inds = data_stimuli[1].comp_index.to_numpy()

i_current = i_current.T # Shape `(time, num_stimuli)`.
else:
# No stimulus was inserted.
i_current = jnp.asarray([[]]).astype("int32")
Expand Down
2 changes: 1 addition & 1 deletion jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,7 @@ def data_stimulate(
):
"""Insert a stimulus into the module within jit (or grad)."""
nodes = self.set_global_index_and_index(self.view)
self.pointer._data_stimulate(current, data_stimuli, nodes)
return self.pointer._data_stimulate(current, data_stimuli, nodes)

def set(self, key: str, val: float):
"""Set parameters of the pointer."""
Expand Down

0 comments on commit 815877b

Please sign in to comment.