Skip to content

Commit

Permalink
Add a test of time bin changing at interpolatino time
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed May 31, 2024
1 parent d99d05b commit d7b6a59
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 12 deletions.
10 changes: 7 additions & 3 deletions src/spikeinterface/sortingcomponents/motion_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import numpy as np
from spikeinterface.core.core_tools import define_function_from_class
from spikeinterface.preprocessing import get_spatial_interpolation_kernel
from spikeinterface.preprocessing.basepreprocessor import BasePreprocessor, BasePreprocessorSegment
from spikeinterface.preprocessing.basepreprocessor import (
BasePreprocessor, BasePreprocessorSegment)

from ..preprocessing.filter import fix_dtype

Expand Down Expand Up @@ -285,7 +286,7 @@ class InterpolateMotionRecording(BasePreprocessor):
Recording after motion correction
"""

name = "correct_motion"
name = "interpolate_motion"

def __init__(
self,
Expand All @@ -299,6 +300,7 @@ def __init__(
interpolation_time_bin_centers_s=None,
interpolation_time_bin_size_s=None,
dtype=None,
**spatial_interpolation_kwargs,
):
# assert recording.get_num_segments() == 1, "correct_motion() is only available for single-segment recordings"

Expand All @@ -307,7 +309,9 @@ def __init__(
f"'direction' {motion.direction} not available. "
f"Channel locations have {channel_locations.ndim} dimensions."
)
spatial_interpolation_kwargs = dict(sigma_um=sigma_um, p=p, num_closest=num_closest)
spatial_interpolation_kwargs = dict(
sigma_um=sigma_um, p=p, num_closest=num_closest, **spatial_interpolation_kwargs
)
if border_mode == "remove_channels":
locs = channel_locations[:, motion.dim]
l0, l1 = np.min(locs), np.max(locs)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import shutil
from pathlib import Path

import numpy as np
import pytest
from spikeinterface.core.node_pipeline import ExtractDenseWaveforms
from spikeinterface.sortingcomponents.motion_estimation import estimate_motion
from spikeinterface.sortingcomponents.motion_interpolation import InterpolateMotionRecording
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from spikeinterface.sortingcomponents.peak_localization import LocalizeCenterOfMass
from spikeinterface.sortingcomponents.peak_localization import \
LocalizeCenterOfMass
from spikeinterface.sortingcomponents.tests.common import make_dataset

if hasattr(pytest, "global_test_folder"):
Expand Down Expand Up @@ -153,7 +152,6 @@ def test_estimate_motion():
)
kwargs.update(cases_kwargs)

job_kwargs = dict(progress_bar=False)
motion, extra_check = estimate_motion(recording, peaks, peak_locations, **kwargs)
motions[name] = motion

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,8 @@
import spikeinterface.core as sc
from spikeinterface import download_dataset
from spikeinterface.sortingcomponents.motion_interpolation import (
InterpolateMotionRecording,
correct_motion_on_peaks,
interpolate_motion,
interpolate_motion_on_traces,
)
InterpolateMotionRecording, correct_motion_on_peaks, interpolate_motion,
interpolate_motion_on_traces)
from spikeinterface.sortingcomponents.motion_utils import Motion
from spikeinterface.sortingcomponents.tests.common import make_dataset

Expand Down Expand Up @@ -103,6 +100,22 @@ def test_interpolation_simple():
assert np.array_equal(traces_corrected[:, 0], np.ones(nt))
assert np.array_equal(traces_corrected[:, 1:], np.zeros((nt, nc0 - 1)))

# let's try a new version where we interpolate too slowly
rec_corrected = interpolate_motion(
rec, true_motion, spatial_interpolation_method="nearest", num_closest=2, interpolation_time_bin_size_s=2
)
traces_corrected = rec_corrected.get_traces()
assert traces_corrected.shape == (nc0, nc0)
# what happens with nearest here?
# well... due to rounding towards the nearest even number, the motion (which at
# these time bin centers is 0.5, 2.5, 4.5, ...) flips the signal's nearest
# neighbor back and forth between the first and second channels
assert np.all(traces_corrected[::2, 0] == 1)
assert np.all(traces_corrected[1::2, 0] == 0)
assert np.all(traces_corrected[1::2, 1] == 1)
assert np.all(traces_corrected[::2, 1] == 0)
assert np.all(traces_corrected[:, 2:] == 0)


def test_InterpolateMotionRecording():
rec, sorting = make_dataset()
Expand Down

0 comments on commit d7b6a59

Please sign in to comment.