Skip to content

Commit

Permalink
Start adding tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeZiminski committed Jan 14, 2025
1 parent 0b829d9 commit 1bf0487
Show file tree
Hide file tree
Showing 7 changed files with 402 additions and 18 deletions.
4 changes: 3 additions & 1 deletion debugging/playing.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@
# Load / generate some recordings
# --------------------------------------------------------------------------------------

# try num units 5 and 65

recordings_list, _ = generate_session_displacement_recordings(
num_units=65,
num_units=5,
recording_durations=[200, 200, 200],
recording_shifts=((0, 0), (0, -200), (0, 150)), # TODO: can see how well this is recaptured by comparing the displacements to the known displacement + gradient
non_rigid_gradient=None, # 0.1, # 0.1,
Expand Down
10 changes: 10 additions & 0 deletions playing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from spikeinterface.generation import generate_drifting_recording
from spikeinterface.preprocessing.motion import correct_motion
from spikeinterface.sortingcomponents.motion.motion_interpolation import InterpolateMotionRecording

rec = generate_drifting_recording(duration=100)[0]

proc_rec = correct_motion(rec)

rec.set_probe(rec.get_probe())

Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def get_activity_histogram(
depth_smooth_um: float | None,
scale_to_hz: bool = False,
weight_with_amplitude: bool = False,
avg_in_bin: bool = True,
):
"""
Generate a 2D activity histogram for the session. Wraps the underlying
Expand Down Expand Up @@ -69,6 +70,7 @@ def get_activity_histogram(
hist_margin_um=None,
spatial_bin_edges=spatial_bin_edges,
depth_smooth_um=depth_smooth_um,
avg_in_bin=avg_in_bin,
)
assert np.array_equal(generated_spatial_bin_edges, spatial_bin_edges), "TODO: remove soon after testing"

Expand All @@ -88,7 +90,6 @@ def get_activity_histogram(

return activity_histogram, temporal_bin_centers, spatial_bin_centers


def get_bin_centers(bin_edges):
return (bin_edges[1:] + bin_edges[:-1]) / 2

Expand Down Expand Up @@ -310,6 +311,14 @@ def compute_histogram_crosscorrelation(
windowed_histogram_j - np.mean(windowed_histogram_i),
mode="full",
)
import os
if "hello_world" in os.environ:
plt.plot(windowed_histogram_i)
plt.plot(windowed_histogram_j)
plt.show()

plt.plot(xcorr)
plt.show()

if num_shifts:
window_indices = np.arange(center_bin - num_shifts, center_bin + num_shifts)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,6 @@
from spikeinterface.preprocessing.inter_session_alignment import alignment_utils
from spikeinterface.preprocessing.motion import run_peak_detection_pipeline_node
import copy
import scipy
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter
import matplotlib.pyplot as plt


def get_estimate_histogram_kwargs() -> dict:
Expand Down Expand Up @@ -54,7 +50,8 @@ def get_estimate_histogram_kwargs() -> dict:
"log_scale": False,
"depth_smooth_um": None,
"histogram_type": "activity_1d",
"weight_with_amplitude": True,
"weight_with_amplitude": False,
"avg_in_bin": False, # TODO
}


Expand Down Expand Up @@ -111,8 +108,12 @@ def get_interpolate_motion_kwargs():
Settings to pass to `InterpolateMotionRecording`,
see that class for parameter descriptions.
"""
return {"border_mode": "remove_channels", "spatial_interpolation_method": "kriging", "sigma_um": 20.0, "p": 2}

return {
"border_mode": "force_zeros", # fixed as this until can figure out probe
"spatial_interpolation_method": "kriging",
"sigma_um": 20.0,
"p": 2
}

###############################################################################
# Public Entry Level Functions
Expand Down Expand Up @@ -221,7 +222,7 @@ def align_sessions(

# Ensure list lengths match and all channel locations are the same across recordings.
_check_align_sessions_inputs(
recordings_list, peaks_list, peak_locations_list, alignment_order, estimate_histogram_kwargs
recordings_list, peaks_list, peak_locations_list, alignment_order, estimate_histogram_kwargs, interpolate_motion_kwargs
)

print("Computing a single activity histogram from each session...")
Expand Down Expand Up @@ -400,6 +401,7 @@ def _compute_session_histograms(
depth_smooth_um: float,
log_scale: bool,
weight_with_amplitude: bool,
avg_in_bin: bool,
) -> tuple[list[np.ndarray], list[np.ndarray], np.ndarray, np.ndarray, list[dict]]:
"""
Compute a 1d activity histogram for the session. As
Expand Down Expand Up @@ -464,6 +466,7 @@ def _compute_session_histograms(
chunked_bin_size_s,
depth_smooth_um,
weight_with_amplitude,
avg_in_bin,
)
temporal_bin_centers_list.append(temporal_bin_centers)
session_histogram_list.append(session_hist)
Expand All @@ -489,6 +492,7 @@ def _get_single_session_activity_histogram(
chunked_bin_size_s: float | "estimate",
depth_smooth_um: float,
weight_with_amplitude: bool,
avg_in_bin: bool,
) -> tuple[np.ndarray, np.ndarray, dict]:
"""
Compute an activity histogram for a single session.
Expand Down Expand Up @@ -544,11 +548,12 @@ def _get_single_session_activity_histogram(
bin_s=None,
depth_smooth_um=None,
scale_to_hz=False,
weight_with_amplitude=weight_with_amplitude,
weight_with_amplitude=False,
avg_in_bin=False,
)

# It is important that the passed histogram is scaled to firing rate in Hz
scaled_hist = one_bin_histogram / recording.get_duration()
scaled_hist = one_bin_histogram / recording.get_duration() # TODO: why is this done here when have a scale_to_hz arg??!?
chunked_bin_size_s = alignment_utils.estimate_chunk_size(scaled_hist)
chunked_bin_size_s = np.min([chunked_bin_size_s, recording.get_duration()])

Expand All @@ -563,6 +568,7 @@ def _get_single_session_activity_histogram(
bin_s=chunked_bin_size_s,
depth_smooth_um=depth_smooth_um,
weight_with_amplitude=weight_with_amplitude,
avg_in_bin=avg_in_bin,
scale_to_hz=True,
)

Expand Down Expand Up @@ -645,7 +651,14 @@ def _create_motion_recordings(

corrected_recording = _add_displacement_to_interpolate_recording(recording, motion)
else:
corrected_recording = InterpolateMotionRecording(recording, motion, **interpolate_motion_kwargs)
corrected_recording = InterpolateMotionRecording(
recording,
motion,
interpolation_time_bin_centers_s=motion.temporal_bins_s,
interpolation_time_bin_edges_s=[np.array(recording.get_times()[0], recording.get_times()[-1])],
**interpolate_motion_kwargs
)
corrected_recording = corrected_recording.set_probe(recording.get_probe()) # TODO: if this works, might need to do above

corrected_recordings_list.append(corrected_recording)

Expand Down Expand Up @@ -780,6 +793,7 @@ def _correct_session_displacement(
estimate_histogram_kwargs["chunked_bin_size_s"],
estimate_histogram_kwargs["depth_smooth_um"],
estimate_histogram_kwargs["weight_with_amplitude"],
estimate_histogram_kwargs["avg_in_bin"],
)
corrected_session_histogram_list.append(session_hist)

Expand Down Expand Up @@ -927,6 +941,7 @@ def _check_align_sessions_inputs(
peak_locations_list: list[np.ndarray],
alignment_order: str,
estimate_histogram_kwargs: dict,
interpolate_motion_kwargs: dict,
):
"""
Perform checks on the input of `align_sessions()`
Expand All @@ -946,13 +961,14 @@ def _check_align_sessions_inputs(
)

channel_locs = [rec.get_channel_locations() for rec in recordings_list]
if not all(np.array_equal(locs, channel_locs[0]) for locs in channel_locs):
if not all([np.array_equal(locs, channel_locs[0]) for locs in channel_locs]):
raise ValueError(
"The recordings in `recordings_list` do not all have "
"the same channel locations. All recordings must be "
"performed using the same probe."
)


accepted_hist_methods = [
"entire_session",
"chunked_mean",
Expand Down Expand Up @@ -981,3 +997,5 @@ def _check_align_sessions_inputs(

if ses_num == 0:
raise ValueError("`alignment_order` required the session number, not session index.")

assert interpolate_motion_kwargs["border_mode"] == "force_zeros", "InterpolateMotionRecording must be `force_zeros` until probe is figured out." # TODO: ask sam
Loading

0 comments on commit 1bf0487

Please sign in to comment.