Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 13, 2025
1 parent 80eaf24 commit 891d69c
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 49 deletions.
2 changes: 1 addition & 1 deletion debugging/playing.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,4 +137,4 @@
np.save("histogram1.npy", extra_info["session_histogram_list"][0])
np.save("histogram2.npy", extra_info["session_histogram_list"][1])
np.save("histogram3.npy", extra_info["session_histogram_list"][2])
breakpoint()
breakpoint()
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

INTERP = "linear"


def get_estimate_histogram_kwargs() -> dict:
"""
A dictionary controlling how the histogram for each session is
Expand Down Expand Up @@ -846,21 +847,17 @@ def _correct_session_displacement(

def _correlate(signal1, signal2):

corr_value = (
np.corrcoef(signal1,
signal2)[0, 1]
)
corr_value = np.corrcoef(signal1, signal2)[0, 1]

corr_value = np.correlate(signal1 - np.mean(signal1), signal2 - np.mean(signal2)) / signal1.size
return corr_value


def _interp(signal, current_coords, orig_coords):
interp_f = scipy.interpolate.interp1d(
current_coords, signal, fill_value=0.0, bounds_error=False, kind=INTERP
)
interp_f = scipy.interpolate.interp1d(current_coords, signal, fill_value=0.0, bounds_error=False, kind=INTERP)
return interp_f(orig_coords)


# Am not using amplitude properly! Revisit with aim to properly use amplitude information
# and preform the task as is done by eye!

Expand All @@ -887,13 +884,14 @@ def _interp(signal, current_coords, orig_coords):
# definlately revise the window sizes. should go to zero basicall on the second of third window? Basically overfitting...


def cross_correlate_with_scaled_fixed(x_orig, new_positions, fixed_windows, histogram_array_blanked, i, j, thr, round_, plot):
"""
"""
def cross_correlate_with_scaled_fixed(
x_orig, new_positions, fixed_windows, histogram_array_blanked, i, j, thr, round_, plot
):
""" """
best_correlation = 0
best_positions = np.zeros(histogram_array_blanked.shape[1])

# histogram_array_blanked = histogram_array_blanked.copy()
# histogram_array_blanked = histogram_array_blanked.copy()
for i_ in range(histogram_array_blanked.shape[0]):
histogram_array_blanked[i_, fixed_windows[i_, :]] = 0

Expand All @@ -918,11 +916,13 @@ def cross_correlate_with_scaled_fixed(x_orig, new_positions, fixed_windows, hist
shift_signal1 = _interp(histogram_array_blanked[i, :], putative_new_x, x_orig)

corr_value = _correlate(
shift_signal1, # gaussian_filter(histogram_array_blanked[i, :], 0.5), # TODO: need to adapt to kinetics of the data
signal2, # gaussian_filter(signal2_blanked, 0.5)
shift_signal1, # gaussian_filter(histogram_array_blanked[i, :], 0.5), # TODO: need to adapt to kinetics of the data
signal2, # gaussian_filter(signal2_blanked, 0.5)
)

percent_diff = np.exp(-(np.abs(1 - np.sum(shift_signal1) / np.sum(histogram_array_blanked[i, :]) ) ) ** 2 / 1.2 ** 2) # ** 6
percent_diff = np.exp(
-((np.abs(1 - np.sum(shift_signal1) / np.sum(histogram_array_blanked[i, :]))) ** 2) / 1.2**2
) # ** 6
corr_value *= percent_diff # heavily penalise interpolation errors

corr_value *= 1 - np.abs(sh - 0) / thr
Expand Down Expand Up @@ -979,10 +979,12 @@ def cross_correlate_combined_loss(x_orig, new_positions, fixed_windows, histogra
for j in range(interp_blanked_histograms.shape[0]):
if i == j:
continue
corr_value += _correlate(signal_i_shift,interp_blanked_histograms[j, :])
corr_value += _correlate(signal_i_shift, interp_blanked_histograms[j, :])

percent_diff = np.exp(-(np.abs(1 - np.sum(signal_i_shift) / np.sum(interp_blanked_histograms[i, :]))) ** 2 / 1.2 ** 2) # ** 6
corr_value *= percent_diff # heavily penalise interpolation errors
percent_diff = np.exp(
-((np.abs(1 - np.sum(signal_i_shift) / np.sum(interp_blanked_histograms[i, :]))) ** 2) / 1.2**2
) # ** 6
corr_value *= percent_diff # heavily penalise interpolation errors

# corr_value *= 1 - np.abs(sh - 0) / thr

Expand All @@ -1009,6 +1011,7 @@ def cross_correlate_combined_loss(x_orig, new_positions, fixed_windows, histogra

return new_positions


def get_threshold_array(num_bins, windows):
num_points = len(windows)
max = num_bins // 2 # TODO: should probably try both, and take maximmum!
Expand All @@ -1018,6 +1021,8 @@ def get_threshold_array(num_bins, windows):
x_values = np.arange(num_points)
all_thr = (max - min) * np.exp(-1.2 * x_values) + min
return all_thr


def get_shifts_pairwise(signal1, signal2, windows, plot=True):

import matplotlib.pyplot as plt
Expand Down Expand Up @@ -1048,7 +1053,9 @@ def get_shifts_pairwise(signal1, signal2, windows, plot=True):

window_corrs[np.isnan(window_corrs)] = 0
if np.any(window_corrs):
max_window = np.argmax(np.abs(window_corrs)) # TODO: cutoff! TODO: note sure about the abs, very weird edge case...
max_window = np.argmax(
np.abs(window_corrs)
) # TODO: cutoff! TODO: note sure about the abs, very weird edge case...

best_displacements[windows[max_window]] = displacements[windows[max_window]]

Expand All @@ -1063,8 +1070,7 @@ def get_shifts_pairwise(signal1, signal2, windows, plot=True):
# TODO: try running loss
# the nonlinear fundamentally doesn't work well for lots of units!
def get_shifts_union(histogram_array, windows, plot=True):
"""
"""
""" """
x_orig = np.arange(histogram_array.shape[1])
new_positions = np.vstack([x_orig.copy()] * histogram_array.shape[0])
fixed_windows = np.zeros_like(histogram_array).astype(bool)
Expand All @@ -1077,7 +1083,6 @@ def get_shifts_union(histogram_array, windows, plot=True):

thr = all_thr[round]


if round == 0:
shift_matrix = np.zeros((histogram_array.shape[0], histogram_array.shape[0], histogram_array.shape[1]))

Expand All @@ -1088,28 +1093,31 @@ def get_shifts_union(histogram_array, windows, plot=True):
)
new_positions = np.mean(shift_matrix, axis=1)
else:
new_positions = cross_correlate_combined_loss(x_orig, new_positions, fixed_windows, histogram_array, thr, round, plot=True)

new_positions = cross_correlate_combined_loss(
x_orig, new_positions, fixed_windows, histogram_array, thr, round, plot=True
)

histogram_array_interp = np.zeros_like(histogram_array)
histogram_array_ = histogram_array.copy()
for i in range(histogram_array.shape[0]):
histogram_array_[i, fixed_windows[i, : ]] = 0
histogram_array_[i, fixed_windows[i, :]] = 0
histogram_array_interp[i, :] = _interp(histogram_array_[i, :], new_positions[i, :], x_orig)


window_corrs = np.empty(len(windows)) # okay need to increase but shouldn't fail for one window
for i, idx in enumerate(windows):
window_corrs[i] = np.sum(np.triu(np.cov(histogram_array_interp[:, idx]), k=1)) # det doesn't work very well, too small
window_corrs[i] = np.sum(
np.triu(np.cov(histogram_array_interp[:, idx]), k=1)
) # det doesn't work very well, too small

window_corrs[np.isnan(window_corrs)] = 0
window_corrs[window_corrs < 0] = 0
print(window_corrs)


# Now fix indices and blank in the originals space
if np.any(window_corrs):
max_window = np.argmax(np.abs(window_corrs)) # TODO: cutoff! TODO: note sure about the abs, very weird edge case...
max_window = np.argmax(
np.abs(window_corrs)
) # TODO: cutoff! TODO: note sure about the abs, very weird edge case...
clipped_windows.append(max_window)

for i in range(histogram_array.shape[0]):
Expand All @@ -1127,7 +1135,11 @@ def get_shifts_union(histogram_array, windows, plot=True):
else:
window_max = windows[mw][-1]

fixed_indices = np.where(np.logical_and(np.ceil(new_positions[i, :]) >= window_min, np.floor(new_positions[i, :]) <= window_max))
fixed_indices = np.where(
np.logical_and(
np.ceil(new_positions[i, :]) >= window_min, np.floor(new_positions[i, :]) <= window_max
)
)
fixed_windows[i, fixed_indices] = True

window = fixed_windows[i, :]
Expand All @@ -1147,10 +1159,11 @@ def get_shifts_union(histogram_array, windows, plot=True):

breakpoint()


window_corrs[max_window] = 0

if np.any(window_corrs > 0.001): # still not resolved, this is dependent on cov so changes every time. A running loss will be much better.
if np.any(
window_corrs > 0.001
): # still not resolved, this is dependent on cov so changes every time. A running loss will be much better.
break

return np.ceil(x_orig - new_positions) # or round?
Expand Down Expand Up @@ -1192,7 +1205,9 @@ def get_shifts_pairwise(signal1, signal2, windows, plot=True):

window_corrs[np.isnan(window_corrs)] = 0
if np.any(window_corrs):
max_window = np.argmax(np.abs(window_corrs)) # TODO: cutoff! TODO: note sure about the abs, very weird edge case...
max_window = np.argmax(
np.abs(window_corrs)
) # TODO: cutoff! TODO: note sure about the abs, very weird edge case...

best_displacements[windows[max_window]] = displacements[windows[max_window]]

Expand All @@ -1219,14 +1234,14 @@ def cross_correlate_with_scale(x, signal1_blanked, signal2_blanked, thr=100, plo
if not np.any(nonzero):
continue

midpoint = nonzero[0] + np.ptp(nonzero) / 2
midpoint = nonzero[0] + np.ptp(nonzero) / 2
x_scale = (x - midpoint) * scale + midpoint

# interp_f = scipy.interpolate.interp1d(
# x_scale, signal1_blanked, fill_value=0.0, bounds_error=False
# ) # TODO: try cubic etc... or Kriging
# interp_f = scipy.interpolate.interp1d(
# x_scale, signal1_blanked, fill_value=0.0, bounds_error=False
# ) # TODO: try cubic etc... or Kriging

# scaled_func = interp_f(x)
# scaled_func = interp_f(x)

for sh in np.arange(-thr, thr): # TODO: we are off by one here

Expand All @@ -1241,10 +1256,7 @@ def cross_correlate_with_scale(x, signal1_blanked, signal2_blanked, thr=100, plo

from scipy.ndimage import gaussian_filter

corr_value = _correlate(
gaussian_filter(shift_signal1_blanked, 1.5),
gaussian_filter(signal2_blanked, 1.5)
)
corr_value = _correlate(gaussian_filter(shift_signal1_blanked, 1.5), gaussian_filter(signal2_blanked, 1.5))

if np.isnan(corr_value) or corr_value < 0:
corr_value = 0
Expand Down Expand Up @@ -1351,14 +1363,16 @@ def _compute_session_alignment(

nonrigid_session_offsets_matrix[i, j, :] = shifts1

# TODO: there are gaps in between rect, rect seems weird, they are non-overlapping :S
# TODO: there are gaps in between rect, rect seems weird, they are non-overlapping :S

# breakpoint()
# Then compute the nonrigid shifts
# nonrigid_session_offsets_matrix = alignment_utils.compute_histogram_crosscorrelation(
# shifted_histograms, non_rigid_windows, **compute_alignment_kwargs
# )
non_rigid_shifts = alignment_utils.get_shifts_from_session_matrix(alignment_order, -nonrigid_session_offsets_matrix) # nonrigid_session_offsets_matrix[0, :, :] #
# breakpoint()
# Then compute the nonrigid shifts
# nonrigid_session_offsets_matrix = alignment_utils.compute_histogram_crosscorrelation(
# shifted_histograms, non_rigid_windows, **compute_alignment_kwargs
# )
non_rigid_shifts = alignment_utils.get_shifts_from_session_matrix(
alignment_order, -nonrigid_session_offsets_matrix
) # nonrigid_session_offsets_matrix[0, :, :] #

non_rigid_window_centers = spatial_bin_centers
shifts = non_rigid_shifts
Expand Down

0 comments on commit 891d69c

Please sign in to comment.