diff --git a/debugging/playing.py b/debugging/playing.py index d405c3871a..dba45ad5b9 100644 --- a/debugging/playing.py +++ b/debugging/playing.py @@ -94,7 +94,7 @@ recordings_list, peaks_list, peak_locations_list, - alignment_order="to_session_1", # "to_session_X" or "to_middle" + alignment_order="to_session_2", # "to_session_X" or "to_middle" non_rigid_window_kwargs=non_rigid_window_kwargs, estimate_histogram_kwargs=estimate_histogram_kwargs, ) diff --git a/debugging/playing2.py b/debugging/playing2.py index 2b01faa334..63710586a8 100644 --- a/debugging/playing2.py +++ b/debugging/playing2.py @@ -33,66 +33,84 @@ def cross_correlate(sig1, sig2, thr= None): return shift -def cross_correlate_with_scale(signa11_blanked, signal2_blanked, thr=100, plot=True): +def cross_correlate_with_scale(x, signa11_blanked, signal2_blanked, thr=100, plot=True): """ """ + best_correlation = 0 + best_displacements = np.zeros_like(signa11_blanked) + + # TODO: use kriging interp + xcorr = [] - for s in np.arange(-thr, thr): # TODO: we are off by one here - shift_signal1_blanked = shift_array_fill_zeros(signa11_blanked, s) + for scale in np.linspace(0.85, 1.15, 10): + + nonzero = np.where(signa11_blanked > 0)[0] + if not np.any(nonzero): + continue + + midpoint = nonzero[0] + np.ptp(nonzero) / 2 + x_scale = (x - midpoint) * scale + midpoint + + interp_f = scipy.interpolate.interp1d(x_scale, signa11_blanked, fill_value=0.0, bounds_error=False) # TODO: try cubic etc... or Kriging + + scaled_func = interp_f(x) + + # plt.plot(signa11_blanked) + # plt.plot(scaled_func) + # plt.show() - x = np.arange(shift_signal1_blanked.size) + # breakpoint() - xcorr_scale = [] - for scale in np.linspace(0.75, 1.25, 10): + for sh in np.arange(-thr, thr): # TODO: we are off by one here - midpoint = np.argmax(shift_signal1_blanked) # assumes x is 0 .. n TODO: IMPROVE - xs = (x - midpoint) * scale + midpoint + shift_signal1_blanked = shift_array_fill_zeros(scaled_func, sh) + + x_shift = x_scale - sh # TODO: rename # is this pull back? - interp_f = scipy.interpolate.interp1d(xs, shift_signal1_blanked, fill_value=0.0, bounds_error=False) # TODO: try cubic etc... or Kriging + # interp_f = scipy.interpolate.interp1d(xs, shift_signal1_blanked, fill_value=0.0, bounds_error=False) # TODO: try cubic etc... or Kriging - scaled_func = interp_f(x) + # scaled_func = interp_f(x_shift) corr_value = np.correlate( - scaled_func - np.mean(scaled_func), + shift_signal1_blanked - np.mean(shift_signal1_blanked), signal2_blanked - np.mean(signal2_blanked), ) / signa11_blanked.size - xcorr_scale.append( - corr_value - ) + if corr_value > best_correlation: + best_displacements = x_shift + best_correlation = corr_value - if plot and corr_value > 0.0045: # and np.abs(s) < 10: + if False and np.abs(sh) == 1: print(corr_value) plt.plot(shift_signal1_blanked) plt.plot(signal2_blanked) plt.show() + # plt.draw() # Draw the updated figure + # plt.pause(0.1) # Pause for 0.5 seconds before updating + # plt.clf() - plt.plot(scaled_func) - plt.plot(signal2_blanked) - plt.show() - # plt.title(f"corr value: {corr_value}") - # plt.draw() # Draw the updated figure - # plt.pause(0.1) # Pause for 0.5 seconds before updating - # plt.clf() + # breakpoint() - xcorr.append(np.max(np.r_[xcorr_scale])) - xcorr = np.r_[xcorr] -# shift = np.argmax(xcorr) - thr + # xcorr.append(np.max(np.r_[xcorr_scale])) - print("MAX", np.max(xcorr)) + if False: + xcorr = np.r_[xcorr] + # shift = np.argmax(xcorr) - thr - if np.max(xcorr) < 0.0001: - shift = 0 - else: - shift = np.argmax(xcorr) - thr + print("MAX", np.max(xcorr)) + + if np.max(xcorr) < 0.0001: + shift = 0 + else: + shift = np.argmax(xcorr) - thr - print("output shift", shift) + print("output shift", shift) - return shift + return best_displacements # plt.plot(signal1) # plt.plot(signal2) @@ -104,6 +122,8 @@ def get_shifts(signal1, signal2, windows, plot=True): signa11_blanked = signal1.copy() signal2_blanked = signal2.copy() + best_displacements = np.zeros_like(signal1) + if (first_idx := windows[0][0]) != 0: print("first idx", first_idx) signa11_blanked[:first_idx] = 0 @@ -115,29 +135,39 @@ def get_shifts(signal1, signal2, windows, plot=True): signal2_blanked[last_idx:] = 0 segment_shifts = np.empty(len(windows)) - cum_shifts = [] + x = np.arange(signa11_blanked.size) + x_orig = x.copy() + for round in range(len(windows)): - if round == 0: - shift = cross_correlate(signa11_blanked, signal2_blanked, thr=100) # for first rigid, do larger! - else: - shift = cross_correlate_with_scale(signa11_blanked, signal2_blanked, thr=100, plot=False) + #if round == 0: + # shift = cross_correlate(signa11_blanked, signal2_blanked, thr=100) # for first rigid, do larger! + #else: + displacements = cross_correlate_with_scale(x, signa11_blanked, signal2_blanked, thr=200, plot=False) + + + # breakpoint() - cum_shifts.append(shift) - print("shift", shift) + interpf = scipy.interpolate.interp1d(displacements, signa11_blanked, fill_value=0.0, bounds_error=False) # TODO: move away from this indexing sceheme + signa11_blanked = interpf(x) + + + + # cum_shifts.append(shift) + # print("shift", shift) # shift the signal1, or use indexing - signa11_blanked = shift_array_fill_zeros(signa11_blanked, shift) +# signa11_blanked = shift_array_fill_zeros(signa11_blanked, shift) # INTERP HERE, KRIGING. but will accumulate interpolation errors... - if plot: - print("round", round) - plt.plot(signa11_blanked) - plt.plot(signal2_blanked) - plt.show() + # if plot: + # print("round", round) + # plt.plot(signa11_blanked) + # plt.plot(signal2_blanked) + # plt.show() window_corrs = np.empty(len(windows)) for i, idx in enumerate(windows): @@ -148,15 +178,28 @@ def get_shifts(signal1, signal2, windows, plot=True): max_window = np.argmax(window_corrs) # TODO: cutoff! - small_shift = cross_correlate(signa11_blanked[windows[max_window]], signal2_blanked[windows[max_window]], thr=windows[max_window].size //2) + if False: + small_shift = cross_correlate(signa11_blanked[windows[max_window]], signal2_blanked[windows[max_window]], thr=windows[max_window].size //2) + signa11_blanked = shift_array_fill_zeros(signa11_blanked, small_shift) + segment_shifts[max_window] = np.sum(cum_shifts) + small_shift - signa11_blanked = shift_array_fill_zeros(signa11_blanked, small_shift) + best_displacements[windows[max_window]] = displacements[windows[max_window]] - segment_shifts[max_window] = np.sum(cum_shifts) + small_shift + x = displacements signa11_blanked[windows[max_window]] = 0 signal2_blanked[windows[max_window]] = 0 + # TODO: need to carry over displacements! + + print(best_displacements) + interpf = scipy.interpolate.interp1d(best_displacements, signal1, fill_value=0.0, bounds_error=False) # TODO: move away from this indexing sceheme + final = interpf(x_orig) + + plt.plot(final) + plt.plot(signal2) + plt.show() + return segment_shifts diff --git a/src/spikeinterface/preprocessing/inter_session_alignment/session_alignment.py b/src/spikeinterface/preprocessing/inter_session_alignment/session_alignment.py index ba2cd55787..949a740bf0 100644 --- a/src/spikeinterface/preprocessing/inter_session_alignment/session_alignment.py +++ b/src/spikeinterface/preprocessing/inter_session_alignment/session_alignment.py @@ -13,6 +13,7 @@ from spikeinterface.preprocessing.inter_session_alignment import alignment_utils from spikeinterface.preprocessing.motion import run_peak_detection_pipeline_node import copy +import scipy def get_estimate_histogram_kwargs() -> dict: @@ -837,7 +838,7 @@ def _correct_session_displacement( return corrected_peak_locations_list, corrected_session_histogram_list -def cross_correlate(sig1, sig2, thr= None): +def cross_correlate(sig1, sig2, thr=None): xcorr = np.correlate(sig1, sig2, mode="full") n = sig1.size @@ -854,82 +855,94 @@ def cross_correlate(sig1, sig2, thr= None): return shift -def cross_correlate_with_scale(signa11_blanked, signal2_blanked, thr=100, plot=False): - """ - """ - import scipy + +def cross_correlate_with_scale(x, signa11_blanked, signal2_blanked, thr=100, plot=True): + """ """ + best_correlation = 0 + best_displacements = np.zeros_like(signa11_blanked) + + # TODO: use kriging interp + xcorr = [] - for s in np.arange(-thr, thr): # TODO: we are off by one here - shift_signal1_blanked = alignment_utils.shift_array_fill_zeros(signa11_blanked, s) + for scale in np.linspace(0.85, 1.15, 10): - x = np.arange(shift_signal1_blanked.size) + nonzero = np.where(signa11_blanked > 0)[0] + if not np.any(nonzero): + continue - xcorr_scale = [] - for scale in np.linspace(0.75, 1.25, 10): + midpoint = nonzero[0] + np.ptp(nonzero) / 2 + x_scale = (x - midpoint) * scale + midpoint + interp_f = scipy.interpolate.interp1d( + x_scale, signa11_blanked, fill_value=0.0, bounds_error=False + ) # TODO: try cubic etc... or Kriging - nonzero = np.where(shift_signal1_blanked > 0)[0] - if not np.any(nonzero): - xcorr_scale.append( - 0 - ) - continue + scaled_func = interp_f(x) + + # plt.plot(signa11_blanked) + # plt.plot(scaled_func) + # plt.show() + + # breakpoint() + + for sh in np.arange(-thr, thr): # TODO: we are off by one here - midpoint = nonzero[0] + np.ptp(nonzero) / 2 + shift_signal1_blanked = alignment_utils.shift_array_fill_zeros(scaled_func, sh) - xs = (x - midpoint) * scale + midpoint + x_shift = x_scale - sh # TODO: rename # is this pull back? - # TODO: maybe upsample 10x here... - interp_f = scipy.interpolate.interp1d(xs, shift_signal1_blanked, fill_value=0.0, bounds_error=False) # TODO: try cubic etc... or Kriging + # interp_f = scipy.interpolate.interp1d(xs, shift_signal1_blanked, fill_value=0.0, bounds_error=False) # TODO: try cubic etc... or Kriging - scaled_func = interp_f(x) + # scaled_func = interp_f(x_shift) - corr_value = np.correlate( - scaled_func - np.mean(scaled_func), + corr_value = ( + np.correlate( + shift_signal1_blanked - np.mean(shift_signal1_blanked), signal2_blanked - np.mean(signal2_blanked), - ) / signa11_blanked.size - - xcorr_scale.append( - corr_value + ) + / signa11_blanked.size ) - if plot and np.abs(s) < 2 and False: - import matplotlib.pyplot as plt + if corr_value > best_correlation: + best_displacements = x_shift + best_correlation = corr_value + + if False and np.abs(sh) == 1: print(corr_value) plt.plot(shift_signal1_blanked) plt.plot(signal2_blanked) plt.show() + # plt.draw() # Draw the updated figure + # plt.pause(0.1) # Pause for 0.5 seconds before updating + # plt.clf() - plt.plot(scaled_func) - plt.plot(signal2_blanked) - plt.show() - # plt.title(f"corr value: {corr_value}") - # plt.draw() # Draw the updated figure - # plt.pause(0.1) # Pause for 0.5 seconds before updating - # plt.clf() + # breakpoint() - xcorr.append(np.max(np.r_[xcorr_scale])) + # xcorr.append(np.max(np.r_[xcorr_scale])) - xcorr = np.r_[xcorr] -# shift = np.argmax(xcorr) - thr + if False: + xcorr = np.r_[xcorr] + # shift = np.argmax(xcorr) - thr - print("MAX", np.max(xcorr)) + print("MAX", np.max(xcorr)) - if np.max(xcorr) < 0.0001: - shift = 0 - else: - shift = np.argmax(xcorr) - thr + if np.max(xcorr) < 0.0001: + shift = 0 + else: + shift = np.argmax(xcorr) - thr - print("output shift", shift) + print("output shift", shift) + + return best_displacements - return shift # plt.plot(signal1) # plt.plot(signal2) + def get_shifts(signal1, signal2, windows, plot=True): import matplotlib.pyplot as plt @@ -937,6 +950,8 @@ def get_shifts(signal1, signal2, windows, plot=True): signa11_blanked = signal1.copy() signal2_blanked = signal2.copy() + best_displacements = np.zeros_like(signal1) + if (first_idx := windows[0][0]) != 0: print("first idx", first_idx) signa11_blanked[:first_idx] = 0 @@ -948,49 +963,78 @@ def get_shifts(signal1, signal2, windows, plot=True): signal2_blanked[last_idx:] = 0 segment_shifts = np.empty(len(windows)) - cum_shifts = [] + x = np.arange(signa11_blanked.size) + x_orig = x.copy() for round in range(len(windows)): - if round == 0: - shift = cross_correlate(signa11_blanked, signal2_blanked, thr=150) # for first rigid, do larger! - else: - shift = cross_correlate_with_scale(signa11_blanked, signal2_blanked, thr=150, plot=False) + # if round == 0: + # shift = cross_correlate(signa11_blanked, signal2_blanked, thr=100) # for first rigid, do larger! + # else: + displacements = cross_correlate_with_scale(x, signa11_blanked, signal2_blanked, thr=200, plot=False) + + # breakpoint() + interpf = scipy.interpolate.interp1d( + displacements, signa11_blanked, fill_value=0.0, bounds_error=False + ) # TODO: move away from this indexing sceheme + signa11_blanked = interpf(x) - cum_shifts.append(shift) - print("shift", shift) + # cum_shifts.append(shift) + # print("shift", shift) # shift the signal1, or use indexing - signa11_blanked = alignment_utils.shift_array_fill_zeros(signa11_blanked, shift) + # signa11_blanked = shift_array_fill_zeros(signa11_blanked, shift) # INTERP HERE, KRIGING. but will accumulate interpolation errors... - if plot and False: - print("round", round) - plt.plot(signa11_blanked) - plt.plot(signal2_blanked) - plt.show() + # if plot: + # print("round", round) + # plt.plot(signa11_blanked) + # plt.plot(signal2_blanked) + # plt.show() window_corrs = np.empty(len(windows)) for i, idx in enumerate(windows): - window_corrs[i] = np.correlate( - signa11_blanked[idx] - np.mean(signa11_blanked[idx]), - signal2_blanked[idx] - np.mean(signal2_blanked[idx]), - ) / signa11_blanked[idx].size + window_corrs[i] = ( + np.correlate( + signa11_blanked[idx] - np.mean(signa11_blanked[idx]), + signal2_blanked[idx] - np.mean(signal2_blanked[idx]), + ) + / signa11_blanked[idx].size + ) max_window = np.argmax(window_corrs) # TODO: cutoff! - small_shift = cross_correlate(signa11_blanked[windows[max_window]], signal2_blanked[windows[max_window]], thr=windows[max_window].size //2) + if False: + small_shift = cross_correlate( + signa11_blanked[windows[max_window]], + signal2_blanked[windows[max_window]], + thr=windows[max_window].size // 2, + ) + signa11_blanked = alignment_utils.shift_array_fill_zeros(signa11_blanked, small_shift) + segment_shifts[max_window] = np.sum(cum_shifts) + small_shift - signa11_blanked = alignment_utils.shift_array_fill_zeros(signa11_blanked, small_shift) + best_displacements[windows[max_window]] = displacements[windows[max_window]] - segment_shifts[max_window] = np.sum(cum_shifts) + small_shift + x = displacements signa11_blanked[windows[max_window]] = 0 signal2_blanked[windows[max_window]] = 0 - return segment_shifts + # TODO: need to carry over displacements! + + print(best_displacements) + interpf = scipy.interpolate.interp1d( + best_displacements, signal1, fill_value=0.0, bounds_error=False + ) # TODO: move away from this indexing sceheme + final = interpf(x_orig) + + # plt.plot(final) + # plt.plot(signal2) + # plt.show() + + return np.floor(best_displacements - x_orig) def _compute_session_alignment( @@ -1065,12 +1109,12 @@ def _compute_session_alignment( nonrigid_session_offsets_matrix = np.empty((shifted_histograms.shape[0], shifted_histograms.shape[0])) - # windows = [] - # for i in range(non_rigid_windows.shape[0]): - # idxs = np.arange(non_rigid_windows.shape[1])[non_rigid_windows[i, :].astype(bool)] - # windows.append(idxs) - # TODO: check assumptions these are always the same size -# windows = np.vstack(windows) + # windows = [] + # for i in range(non_rigid_windows.shape[0]): + # idxs = np.arange(non_rigid_windows.shape[1])[non_rigid_windows[i, :].astype(bool)] + # windows.append(idxs) + # TODO: check assumptions these are always the same size + # windows = np.vstack(windows) num_windows = non_rigid_windows.shape[0] @@ -1084,7 +1128,7 @@ def _compute_session_alignment( # windows1 = windows[::2, :] nonrigid_session_offsets_matrix = np.empty( - (shifted_histograms.shape[0], shifted_histograms.shape[0], non_rigid_windows.shape[0]) + (shifted_histograms.shape[0], shifted_histograms.shape[0], spatial_bin_centers.size) ) print("NUM WINDOWS: ", num_windows) @@ -1092,14 +1136,13 @@ def _compute_session_alignment( for i in range(shifted_histograms.shape[0]): for j in range(shifted_histograms.shape[0]): - plot_ = j == 2 and i == 0 - shifts1 = get_shifts(shifted_histograms[i, :], shifted_histograms[j, :], windows1, plot=True) - # shifts2 = get_shifts(shifted_histograms[i, :], shifted_histograms[j, :], windows2) - # shifts = np.empty(shifts1.size + shifts1.size - 1) + + # shifts2 = get_shifts(shifted_histograms[i, :], shifted_histograms[j, :], windows2) + # shifts = np.empty(shifts1.size + shifts1.size - 1) # breakpoint() - # shifts[::2] = shifts1 - # shifts[1::2] = (shifts1[:-1] + shifts1[1:]) / 2 # np.shifts2 + # shifts[::2] = shifts1 + # shifts[1::2] = (shifts1[:-1] + shifts1[1:]) / 2 # np.shifts2 # breakpoint() nonrigid_session_offsets_matrix[i, j, :] = shifts1 @@ -1110,21 +1153,26 @@ def _compute_session_alignment( # nonrigid_session_offsets_matrix = alignment_utils.compute_histogram_crosscorrelation( # shifted_histograms, non_rigid_windows, **compute_alignment_kwargs # ) - non_rigid_shifts = alignment_utils.get_shifts_from_session_matrix(alignment_order, nonrigid_session_offsets_matrix) - - # Akima interpolate the nonrigid bins if required. - if akima_interp_nonrigid: - interp_nonrigid_shifts = alignment_utils.akima_interpolate_nonrigid_shifts( - non_rigid_shifts, non_rigid_window_centers, spatial_bin_centers - ) - shifts = interp_nonrigid_shifts # rigid_shifts + interp_nonrigid_shifts - non_rigid_window_centers = spatial_bin_centers - else: - # TODO: so check + add a test, the interpolator will handle this? - shifts = non_rigid_shifts # rigid_shifts + non_rigid_shifts + non_rigid_shifts = nonrigid_session_offsets_matrix[ + 2, :, : + ] # alignment_utils.get_shifts_from_session_matrix(alignment_order, nonrigid_session_offsets_matrix) + non_rigid_window_centers = spatial_bin_centers + shifts = non_rigid_shifts + + if False: + # Akima interpolate the nonrigid bins if required. + if akima_interp_nonrigid: + interp_nonrigid_shifts = alignment_utils.akima_interpolate_nonrigid_shifts( + non_rigid_shifts, non_rigid_window_centers, spatial_bin_centers + ) + shifts = interp_nonrigid_shifts # rigid_shifts + interp_nonrigid_shifts + non_rigid_window_centers = spatial_bin_centers + else: + # TODO: so check + add a test, the interpolator will handle this? + shifts = non_rigid_shifts # rigid_shifts + non_rigid_shifts - if rigid_mode == "rigid_nonrigid": - shifts += rigid_shifts + if rigid_mode == "rigid_nonrigid": + shifts += rigid_shifts return shifts, non_rigid_windows, non_rigid_window_centers diff --git a/src/spikeinterface/sortingcomponents/motion/motion_utils.py b/src/spikeinterface/sortingcomponents/motion/motion_utils.py index e54721a447..00a7dd6e05 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_utils.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_utils.py @@ -59,7 +59,10 @@ def __init__(self, displacement, temporal_bins_s, spatial_bins_um, direction="y" def check_properties(self): assert all(d.ndim == 2 for d in self.displacement) assert all(t.ndim == 1 for t in self.temporal_bins_s) - assert all(self.spatial_bins_um.shape == (d.shape[1],) for d in self.displacement) + try: + assert all(self.spatial_bins_um.shape == (d.shape[1],) for d in self.displacement) + except: + breakpoint() def __repr__(self): nbins = self.spatial_bins_um.shape[0]