Skip to content

Commit

Permalink
Continue adding tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeZiminski committed Jan 16, 2025
1 parent 070e33f commit 9da4cbf
Show file tree
Hide file tree
Showing 5 changed files with 346 additions and 136 deletions.
69 changes: 66 additions & 3 deletions playing.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,73 @@
from spikeinterface.generation import generate_drifting_recording
from spikeinterface.preprocessing.motion import correct_motion
from spikeinterface.sortingcomponents.motion.motion_interpolation import InterpolateMotionRecording
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from spikeinterface.generation.session_displacement_generator import generate_session_displacement_recordings
from spikeinterface.generation import generate_ground_truth_recording
from spikeinterface.core import get_noise_levels
from spikeinterface.sortingcomponents.peak_localization import localize_peaks

rec = generate_drifting_recording(duration=100)[0]

proc_rec = correct_motion(rec)
recordings_list, _ = generate_session_displacement_recordings(
num_units=5,
recording_durations=[1, 1],
recording_shifts=((0, 0), (0, 250)),
# TODO: can see how well this is recaptured by comparing the displacements to the known displacement + gradient
non_rigid_gradient=None, # 0.1, # 0.1,
seed=55, # 52
generate_sorting_kwargs=dict(firing_rates=(100, 250), refractory_period_ms=4.0),
generate_unit_locations_kwargs=dict(
margin_um=0.0,
# if this is say 20, then units go off the edge of the probe and are such low amplitude they are not picked up.
minimum_z=0.0,
maximum_z=2.0,
minimum_distance=18.0,
max_iteration=100,
distance_strict=False,
),
generate_noise_kwargs=dict(noise_levels=(0.0, 1.0), spatial_decay=1.0),
)
rec = recordings_list[1]

rec.set_probe(rec.get_probe())
detect_kwargs = {
"method": "locally_exclusive",
"peak_sign": "neg",
"detect_threshold": 25,
"exclude_sweep_ms": 0.1,
"radius_um": 75,
"noise_levels": None,
"random_chunk_kwargs": {},
}
localize_peaks_kwargs = {"method": "grid_convolution"}

# noise_levels = get_noise_levels(rec, return_scaled=False)
rec_0 = recordings_list[0]
rec_1 = recordings_list[1]

peaks_before_0 = detect_peaks(rec_0, **detect_kwargs) # noise_levels=noise_levels,
peaks_before_1 = detect_peaks(rec_1, **detect_kwargs)

proc_rec_0, motion_info_0 = correct_motion(rec_0, preset="rigid_fast", detect_kwargs=detect_kwargs, localize_peaks_kwargs=localize_peaks_kwargs, output_motion_info=True)
proc_rec_1, motion_info_1 = correct_motion(rec_1, preset="rigid_fast", detect_kwargs=detect_kwargs, localize_peaks_kwargs=localize_peaks_kwargs, output_motion_info=True)

peaks_after_0 = detect_peaks(proc_rec_0, **detect_kwargs) # noise_levels=noise_levels
peaks_after_1 = detect_peaks(proc_rec_1, **detect_kwargs)


import spikeinterface.full as si
import matplotlib.pyplot as plt

# TODO: need to test multi-shank
plot = si.plot_traces(rec_1, order_channel_by_depth=True) # , time_range=(0, 0.1))
x = peaks_before_1["sample_index"] * (1/ rec_1.get_sampling_frequency())
y = rec_1.get_channel_locations()[peaks_before_1["channel_index"], 1]
plot.ax.scatter(x, y, color="r", s=2)
plt.show()

plot = si.plot_traces(proc_rec_1, order_channel_by_depth=True)
x = peaks_after_1["sample_index"] * (1/ proc_rec_1.get_sampling_frequency())
y = rec_1.get_channel_locations()[peaks_after_1["channel_index"], 1]
plot.ax.scatter(x, y, color="r", s=2)
plt.show()

breakpoint()
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# #############################################################################


def get_activity_histogram(
def get_2d_activity_histogram(
recording: BaseRecording,
peaks: np.ndarray,
peak_locations: np.ndarray,
Expand Down Expand Up @@ -74,9 +74,6 @@ def get_activity_histogram(
)
assert np.array_equal(generated_spatial_bin_edges, spatial_bin_edges), "TODO: remove soon after testing"

temporal_bin_centers = get_bin_centers(temporal_bin_edges)
spatial_bin_centers = get_bin_centers(spatial_bin_edges)

if scale_to_hz:
if bin_s is None:
scaler = 1 / recording.get_duration()
Expand All @@ -88,8 +85,12 @@ def get_activity_histogram(
if log_scale:
activity_histogram = np.log10(1 + activity_histogram) # TODO: make_2d_motion_histogram uses log2

temporal_bin_centers = get_bin_centers(temporal_bin_edges)
spatial_bin_centers = get_bin_centers(spatial_bin_edges)

return activity_histogram, temporal_bin_centers, spatial_bin_centers


def get_bin_centers(bin_edges):
return (bin_edges[1:] + bin_edges[:-1]) / 2

Expand Down Expand Up @@ -152,9 +153,6 @@ def get_chunked_hist_median(chunked_session_histograms):
""" """
median_hist = np.median(chunked_session_histograms, axis=0)

quartile_1 = np.percentile(chunked_session_histograms, 25, axis=0)
quartile_3 = np.percentile(chunked_session_histograms, 75, axis=0)

return median_hist


Expand Down Expand Up @@ -311,15 +309,6 @@ def compute_histogram_crosscorrelation(
windowed_histogram_j - np.mean(windowed_histogram_i),
mode="full",
)
import os
if "hello_world" in os.environ:
plt.plot(windowed_histogram_i)
plt.plot(windowed_histogram_j)
plt.show()

plt.plot(xcorr)
plt.show()

if num_shifts:
window_indices = np.arange(center_bin - num_shifts, center_bin + num_shifts)
xcorr = xcorr[window_indices]
Expand Down Expand Up @@ -436,7 +425,9 @@ def akima_interpolate_nonrigid_shifts(
interpolated from the non-rigid shifts.
"""
if Version(scipy.__version__) >= Version("1.4.0"):
import scipy

if Version(scipy.__version__) < Version("1.14.0"):
raise ImportError("Scipy version 14 or higher is required fro Akima interpolation.")

from scipy.interpolate import Akima1DInterpolator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def get_compute_alignment_kwargs() -> dict:
windows along the probe depth. See `get_spatial_windows`.
"""
return {
"num_shifts_global": None,
"num_shifts_block": 20,
"interpolate": False,
"interp_factor": 10,
"kriging_sigma": 1,
Expand All @@ -93,8 +95,6 @@ def get_non_rigid_window_kwargs():
"""
return {
"rigid": True,
"num_shifts_global": None,
"num_shifts_block": 20,
"win_shape": "gaussian",
"win_step_um": 50,
"win_scale_um": 50,
Expand All @@ -109,12 +109,13 @@ def get_interpolate_motion_kwargs():
see that class for parameter descriptions.
"""
return {
"border_mode": "force_zeros", # fixed as this until can figure out probe
"border_mode": "force_zeros", # fixed as this until can figure out probe
"spatial_interpolation_method": "kriging",
"sigma_um": 20.0,
"p": 2
"p": 2,
}


###############################################################################
# Public Entry Level Functions
###############################################################################
Expand Down Expand Up @@ -222,7 +223,12 @@ def align_sessions(

# Ensure list lengths match and all channel locations are the same across recordings.
_check_align_sessions_inputs(
recordings_list, peaks_list, peak_locations_list, alignment_order, estimate_histogram_kwargs, interpolate_motion_kwargs
recordings_list,
peaks_list,
peak_locations_list,
alignment_order,
estimate_histogram_kwargs,
interpolate_motion_kwargs,
)

print("Computing a single activity histogram from each session...")
Expand Down Expand Up @@ -311,7 +317,10 @@ def align_sessions_after_motion_correction(
)

motion_window_kwargs = copy.deepcopy(motion_kwargs_list[0])
if motion_window_kwargs["direction"] != "y":

if (
"direction" in motion_window_kwargs and motion_window_kwargs["direction"] != "y"
): # TODO: why is this not in all?
raise ValueError("motion correct must have been performed along the 'y' dimension.")

if align_sessions_kwargs is None:
Expand All @@ -322,24 +331,37 @@ def align_sessions_after_motion_correction(
# shifts together.
if (
"non_rigid_window_kwargs" in align_sessions_kwargs
and "nonrigid" in align_sessions_kwargs["non_rigid_window_kwargs"]["rigid_mode"]
and not align_sessions_kwargs["non_rigid_window_kwargs"]["rigid"]
):

# TODO: carefully walk through this function! and test all assumptions...
if not motion_window_kwargs["rigid"]:
print(
print( # TODO: make a warning
"Nonrigid inter-session alignment must use the motion correct "
"nonrigid settings. Overwriting any passed `non_rigid_window_kwargs` "
"with the motion object non_rigid_window_kwargs."
)
motion_window_kwargs.pop("method")
motion_window_kwargs.pop("direction")
non_rigid_window_kwargs = get_non_rigid_window_kwargs()

# TODO: generate function for replacing one dict into another?
for (
k,
v,
) in motion_window_kwargs.items(): # TODO: can get tighter alignment here with original implementation?
if k in non_rigid_window_kwargs:
non_rigid_window_kwargs[k] = v

align_sessions_kwargs = copy.deepcopy(align_sessions_kwargs)
align_sessions_kwargs["non_rigid_window_kwargs"] = motion_window_kwargs
align_sessions_kwargs["non_rigid_window_kwargs"] = non_rigid_window_kwargs

corrected_peak_locations = [
correct_motion_on_peaks(info["peaks"], info["peak_locations"], info["motion"], recording)
for info, recording in zip(motion_info_list, recordings_list)
]

return align_sessions(
recordings_list,
[info["peaks"] for info in motion_info_list],
[info["peak_locations"] for info in motion_info_list],
corrected_peak_locations,
**align_sessions_kwargs,
)

Expand Down Expand Up @@ -459,14 +481,14 @@ def _compute_session_histograms(
recording,
peaks,
peak_locations,
histogram_type,
spatial_bin_edges,
method,
log_scale,
chunked_bin_size_s,
depth_smooth_um,
weight_with_amplitude,
avg_in_bin,
histogram_type=histogram_type,
spatial_bin_edges=spatial_bin_edges,
method=method,
log_scale=log_scale,
chunked_bin_size_s=chunked_bin_size_s,
depth_smooth_um=depth_smooth_um,
weight_with_amplitude=weight_with_amplitude,
avg_in_bin=avg_in_bin,
)
temporal_bin_centers_list.append(temporal_bin_centers)
session_histogram_list.append(session_hist)
Expand Down Expand Up @@ -539,32 +561,31 @@ def _get_single_session_activity_histogram(
# full estimation for chunked bin size
if chunked_bin_size_s == "estimate":

one_bin_histogram, _, _ = alignment_utils.get_activity_histogram(
scaled_hist, _, _ = alignment_utils.get_2d_activity_histogram(
recording,
peaks,
peak_locations,
spatial_bin_edges,
log_scale=False,
bin_s=None,
depth_smooth_um=None,
scale_to_hz=False,
scale_to_hz=True,
weight_with_amplitude=False,
avg_in_bin=False,
)

# It is important that the passed histogram is scaled to firing rate in Hz
scaled_hist = one_bin_histogram / recording.get_duration() # TODO: why is this done here when have a scale_to_hz arg??!?
chunked_bin_size_s = alignment_utils.estimate_chunk_size(scaled_hist)
chunked_bin_size_s = np.min([chunked_bin_size_s, recording.get_duration()])

if histogram_type == "activity_1d":

chunked_histograms, chunked_temporal_bin_centers, _ = alignment_utils.get_activity_histogram(
chunked_histograms, chunked_temporal_bin_centers, _ = alignment_utils.get_2d_activity_histogram(
recording,
peaks,
peak_locations,
spatial_bin_edges,
log_scale,
log_scale=log_scale,
bin_s=chunked_bin_size_s,
depth_smooth_um=depth_smooth_um,
weight_with_amplitude=weight_with_amplitude,
Expand Down Expand Up @@ -656,9 +677,11 @@ def _create_motion_recordings(
motion,
interpolation_time_bin_centers_s=motion.temporal_bins_s,
interpolation_time_bin_edges_s=[np.array(recording.get_times()[0], recording.get_times()[-1])],
**interpolate_motion_kwargs
**interpolate_motion_kwargs,
)
corrected_recording = corrected_recording.set_probe(recording.get_probe()) # TODO: if this works, might need to do above
corrected_recording = corrected_recording.set_probe(
recording.get_probe()
) # TODO: if this works, might need to do above

corrected_recordings_list.append(corrected_recording)

Expand Down Expand Up @@ -840,8 +863,8 @@ def _compute_session_alignment(
session_histogram_array = np.array(session_histogram_list)

akima_interp_nonrigid = compute_alignment_kwargs.pop("akima_interp_nonrigid")
num_shifts_global = non_rigid_window_kwargs.pop("num_shifts_global")
num_shifts_block = non_rigid_window_kwargs.pop("num_shifts_block")
num_shifts_global = compute_alignment_kwargs.pop("num_shifts_global")
num_shifts_block = compute_alignment_kwargs.pop("num_shifts_block")

non_rigid_windows, non_rigid_window_centers = get_spatial_windows(
contact_depths, spatial_bin_centers, **non_rigid_window_kwargs
Expand Down Expand Up @@ -870,7 +893,7 @@ def _compute_session_alignment(

# Then compute the nonrigid shifts
nonrigid_session_offsets_matrix = alignment_utils.compute_histogram_crosscorrelation(
shifted_histograms, non_rigid_windows, num_shifts_block, **compute_alignment_kwargs
shifted_histograms, non_rigid_windows, num_shifts=num_shifts_block, **compute_alignment_kwargs
)
non_rigid_shifts = alignment_utils.get_shifts_from_session_matrix(alignment_order, nonrigid_session_offsets_matrix)

Expand Down Expand Up @@ -920,7 +943,7 @@ def _estimate_rigid_alignment(
rigid_session_offsets_matrix = alignment_utils.compute_histogram_crosscorrelation(
session_histogram_array,
rigid_window,
num_shifts,
num_shifts=num_shifts,
**compute_alignment_kwargs, # TODO: remove the copy above and pass directly. Consider removing this function...
)
optimal_shift_indices = alignment_utils.get_shifts_from_session_matrix(
Expand Down Expand Up @@ -968,7 +991,6 @@ def _check_align_sessions_inputs(
"performed using the same probe."
)


accepted_hist_methods = [
"entire_session",
"chunked_mean",
Expand Down Expand Up @@ -998,4 +1020,6 @@ def _check_align_sessions_inputs(
if ses_num == 0:
raise ValueError("`alignment_order` required the session number, not session index.")

assert interpolate_motion_kwargs["border_mode"] == "force_zeros", "InterpolateMotionRecording must be `force_zeros` until probe is figured out." # TODO: ask sam
assert (
interpolate_motion_kwargs["border_mode"] == "force_zeros"
), "InterpolateMotionRecording must be `force_zeros` until probe is figured out." # TODO: ask sam
4 changes: 4 additions & 0 deletions src/spikeinterface/preprocessing/motion.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,10 @@ def run_peak_detection_pipeline_node(recording, gather_mode, detect_kwargs, loca
from spikeinterface.core.node_pipeline import ExtractDenseWaveforms, run_node_pipeline
from spikeinterface.sortingcomponents.peak_localization import localize_peak_methods

# Don't modify the kwargs in place in case the caller requires them
detect_kwargs = copy.deepcopy(detect_kwargs)
localize_peaks_kwargs = copy.deepcopy(localize_peaks_kwargs)

# node detect
method = detect_kwargs.pop("method", "locally_exclusive")
method_class = detect_peak_methods[method]
Expand Down
Loading

0 comments on commit 9da4cbf

Please sign in to comment.