diff --git a/src/spikeinterface/sortingcomponents/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion_interpolation.py index d0bbbddd71..889e89446d 100644 --- a/src/spikeinterface/sortingcomponents/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion_interpolation.py @@ -3,7 +3,8 @@ import numpy as np from spikeinterface.core.core_tools import define_function_from_class from spikeinterface.preprocessing import get_spatial_interpolation_kernel -from spikeinterface.preprocessing.basepreprocessor import BasePreprocessor, BasePreprocessorSegment +from spikeinterface.preprocessing.basepreprocessor import ( + BasePreprocessor, BasePreprocessorSegment) from ..preprocessing.filter import fix_dtype @@ -285,7 +286,7 @@ class InterpolateMotionRecording(BasePreprocessor): Recording after motion correction """ - name = "correct_motion" + name = "interpolate_motion" def __init__( self, @@ -299,6 +300,7 @@ def __init__( interpolation_time_bin_centers_s=None, interpolation_time_bin_size_s=None, dtype=None, + **spatial_interpolation_kwargs, ): # assert recording.get_num_segments() == 1, "correct_motion() is only available for single-segment recordings" @@ -307,7 +309,9 @@ def __init__( f"'direction' {motion.direction} not available. " f"Channel locations have {channel_locations.ndim} dimensions." ) - spatial_interpolation_kwargs = dict(sigma_um=sigma_um, p=p, num_closest=num_closest) + spatial_interpolation_kwargs = dict( + sigma_um=sigma_um, p=p, num_closest=num_closest, **spatial_interpolation_kwargs + ) if border_mode == "remove_channels": locs = channel_locations[:, motion.dim] l0, l1 = np.min(locs), np.max(locs) diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py b/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py index d916102376..88908c5cc4 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py @@ -1,13 +1,12 @@ -import shutil from pathlib import Path import numpy as np import pytest from spikeinterface.core.node_pipeline import ExtractDenseWaveforms from spikeinterface.sortingcomponents.motion_estimation import estimate_motion -from spikeinterface.sortingcomponents.motion_interpolation import InterpolateMotionRecording from spikeinterface.sortingcomponents.peak_detection import detect_peaks -from spikeinterface.sortingcomponents.peak_localization import LocalizeCenterOfMass +from spikeinterface.sortingcomponents.peak_localization import \ + LocalizeCenterOfMass from spikeinterface.sortingcomponents.tests.common import make_dataset if hasattr(pytest, "global_test_folder"): @@ -153,7 +152,6 @@ def test_estimate_motion(): ) kwargs.update(cases_kwargs) - job_kwargs = dict(progress_bar=False) motion, extra_check = estimate_motion(recording, peaks, peak_locations, **kwargs) motions[name] = motion diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py b/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py index 1de0337ec0..ffed3e72fc 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py @@ -5,11 +5,8 @@ import spikeinterface.core as sc from spikeinterface import download_dataset from spikeinterface.sortingcomponents.motion_interpolation import ( - InterpolateMotionRecording, - correct_motion_on_peaks, - interpolate_motion, - interpolate_motion_on_traces, -) + InterpolateMotionRecording, correct_motion_on_peaks, interpolate_motion, + interpolate_motion_on_traces) from spikeinterface.sortingcomponents.motion_utils import Motion from spikeinterface.sortingcomponents.tests.common import make_dataset @@ -103,6 +100,22 @@ def test_interpolation_simple(): assert np.array_equal(traces_corrected[:, 0], np.ones(nt)) assert np.array_equal(traces_corrected[:, 1:], np.zeros((nt, nc0 - 1))) + # let's try a new version where we interpolate too slowly + rec_corrected = interpolate_motion( + rec, true_motion, spatial_interpolation_method="nearest", num_closest=2, interpolation_time_bin_size_s=2 + ) + traces_corrected = rec_corrected.get_traces() + assert traces_corrected.shape == (nc0, nc0) + # what happens with nearest here? + # well... due to rounding towards the nearest even number, the motion (which at + # these time bin centers is 0.5, 2.5, 4.5, ...) flips the signal's nearest + # neighbor back and forth between the first and second channels + assert np.all(traces_corrected[::2, 0] == 1) + assert np.all(traces_corrected[1::2, 0] == 0) + assert np.all(traces_corrected[1::2, 1] == 1) + assert np.all(traces_corrected[::2, 1] == 0) + assert np.all(traces_corrected[:, 2:] == 0) + def test_InterpolateMotionRecording(): rec, sorting = make_dataset()