diff --git a/debugging/histogram1.npy b/debugging/histogram1.npy index 85fae37143..b98cfbaf11 100644 Binary files a/debugging/histogram1.npy and b/debugging/histogram1.npy differ diff --git a/debugging/histogram2.npy b/debugging/histogram2.npy index af6fa99ca8..293aae23c4 100644 Binary files a/debugging/histogram2.npy and b/debugging/histogram2.npy differ diff --git a/debugging/histogram3.npy b/debugging/histogram3.npy index 8d003239b1..48585ba89d 100644 Binary files a/debugging/histogram3.npy and b/debugging/histogram3.npy differ diff --git a/debugging/peak_locs_1.npy b/debugging/peak_locs_1.npy index 67d2aaa0fd..74e7b3177e 100644 Binary files a/debugging/peak_locs_1.npy and b/debugging/peak_locs_1.npy differ diff --git a/debugging/peak_locs_2.npy b/debugging/peak_locs_2.npy index 622cdbe65a..232a1b3087 100644 Binary files a/debugging/peak_locs_2.npy and b/debugging/peak_locs_2.npy differ diff --git a/debugging/peak_locs_3.npy b/debugging/peak_locs_3.npy index 9ff848da0f..0e8f0e045f 100644 Binary files a/debugging/peak_locs_3.npy and b/debugging/peak_locs_3.npy differ diff --git a/debugging/peaks_1.npy b/debugging/peaks_1.npy index 36e777c303..02bb934e1b 100644 Binary files a/debugging/peaks_1.npy and b/debugging/peaks_1.npy differ diff --git a/debugging/peaks_2.npy b/debugging/peaks_2.npy index d314d20050..095223b4ec 100644 Binary files a/debugging/peaks_2.npy and b/debugging/peaks_2.npy differ diff --git a/debugging/peaks_3.npy b/debugging/peaks_3.npy index 426cc94d4a..c66301ccfd 100644 Binary files a/debugging/peaks_3.npy and b/debugging/peaks_3.npy differ diff --git a/debugging/playing.py b/debugging/playing.py index cdec4d6481..9bbde26ae6 100644 --- a/debugging/playing.py +++ b/debugging/playing.py @@ -30,12 +30,13 @@ # -------------------------------------------------------------------------------------- recordings_list, _ = generate_session_displacement_recordings( - num_units=20, + num_units=65, recording_durations=[400, 400, 400], - recording_shifts=((0, 0), (0, -200), (0, 100)), # TODO: can see how well this is recaptured by comparing the displacements to the known displacement + gradient - non_rigid_gradient=None, # 0.1, - seed=2, # 52 + recording_shifts=((0, 0), (0, -200), (0, 150)), # TODO: can see how well this is recaptured by comparing the displacements to the known displacement + gradient + non_rigid_gradient=0.1, # 0.1, + seed=5, # 52 ) + if False: import numpy as np @@ -60,7 +61,6 @@ detect_kwargs={"method": "locally_exclusive"}, localize_peaks_kwargs={"method": "grid_convolution"}, ) - # if False: np.save("peaks_1.npy", peaks_list[0]) np.save("peaks_2.npy", peaks_list[1]) np.save("peaks_3.npy", peaks_list[2]) @@ -68,7 +68,7 @@ np.save("peak_locs_2.npy", peak_locations_list[1]) np.save("peak_locs_3.npy", peak_locations_list[2]) - # if False: + # if False: peaks_list = [np.load("peaks_1.npy"), np.load("peaks_2.npy"), np.load("peaks_3.npy")] peak_locations_list = [np.load("peak_locs_1.npy"), np.load("peak_locs_2.npy"), np.load("peak_locs_3.npy")] @@ -81,16 +81,16 @@ # See `session_alignment.py` for docs on these settings. non_rigid_window_kwargs = session_alignment.get_non_rigid_window_kwargs() - non_rigid_window_kwargs["rigid_mode"] = "nonrigid" + non_rigid_window_kwargs["rigid_mode"] = "rigid" non_rigid_window_kwargs["win_shape"] = "rect" - non_rigid_window_kwargs["win_step_um"] = 200.0 - non_rigid_window_kwargs["win_scale_um"] = 300.0 + non_rigid_window_kwargs["win_step_um"] = 300.0 + non_rigid_window_kwargs["win_scale_um"] = 400.0 estimate_histogram_kwargs = session_alignment.get_estimate_histogram_kwargs() estimate_histogram_kwargs["method"] = "chunked_median" estimate_histogram_kwargs["histogram_type"] = "activity_1d" # TODO: investigate this case thoroughly estimate_histogram_kwargs["bin_um"] = 5 - estimate_histogram_kwargs["log_scale"] = True + estimate_histogram_kwargs["log_scale"] = False estimate_histogram_kwargs["weight_with_amplitude"] = True compute_alignment_kwargs = session_alignment.get_compute_alignment_kwargs() diff --git a/src/spikeinterface/preprocessing/inter_session_alignment/session_alignment.py b/src/spikeinterface/preprocessing/inter_session_alignment/session_alignment.py index 9fd7f31395..cc9650acde 100644 --- a/src/spikeinterface/preprocessing/inter_session_alignment/session_alignment.py +++ b/src/spikeinterface/preprocessing/inter_session_alignment/session_alignment.py @@ -1,6 +1,7 @@ from __future__ import annotations from typing import TYPE_CHECKING +from numpy.lib.histograms import histogram if TYPE_CHECKING: from spikeinterface.core.baserecording import BaseRecording @@ -15,6 +16,8 @@ import copy import scipy import matplotlib.pyplot as plt +from scipy.ndimage import gaussian_filter +import matplotlib.pyplot as plt INTERP = "linear" @@ -560,6 +563,7 @@ def _get_single_session_activity_histogram( log_scale, bin_s=chunked_bin_size_s, depth_smooth_um=depth_smooth_um, + weight_with_amplitude=weight_with_amplitude, scale_to_hz=True, ) @@ -840,106 +844,15 @@ def _correct_session_displacement( return corrected_peak_locations_list, corrected_session_histogram_list -def cross_correlate(sig1, sig2, thr=None): - xcorr = np.correlate(sig1, sig2, mode="full") - - n = sig1.size - low_cut_idx = np.arange(0, n - thr) # double check - high_cut_idx = np.arange(n + thr, 2 * n - 1) - - xcorr[low_cut_idx] = 0 - xcorr[high_cut_idx] = 0 - - if np.max(xcorr) < 0.01: - shift = 0 - else: - shift = np.argmax(xcorr) - xcorr.size // 2 - - return shift - def _correlate(signal1, signal2): corr_value = ( np.corrcoef(signal1, signal2)[0, 1] ) - if False: - corr_value = np.correlate(signal1 - np.mean(signal1), signal2 - np.mean(signal2)) / signal1.size - return corr_value - -def cross_correlate_with_scale(x, signal1_blanked, signal2_blanked, thr=100, plot=True, round=0): - """ """ - best_correlation = 0 - best_displacements = np.zeros_like(signal1_blanked) - # TODO: use kriging interp - - xcorr = [] - - for scale in np.r_[np.linspace(0.85, 1, 10), np.linspace(1, 1.15, 10)]: # TODO: double 1 - - nonzero = np.where(signal1_blanked > 0)[0] - if not np.any(nonzero): - continue - - midpoint = nonzero[0] + np.ptp(nonzero) / 2 - x_scale = (x - midpoint) * scale + midpoint - - # interp_f = scipy.interpolate.interp1d( - # x_scale, signal1_blanked, fill_value=0.0, bounds_error=False - # ) # TODO: try cubic etc... or Kriging - - # scaled_func = interp_f(x) - - for sh in np.arange(-thr, thr): # TODO: we are off by one here - - # shift_signal1_blanked = alignment_utils.shift_array_fill_zeros(scaled_func, sh) - - x_shift = x_scale - sh - - interp_f = scipy.interpolate.interp1d( - x_shift, signal1_blanked, fill_value=0.0, bounds_error=False, kind=INTERP - ) - shift_signal1_blanked = interp_f(x) - - from scipy.ndimage import gaussian_filter - - corr_value = _correlate( - gaussian_filter(shift_signal1_blanked, 1.5), - gaussian_filter(signal2_blanked, 1.5) - ) - - if np.isnan(corr_value) or corr_value < 0: - corr_value = 0 - - if corr_value > best_correlation: - best_displacements = x_shift - best_correlation = corr_value - - # if plot and round == 1 and (corr_value > 0.3): # and plot and np.abs(sh) < 25: - # print("3") - # plt.plot(shift_signal1_blanked) - # plt.plot(signal2_blanked) - # plt.title(corr_value) - # plt.show() - # plt.draw() - # plt.pause(0.1) - # plt.clf() - if False and plot: - print("DONE)") - plt.plot(signal1_blanked) - plt.plot(signal2_blanked) - plt.show() - - interp_f = scipy.interpolate.interp1d( - best_displacements, signal1_blanked, fill_value=0.0, bounds_error=False, kind=INTERP - ) - final = interp_f(x) - plt.plot(final) - plt.plot(signal2_blanked) - plt.show() - - return best_displacements + corr_value = np.correlate(signal1 - np.mean(signal1), signal2 - np.mean(signal2)) / signal1.size + return corr_value def _interp(signal, current_coords, orig_coords): @@ -948,18 +861,39 @@ def _interp(signal, current_coords, orig_coords): ) return interp_f(orig_coords) +# Am not using amplitude properly! Revisit with aim to properly use amplitude information +# and preform the task as is done by eye! + +# 1) fix the interpolation issue for the 35-unit test case +# 2) fix the overall loss-tracking +# 3) combine pairwise and union approach +# 4) make a stimulation benchmark and test all +# 5) visualise the data after interpolation. Apply to real data +# 6) Revising the concept, the way of blanking is not very nice. It would +# be better to fix areas with regularisation and stop some regions moving into others. +# 7) dynamically adjust the allowed window movement +# 8) dynamically adjust the gaussian smoothing window (currently unused) +# 9) make a version for 2D + # Be extremely careful, if the fixed_indices is not correct, everything will be messed # up and it will be very hard to detect! # the windowing will just adjust blanked_mask and fixed_windows... # For windowing, it would be easier just to pass the sinals directly and get displacements back # 1) project 2) window 3) correlate and add shifts -def cross_correlate_with_scaled_fixed(x_orig, new_positions, blanked_mask, fixed_windows, histogram_array_blanked, i, j, thr, round_, plot): +# notes +# splitting window adds a lot of overhead and not sure it helps much. Makes very uninterpretable. +# better to iteratively maintain joint information for as long as possible +# TODO: +# definlately revise the window sizes. should go to zero basicall on the second of third window? Basically overfitting... + + +def cross_correlate_with_scaled_fixed(x_orig, new_positions, fixed_windows, histogram_array_blanked, i, j, thr, round_, plot): """ """ best_correlation = 0 best_positions = np.zeros(histogram_array_blanked.shape[1]) - histogram_array_blanked = histogram_array_blanked.copy() + # histogram_array_blanked = histogram_array_blanked.copy() for i_ in range(histogram_array_blanked.shape[0]): histogram_array_blanked[i_, fixed_windows[i_, :]] = 0 @@ -968,14 +902,11 @@ def cross_correlate_with_scaled_fixed(x_orig, new_positions, blanked_mask, fixed # just need to do this to find the scaling midpoint... signal1_for_midpoint = _interp(histogram_array_blanked[i, :], new_positions[i, :], x_orig) - signal1_for_midpoint[blanked_mask] = 0 nonzero = np.where(signal1_for_midpoint > 0)[0] if not np.any(nonzero): - return new_positions[i, :] # no change + return new_positions[i, :] midpoint = nonzero[0] + np.ptp(nonzero) / 2 - best_s1 = None - for scale in np.r_[np.linspace(0.85, 1, 10), np.linspace(1, 1.15, 10)]: # TODO: double 1 x_scale = (new_positions[i, :] - midpoint) * scale + midpoint @@ -984,28 +915,17 @@ def cross_correlate_with_scaled_fixed(x_orig, new_positions, blanked_mask, fixed putative_new_x = x_scale - sh - # x_shift_ = new_positions[i, :].copy() # TODO - # x_shift_[~fixed_windows[i, :]] = putative_new_x[~fixed_windows[i, :]] - # putative_new_x = x_shift_ - shift_signal1 = _interp(histogram_array_blanked[i, :], putative_new_x, x_orig) - from scipy.ndimage import gaussian_filter - - s1 = shift_signal1.copy() - s2 = signal2.copy() - s1[blanked_mask] = 0 - s2[blanked_mask] = 0 - corr_value = _correlate( - s1, # shift_signal1[~blanked_mask], # gaussian_filter(histogram_array_blanked[i, :], 0.5), # TODO: need to adapt to kinetics of the data - s2, # signal2[~blanked_mask], # gaussian_filter(signal2_blanked, 0.5) + shift_signal1, # gaussian_filter(histogram_array_blanked[i, :], 0.5), # TODO: need to adapt to kinetics of the data + signal2, # gaussian_filter(signal2_blanked, 0.5) ) - percent_diff = np.exp(-(np.abs(1 - np.sum(s1) / np.sum(histogram_array_blanked[i, :]))) ** 2 / 1.2 ** 2) ** 12 + percent_diff = np.exp(-(np.abs(1 - np.sum(shift_signal1) / np.sum(histogram_array_blanked[i, :]) ) ) ** 2 / 1.2 ** 2) # ** 6 corr_value *= percent_diff # heavily penalise interpolation errors - # corr_value *= 1 - np.abs(sh - 0) / thr + corr_value *= 1 - np.abs(sh - 0) / thr if np.isnan(corr_value) or corr_value < 0: corr_value = 0 @@ -1013,358 +933,233 @@ def cross_correlate_with_scaled_fixed(x_orig, new_positions, blanked_mask, fixed if corr_value > best_correlation: best_positions = putative_new_x best_correlation = corr_value - best_s1 = s1 - - if round_ > 0: - plt.plot(s1) - plt.plot(s2) - plt.title(corr_value) - plt.draw() - plt.pause(0.1) - plt.clf() new_pos = new_positions[i, :].copy() - new_pos[~fixed_windows[i, :]] = best_positions[~fixed_windows[i, :]] + new_pos[~fixed_windows[i, :]] = best_positions[~fixed_windows[i, :]] # hmm still not ideal... + return new_pos -def cross_correlate_combined_loss(x_orig, new_positions, fixed_windows, orig_blank_histograms, interp_blanked_histograms, thr, round_, plot): +def cross_correlate_combined_loss(x_orig, new_positions, fixed_windows, histogram_array, thr, round_, plot): """""" + blanked_histogram_array = histogram_array.copy() + for i in range(histogram_array.shape[0]): + blanked_histogram_array[i, fixed_windows[i, :]] = 0 - # while True: - for i in range(interp_blanked_histograms.shape[0]): + for i in range(blanked_histogram_array.shape[0]): best_correlation = 0 - best_positions = np.zeros_like(interp_blanked_histograms[i, :]) + best_positions = np.zeros_like(blanked_histogram_array[i, :]) - interp_blanked_histograms = np.zeros_like(orig_blank_histograms) - for j in range(orig_blank_histograms.shape[0]): - interp_f = scipy.interpolate.interp1d( - new_positions[j, :], orig_blank_histograms[j, :], fill_value=0.0, bounds_error=False, kind=INTERP - ) - interp_blanked_histograms[j, :] = interp_f(x_orig) + # Interpolate the other histograms. We use the i histogram for the midpoint but otherwise ignore + interp_blanked_histograms = np.zeros_like(blanked_histogram_array) + for j in range(blanked_histogram_array.shape[0]): + interp_blanked_histograms[j, :] = _interp(blanked_histogram_array[j, :], new_positions[j, :], x_orig) + + # find the scaling midpoint + nonzero = np.where(interp_blanked_histograms[i, :] > 0)[0] + if not np.any(nonzero): + print("CONTINUED") + continue + midpoint = nonzero[0] + np.ptp(nonzero) / 2 - for scale in np.r_[np.linspace(0.75, 1, 15), np.linspace(1, 1.25, 15)]: # TODO: double 1 + best_si = None - nonzero = np.where(interp_blanked_histograms[i, :] > 0)[0] - if not np.any(nonzero): - continue + for scale in np.r_[np.linspace(0.85, 1, 15), np.linspace(1, 1.15, 15)]: # TODO: double 1 - midpoint = nonzero[0] + np.ptp(nonzero) / 2 x_scale = (new_positions[i, :] - midpoint) * scale + midpoint for sh in np.arange(-thr, thr): # TODO: we are off by one here - x_shift = x_scale - sh - - x_shift_ = new_positions[i, :].copy() # TODO - x_shift_[~fixed_windows[i, :]] = x_shift[~fixed_windows[i, :]] - x_shift = x_shift_ + putative_new_x = x_scale - sh - interp_f = scipy.interpolate.interp1d( - x_shift, orig_blank_histograms[i, :], fill_value=0.0, bounds_error=False, kind=INTERP - ) - shift_signal1_blanked = interp_f(x_orig) - - from scipy.ndimage import gaussian_filter + signal_i_shift = _interp(blanked_histogram_array[i, :], putative_new_x, x_orig) corr_value = 0 - for j in range(interp_blanked_histograms.shape[0]): if i == j: continue + corr_value += _correlate(signal_i_shift,interp_blanked_histograms[j, :]) - # gaussian_filter(shift_signal1_blanked, 1.5), # TODO: need to adapt to kinetics of the data - # gaussian_filter(interp_blanked_histograms[j, :], 1.5) - corr_value += _correlate(gaussian_filter(shift_signal1_blanked, 1.5), gaussian_filter(interp_blanked_histograms[j, :], 1.5)) + percent_diff = np.exp(-(np.abs(1 - np.sum(signal_i_shift) / np.sum(interp_blanked_histograms[i, :]))) ** 2 / 1.2 ** 2) # ** 6 + corr_value *= percent_diff # heavily penalise interpolation errors # corr_value *= 1 - np.abs(sh - 0) / thr - percent_diff = np.exp(-(np.abs(1 - np.sum(shift_signal1_blanked) / np.sum(orig_blank_histograms[i, :])))**2/1.5**2)**6 - corr_value *= percent_diff # heavily penalise interpolation errors - if np.isnan(corr_value) or corr_value < 0: corr_value = 0 if corr_value > best_correlation: - best_positions = x_shift + best_positions = putative_new_x best_correlation = corr_value + best_si = signal_i_shift - plt.plot(shift_signal1_blanked) - for j in range(interp_blanked_histograms.shape[0]): - if i == j: - continue - plt.plot(interp_blanked_histograms[j, :].T) - plt.title(corr_value) - plt.draw() - plt.pause(0.1) - plt.clf() + if False: + plt.plot(signal_i_shift) + for j in range(interp_blanked_histograms.shape[0]): + if i == j: + continue + plt.plot(interp_blanked_histograms[j, :]) + plt.title(corr_value) + plt.draw() + plt.pause(0.1) + plt.clf() - - print("FINAL i update", i) - new_positions[i, :] = best_positions # new_positions[i, :] + (best_positions - x_orig) - - interp_f = scipy.interpolate.interp1d( - new_positions[i, :], orig_blank_histograms[i, :], fill_value=0.0, bounds_error=False, kind=INTERP - ) - interp_blanked_histograms[i, :] = interp_f(x_orig) - - if False: - print("AFTER)") - plt.plot(interp_blanked_histograms[i, :]) - for j in range(interp_blanked_histograms.shape[0]): - if i == j: - continue - plt.plot(interp_blanked_histograms[j, :].T) - plt.show() + new_positions[i, ~fixed_windows[i, :]] = best_positions[~fixed_windows[i, :]] return new_positions def get_threshold_array(num_bins, windows): num_points = len(windows) - max = num_bins // 2 + max = num_bins // 2 # TODO: should probably try both, and take maximmum! min = windows[0].size // 5 k = -np.log(min / (max - min)) / (num_points - 1) x_values = np.arange(num_points) all_thr = (max - min) * np.exp(-1.2 * x_values) + min return all_thr +def get_shifts_pairwise(signal1, signal2, windows, plot=True): - -def get_shifts_union(histogram_array, windows, plot=True): import matplotlib.pyplot as plt - plot = True - - histogram_array_blanked = histogram_array.copy() - - x_orig = np.arange(histogram_array_blanked.shape[1]) - new_positions = np.vstack([x_orig.copy()] * histogram_array_blanked.shape[0]) - fixed_windows = np.zeros_like(histogram_array_blanked).astype(bool) - - windows_to_run = np.arange(len(windows)) - - all_thr = get_threshold_array(histogram_array.shape[1], windows) # TOOD: tidy + signal1_blanked = signal1.copy(signal1) + signal2_blanked = signal2.copy(signal2) - loss = 0 + best_displacements = np.zeros_like(signal1) - blanked_mask = np.zeros(histogram_array.shape[1]).astype(bool) + x = np.arange(signal1_blanked.size) + x_orig = x.copy() - for round in range(len(windows)): + all_thr = get_threshold_array(signal1.size, windows) + for round in range(num_points): thr = all_thr[round] - # find contigious window ids - diffs = np.diff(windows_to_run) - block_boundaries = np.where(diffs > 1)[0] # Find indices where the difference is greater than 1 - all_blocks = np.split(windows_to_run, block_boundaries + 1) - - for block in all_blocks: - - print("BLOCK", block) - - window_indexes = [] - block_bools = np.ones(histogram_array.shape[1]).astype(bool) - for block_idx in block: - block_bools[windows[block_idx]] = False - window_indexes.append(windows[block_idx]) - window_indexes = np.hstack(window_indexes) - - if round< 100: # TODO: maybe some function of num windows? - shift_matrix = np.zeros((histogram_array.shape[0], histogram_array.shape[0], histogram_array.shape[1])) - - - y = np.zeros_like(histogram_array_blanked) - for i in range(histogram_array_blanked.shape[0]): - interpf = scipy.interpolate.interp1d( - new_positions[i, :], histogram_array_blanked[i, :], fill_value=0.0, - bounds_error=False, - kind=INTERP - ) - y[i, :] = interpf(x_orig) - - print("BEFORE") - plt.plot(y.T) - plt.show() - - - for i in range(histogram_array.shape[0]): - for j in range(histogram_array.shape[0]): - - fixed_windows_orig = np.ones_like(histogram_array).astype(bool) - - # TODO: DIRECT COPY!!! - if window_indexes[0] == 0: - window_min = np.min([new_positions[i, :], new_positions[j, :]]) - 1 - else: - window_min = window_indexes[0] - - if window_indexes[-1] == x_orig[-1]: - window_max = np.max([new_positions[i, :], new_positions[j, :]]) + 1 - else: - window_max = window_indexes[-1] - - fixed_indices = np.where(np.logical_and(new_positions[i, :] >= window_min, - new_positions[i, :] <= window_max)) - fixed_windows_orig[i, fixed_indices] = False # TODO: CAREUFULLY CHECK MAPPING - - fixed_indices = np.where(np.logical_and(new_positions[j, :] >= window_min, - new_positions[j, :] <= window_max)) - fixed_windows_orig[j, fixed_indices] = False # TODO: CAREUFULLY CHECK MAPPING - # DIRECT COPY END - - # from the window, and from the block - fixed_windows_round = np.logical_or(fixed_windows, fixed_windows_orig) # this is in orig space, different for all. - - blanked_mask_round = np.logical_or(blanked_mask, block_bools) # this is interp space, same for all + displacements = cross_correlate_with_scale(x, signal1_blanked, signal2_blanked, thr=thr, plot=plot, round=round) - shift_matrix[i, j, :] = cross_correlate_with_scaled_fixed( - x_orig, new_positions, blanked_mask, fixed_windows, histogram_array_blanked, i, j, thr=thr, round_=round, plot=plot - ) + interpf = scipy.interpolate.interp1d( + displacements, signal1_blanked, fill_value=0.0, bounds_error=False, kind=INTERP + ) + signal1_blanked = interpf(x) - this_round_new_positions = np.mean(shift_matrix, axis=1) # TODO: FIX! TODO: these are not displacements + window_corrs = np.empty(len(windows)) + for i, idx in enumerate(windows): + window_corrs[i] = _correlate(signal1_blanked[idx], signal2_blanked[idx]) + window_corrs[np.isnan(window_corrs)] = 0 + if np.any(window_corrs): + max_window = np.argmax(np.abs(window_corrs)) # TODO: cutoff! TODO: note sure about the abs, very weird edge case... + best_displacements[windows[max_window]] = displacements[windows[max_window]] + signal1_blanked[windows[max_window]] = 0 + signal2_blanked[windows[max_window]] = 0 + x = displacements + return np.floor(best_displacements - x_orig) - else: - # Not bad for evne no blanking! - fixed_windows_round = block_bools #np.logical_or(fixed_windows, block_bools) - histogram_array_blanked_interp_new = histogram_array_blanked_interp.copy() - histogram_array_blanked_new = histogram_array_blanked.copy() - for j in range(histogram_array_blanked_interp.shape[0]): - histogram_array_blanked_interp_new[j, ~fixed_windows_round] = 0 +# TODO: try running loss +# the nonlinear fundamentally doesn't work well for lots of units! +def get_shifts_union(histogram_array, windows, plot=True): + """ + """ + x_orig = np.arange(histogram_array.shape[1]) + new_positions = np.vstack([x_orig.copy()] * histogram_array.shape[0]) + fixed_windows = np.zeros_like(histogram_array).astype(bool) - # todo: direct copy - this_windows = [] - for block_idx in block: - this_windows.append(windows[block_idx]) - this_windows = np.hstack(this_windows) + all_thr = get_threshold_array(histogram_array.shape[1], windows) - if this_windows[0] == 0: - window_min = np.min(new_positions[j, :]) - 1 - else: - window_min = this_windows[0] + clipped_windows = [] - if this_windows[-1] == x_orig[-1]: - window_max = np.max(new_positions[j, :]) + 1 - else: - window_max = this_windows[-1] + for round in range(len(windows)): - fixed_indices = np.where( - np.logical_and(new_positions[j, :] > window_min, new_positions[j, :] < window_max) - ) + thr = all_thr[round] - histogram_array_blanked_new[j, fixed_indices] = 0 - print("INTERP") - plt.plot(histogram_array_blanked_interp_new.T) - plt.show() + if round == 0: + shift_matrix = np.zeros((histogram_array.shape[0], histogram_array.shape[0], histogram_array.shape[1])) - y = np.zeros_like(histogram_array_blanked) - for i in range(histogram_array_blanked.shape[0]): - interpf = scipy.interpolate.interp1d( - new_positions[i, :], histogram_array_blanked_new[i, :], fill_value=0.0, - bounds_error=False, - kind=INTERP + for i in range(histogram_array.shape[0]): + for j in range(histogram_array.shape[0]): + shift_matrix[i, j, :] = cross_correlate_with_scaled_fixed( + x_orig, new_positions, fixed_windows, histogram_array, i, j, thr=thr, round_=round, plot=plot ) - y[i, :] = interpf(x_orig) - - print("INTERPED") - plt.plot(y.T) - plt.show() - - this_round_new_positions = cross_correlate_combined_loss(x_orig, new_positions, fixed_windows, histogram_array_blanked_new, histogram_array_blanked_interp_new, thr, round, plot=True) - - + new_positions = np.mean(shift_matrix, axis=1) + else: + new_positions = cross_correlate_combined_loss(x_orig, new_positions, fixed_windows, histogram_array, thr, round, plot=True) - histogram_array_interp = np.zeros_like(histogram_array_blanked) - for i in range(histogram_array_blanked.shape[0]): - interpf = scipy.interpolate.interp1d( - this_round_new_positions[i, :], histogram_array_blanked[i, :], fill_value=0.0, bounds_error=False, - kind=INTERP - ) - histogram_array_interp[i, :] = interpf(x_orig) - histogram_array_interp[i, blanked_mask] = 0 + histogram_array_interp = np.zeros_like(histogram_array) + histogram_array_ = histogram_array.copy() + for i in range(histogram_array.shape[0]): + histogram_array_[i, fixed_windows[i, : ]] = 0 + histogram_array_interp[i, :] = _interp(histogram_array_[i, :], new_positions[i, :], x_orig) window_corrs = np.empty(len(windows)) # okay need to increase but shouldn't fail for one window for i, idx in enumerate(windows): - window_corrs[i] = np.sum(np.triu(np.corrcoef(histogram_array_interp[:, idx]), k=1)) # det doesn't work very well, too small - # plt.plot(histogram_array_interp[:, idx].T) - # plt.show() - window_corrs[np.isnan(window_corrs)] = 0 - window_corrs = np.abs(window_corrs) + window_corrs[i] = np.sum(np.triu(np.cov(histogram_array_interp[:, idx]), k=1)) # det doesn't work very well, too small + window_corrs[np.isnan(window_corrs)] = 0 + window_corrs[window_corrs < 0] = 0 print(window_corrs) + + # Now fix indices and blank in the originals space if np.any(window_corrs): max_window = np.argmax(np.abs(window_corrs)) # TODO: cutoff! TODO: note sure about the abs, very weird edge case... + clipped_windows.append(max_window) + for i in range(histogram_array.shape[0]): + # this should never happen! this is because windows are allowd to move into frozen + # windows. This should be explicitly handled! We now have to iterative across + # all windows because some signal ends up going into previously closed windows + for mw in clipped_windows: + if windows[mw][0] == 0: + window_min = np.min(new_positions[i, :]) - 1 + else: + window_min = windows[mw][0] - if windows[max_window][0] == 0: - window_min = np.min(this_round_new_positions) - 1 - else: - window_min = windows[max_window][0] + if windows[mw][-1] == x_orig[-1]: + window_max = np.max(new_positions[i, :]) + 1 + else: + window_max = windows[mw][-1] - if windows[max_window][-1] == x_orig[-1]: - window_max = np.max(this_round_new_positions) + 1 - else: - window_max = windows[max_window][-1] + fixed_indices = np.where(np.logical_and(np.ceil(new_positions[i, :]) >= window_min, np.floor(new_positions[i, :]) <= window_max)) + fixed_windows[i, fixed_indices] = True - fixed_indices = np.where(np.logical_and(this_round_new_positions[i, :] >= window_min, this_round_new_positions[i, :] <= window_max)) - fixed_windows[i, fixed_indices] = True # TODO: CAREUFULLY CHECK MAPPING + window = fixed_windows[i, :] + histogram_array_copy = histogram_array.copy() + histogram_array_copy[i, window] = 0 - blanked_mask[windows[max_window]] = True - window_corrs[max_window] = 0 - windows_to_run = np.delete(windows_to_run, np.where(windows_to_run == max_window)[0]) + # TODO: need to investigate the interpolation issiue + if False: + print("ORIG") + plt.plot(histogram_array_copy[i, :]) + plt.show() - # if round == 1 or not np.any(window_corrs > 0.1): # TODO: definately keep a running track of the xcorr and quit when it gets worse or doesn't improve. See how this example does across the rounds - # break + print("INTERP") + new = _interp(histogram_array_copy[i, :], new_positions[i, :], x_orig) + plt.plot(new) + plt.show() - # final = np.zeros_like(histogram_array_blanked) - # for i in range(histogram_array_blanked.shape[0]): - # interpf = scipy.interpolate.interp1d( - # this_round_new_positions[i], histogram_array[i, :], fill_value=0.0, bounds_error=False, kind=INTERP - # ) - # final[i, :] = interpf(x_orig) + breakpoint() - # loss_ = 0 # okay need to increase but shouldn't fail for one window - # loss_ += np.sum( - # np.triu(np.cov(histogram_array_interp), k=1) - # ) - new_positions = this_round_new_positions # TODO - - if not np.any(window_corrs > 0.01): # TODO: KEY <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< HADNLE THIS FIRST - break + window_corrs[max_window] = 0 - # if round == 1: - # break - # if round == 0: - # break - # print("loss_", loss_) - ## if loss_ < loss: - # break - # else: - # new_positions = this_round_new_positions # TODO - # loss = loss_ + if np.any(window_corrs > 0.001): # still not resolved, this is dependent on cov so changes every time. A running loss will be much better. + break -# print("FINAL") - # plt.plot(final.T) - # plt.show() + return np.ceil(x_orig - new_positions) # or round? - # going to have to check the improvement in fit for every round and - # if the round does not add much to the loss, then don't make the - # change from this round! - # if not np.any(window_corrs > 0.01): # TODO: KEY <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< HADNLE THIS FIRST - # break - return np.ceil(x_orig - new_positions) +######################################################################################################################## +# Combine this with above, this uses many older techniques so basically just leverage the above +# for this purpose +######################################################################################################################## def get_shifts_pairwise(signal1, signal2, windows, plot=True): @@ -1379,17 +1174,16 @@ def get_shifts_pairwise(signal1, signal2, windows, plot=True): x = np.arange(signal1_blanked.size) x_orig = x.copy() - all_thr = get_threshold_array(signal1.size, windows) # TOOD: tidy + all_thr = get_threshold_array(signal1.size, windows) for round in range(num_points): - thr = all_thr[round] # TODO: optimise this somehow? go back and forth? + thr = all_thr[round] - print(f"ROUND: {round}, THR: {thr}") displacements = cross_correlate_with_scale(x, signal1_blanked, signal2_blanked, thr=thr, plot=plot, round=round) interpf = scipy.interpolate.interp1d( displacements, signal1_blanked, fill_value=0.0, bounds_error=False, kind=INTERP - ) # TODO: move away from this indexing sceheme + ) signal1_blanked = interpf(x) window_corrs = np.empty(len(windows)) @@ -1407,19 +1201,62 @@ def get_shifts_pairwise(signal1, signal2, windows, plot=True): x = displacements - if False and plot: - print("FINAL") - plt.plot(signal1) - plt.plot(signal2) - plt.show() + return np.floor(best_displacements - x_orig) - interpf = scipy.interpolate.interp1d(best_displacements, signal1, fill_value=0.0, bounds_error=False, kind=INTERP) - final = interpf(x_orig) - plt.plot(final) - plt.plot(signal2) - plt.show() - return np.floor(best_displacements - x_orig) +def cross_correlate_with_scale(x, signal1_blanked, signal2_blanked, thr=100, plot=True, round=0): + """ """ + best_correlation = 0 + best_displacements = np.zeros_like(signal1_blanked) + + # TODO: use kriging interp + + xcorr = [] + + for scale in np.r_[np.linspace(0.85, 1, 10), np.linspace(1, 1.15, 10)]: # TODO: double 1 + + nonzero = np.where(signal1_blanked > 0)[0] + if not np.any(nonzero): + continue + + midpoint = nonzero[0] + np.ptp(nonzero) / 2 + x_scale = (x - midpoint) * scale + midpoint + + # interp_f = scipy.interpolate.interp1d( + # x_scale, signal1_blanked, fill_value=0.0, bounds_error=False + # ) # TODO: try cubic etc... or Kriging + + # scaled_func = interp_f(x) + + for sh in np.arange(-thr, thr): # TODO: we are off by one here + + # shift_signal1_blanked = alignment_utils.shift_array_fill_zeros(scaled_func, sh) + + x_shift = x_scale - sh + + interp_f = scipy.interpolate.interp1d( + x_shift, signal1_blanked, fill_value=0.0, bounds_error=False, kind=INTERP + ) + shift_signal1_blanked = interp_f(x) + + from scipy.ndimage import gaussian_filter + + corr_value = _correlate( + gaussian_filter(shift_signal1_blanked, 1.5), + gaussian_filter(signal2_blanked, 1.5) + ) + + if np.isnan(corr_value) or corr_value < 0: + corr_value = 0 + + if corr_value > best_correlation: + best_displacements = x_shift + best_correlation = corr_value + + return best_displacements + + +######################################################################################################################## def _compute_session_alignment( @@ -1494,52 +1331,24 @@ def _compute_session_alignment( nonrigid_session_offsets_matrix = np.empty((shifted_histograms.shape[0], shifted_histograms.shape[0])) - # windows = [] - # for i in range(non_rigid_windows.shape[0]): - # idxs = np.arange(non_rigid_windows.shape[1])[non_rigid_windows[i, :].astype(bool)] - # windows.append(idxs) - # TODO: check assumptions these are always the same size - # windows = np.vstack(windows) - num_windows = non_rigid_windows.shape[0] - windows = np.arange(shifted_histograms.shape[1]) windows = np.array_split(windows, num_windows) - # import matplotlib.pyplot as plt - # plt.plot(non_rigid_windows.T) - # plt.show() - # num_windows = - # windows1 = windows[::2, :] - nonrigid_session_offsets_matrix = np.empty( (shifted_histograms.shape[0], shifted_histograms.shape[0], spatial_bin_centers.size) ) - print("NUM WINDOWS: ", num_windows) - mode = "centered" if mode == "centered": - plot_ = False non_rigid_shifts = get_shifts_union(shifted_histograms, windows, plot_) - else: for i in range(shifted_histograms.shape[0]): for j in range(shifted_histograms.shape[0]): - - plot_ = False # i == 0 and j == 1 - print("I", i) - print("J", j) - + plot_ = False shifts1 = get_shifts_pairwise(shifted_histograms[i, :], shifted_histograms[j, :], windows, plot=plot_) - # shifts2 = get_shifts(shifted_histograms[i, :], shifted_histograms[j, :], windows2) - # shifts = np.empty(shifts1.size + shifts1.size - 1) - # breakpoint() - # shifts[::2] = shifts1 - # shifts[1::2] = (shifts1[:-1] + shifts1[1:]) / 2 # np.shifts2 - # breakpoint() nonrigid_session_offsets_matrix[i, j, :] = shifts1 # TODO: there are gaps in between rect, rect seems weird, they are non-overlapping :S diff --git a/src/spikeinterface/sortingcomponents/motion/motion_utils.py b/src/spikeinterface/sortingcomponents/motion/motion_utils.py index 00a7dd6e05..53a2099d71 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_utils.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_utils.py @@ -500,7 +500,7 @@ def make_2d_motion_histogram( arr[:, 1] = peak_locations[direction] if weight_with_amplitude: - weights = np.abs(peaks["amplitude"]) + weights = np.abs(peaks["amplitude"]) * 10 else: weights = None