Skip to content

Commit

Permalink
Tidy up tests and (slightly) improve num_iter handling.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeZiminski committed Jan 20, 2025
1 parent 93b89f6 commit abe27c1
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 192 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,10 @@ def compute_histogram_crosscorrelation(
center_bin = np.floor((num_bins * 2 - 1) / 2).astype(int)

# Create the (num windows, num_bins) matrix for this pair of sessions
num_iter = num_bins * 2 - 1 if not num_shifts else num_shifts * 2
shifts_array = np.arange(-(num_iter // 2), num_iter // 2 + 1)
if num_shifts is None:
num_shifts = num_bins - 1
shifts_array = np.arange(-(num_shifts), num_shifts + 1)
num_iter = shifts_array.size

for i in range(num_sessions):
for j in range(i, num_sessions):
Expand All @@ -291,7 +293,7 @@ def compute_histogram_crosscorrelation(
window_i = windowed_histogram_i - np.mean(windowed_histogram_i, axis=1)[:, np.newaxis]
window_j = windowed_histogram_j - np.mean(windowed_histogram_j, axis=1)[:, np.newaxis]

xcorr = np.zeros(num_iter + 1)
xcorr = np.zeros(num_iter)

for idx, shift in enumerate(shifts_array):
shifted_i = shift_array_fill_zeros(window_i, shift)
Expand All @@ -309,7 +311,7 @@ def compute_histogram_crosscorrelation(
mode="full",
)
if num_shifts:
window_indices = np.arange(center_bin - num_shifts, center_bin + num_shifts)
window_indices = np.arange(center_bin - num_shifts, center_bin + num_shifts + 1)
xcorr = xcorr[window_indices]

xcorr_matrix[win_idx, :] = xcorr
Expand All @@ -322,7 +324,6 @@ def compute_histogram_crosscorrelation(
if num_windows > 1 and smoothing_sigma_window:
xcorr_matrix = gaussian_filter(xcorr_matrix, smoothing_sigma_window, axes=0)

shifts_array = np.arange(-(num_iter // 2), num_iter // 2 + 1) # TODO: double check
# Upsample the cross-correlation
if interpolate:

Expand All @@ -337,44 +338,16 @@ def compute_histogram_crosscorrelation(
kriging_d,
)

# breakpoint()

xcorr_matrix_old = np.matmul(xcorr_matrix, K, axes=[(-2, -1), (-2, -1), (-2, -1)])
xcorr_matrix_ = np.zeros(
(xcorr_matrix.shape[0], shifts_upsampled.size)
) # TODO: check in nonlinear case
for i_ in range(xcorr_matrix.shape[0]):
xcorr_matrix_[i_, :] = np.matmul(xcorr_matrix[i_, :], K)

# breakpoint()

plt.plot(shifts_array, xcorr_matrix.T)
plt.show
plt.plot(shifts_upsampled, xcorr_matrix_.T)
plt.show()

xcorr_matrix = xcorr_matrix_

# plt.plot(xcorr_matrix.T)
# plt.plot(xcorr_matrix_old.T)
# plt.show()
#
xcorr_matrix = np.matmul(xcorr_matrix, K, axes=[(-2, -1), (-2, -1), (-2, -1)])

xcorr_peak = np.argmax(xcorr_matrix, axis=1)
shift = shifts_upsampled[xcorr_peak]

# breakpoint()

else:
xcorr_peak = np.argmax(xcorr_matrix, axis=1)
shift = shifts_array[xcorr_peak]

# x=i;y=j
# breakpoint()
shift_matrix[i, j, :] = shift

breakpoint()

# As xcorr shifts are symmetric, the shift matrix is skew symmetric, so fill
# the (empty) lower triangular with the negative (already computed) upper triangular to save computation
for k in range(shift_matrix.shape[2]):
Expand Down
Loading

0 comments on commit abe27c1

Please sign in to comment.