From 221afde68b3f22587a5483e7b642804c8f7599d0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 31 May 2024 14:45:13 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/preprocessing/motion.py | 6 ++---- .../preprocessing/tests/test_motion.py | 2 +- .../sortingcomponents/motion_interpolation.py | 6 ++---- .../sortingcomponents/tests/common.py | 2 -- .../tests/test_motion_estimation.py | 2 +- .../tests/test_motion_interpolation.py | 7 +++++-- .../sortingcomponents/tests/test_motion_utils.py | 16 +++++++--------- 7 files changed, 18 insertions(+), 23 deletions(-) diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index a5300ccadc..8b89e2f545 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -384,9 +384,7 @@ def correct_motion( t1 = time.perf_counter() run_times["estimate_motion"] = t1 - t0 - recording_corrected = InterpolateMotionRecording( - recording, motion, **interpolate_motion_kwargs - ) + recording_corrected = InterpolateMotionRecording(recording, motion, **interpolate_motion_kwargs) if folder is not None: (folder / "run_times.json").write_text(json.dumps(run_times, indent=4), encoding="utf8") @@ -434,7 +432,7 @@ def load_motion_info(folder): motion_info[name] = np.load(folder / f"{name}.npy") else: motion_info[name] = None - + motion_info["motion"] = Motion.load(folder / "motion") return motion_info diff --git a/src/spikeinterface/preprocessing/tests/test_motion.py b/src/spikeinterface/preprocessing/tests/test_motion.py index d678b2d565..f42a64b90b 100644 --- a/src/spikeinterface/preprocessing/tests/test_motion.py +++ b/src/spikeinterface/preprocessing/tests/test_motion.py @@ -25,7 +25,7 @@ def test_estimate_and_correct_motion(): folder = cache_folder / "estimate_and_correct_motion" if folder.exists(): shutil.rmtree(folder) - + rec_corrected = correct_motion(rec, folder=folder) print(rec_corrected) diff --git a/src/spikeinterface/sortingcomponents/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion_interpolation.py index cbc24c83c3..d0bbbddd71 100644 --- a/src/spikeinterface/sortingcomponents/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion_interpolation.py @@ -3,8 +3,7 @@ import numpy as np from spikeinterface.core.core_tools import define_function_from_class from spikeinterface.preprocessing import get_spatial_interpolation_kernel -from spikeinterface.preprocessing.basepreprocessor import ( - BasePreprocessor, BasePreprocessorSegment) +from spikeinterface.preprocessing.basepreprocessor import BasePreprocessor, BasePreprocessorSegment from ..preprocessing.filter import fix_dtype @@ -346,8 +345,7 @@ def __init__( if recording.dtype.kind == "f": dtype = recording.dtype else: - raise ValueError( - f"Can't interpolate traces of recording with non-floating dtype={recording.dtype=}.") + raise ValueError(f"Can't interpolate traces of recording with non-floating dtype={recording.dtype=}.") dtype_ = fix_dtype(recording, dtype) BasePreprocessor.__init__(self, recording, channel_ids=channel_ids, dtype=dtype_) diff --git a/src/spikeinterface/sortingcomponents/tests/common.py b/src/spikeinterface/sortingcomponents/tests/common.py index 84d532d3aa..01e4445a13 100644 --- a/src/spikeinterface/sortingcomponents/tests/common.py +++ b/src/spikeinterface/sortingcomponents/tests/common.py @@ -3,7 +3,6 @@ from spikeinterface.core import generate_ground_truth_recording - def make_dataset(): # this replace the MEArec 10s file for testing recording, sorting = generate_ground_truth_recording( @@ -23,4 +22,3 @@ def make_dataset(): seed=2205, ) return recording, sorting - diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py b/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py index e842d876a2..945aa6a09e 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py @@ -200,7 +200,7 @@ def test_estimate_motion(): # same params with differents engine should be the same motion0, motion1 = motions["rigid / decentralized / torch"], motions["rigid / decentralized / numpy"] - assert (motion0 == motion1) + assert motion0 == motion1 motion0, motion1 = ( motions["rigid / decentralized / torch / time_horizon_s"], diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py b/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py index b97040a740..1de0337ec0 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py @@ -5,8 +5,11 @@ import spikeinterface.core as sc from spikeinterface import download_dataset from spikeinterface.sortingcomponents.motion_interpolation import ( - InterpolateMotionRecording, correct_motion_on_peaks, interpolate_motion, - interpolate_motion_on_traces) + InterpolateMotionRecording, + correct_motion_on_peaks, + interpolate_motion, + interpolate_motion_on_traces, +) from spikeinterface.sortingcomponents.motion_utils import Motion from spikeinterface.sortingcomponents.tests.common import make_dataset diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py b/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py index a170245d7d..8a62ef324b 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py +++ b/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py @@ -14,15 +14,13 @@ def test_Motion(): - temporal_bins_s = np.arange(0., 10., 1.) - spatial_bins_um = np.array([100., 200.]) + temporal_bins_s = np.arange(0.0, 10.0, 1.0) + spatial_bins_um = np.array([100.0, 200.0]) displacement = np.zeros((temporal_bins_s.shape[0], spatial_bins_um.shape[0])) displacement[:, :] = np.linspace(-20, 20, temporal_bins_s.shape[0])[:, np.newaxis] - motion = Motion( - displacement, temporal_bins_s, spatial_bins_um, direction="y" - ) + motion = Motion(displacement, temporal_bins_s, spatial_bins_um, direction="y") print(motion) # serialize with pickle before interpolation fit @@ -40,16 +38,16 @@ def test_Motion(): assert motion2.interpolator is None # do interpolate - displacement = motion.get_displacement_at_time_and_depth([2, 4.4, 11], [120., 80., 150.]) + displacement = motion.get_displacement_at_time_and_depth([2, 4.4, 11], [120.0, 80.0, 150.0]) # print(displacement) assert displacement.shape[0] == 3 # check clip - assert displacement[2] == 20. + assert displacement[2] == 20.0 # interpolate grid - displacement = motion.get_displacement_at_time_and_depth([2, 4.4, 11, 15, 19], [150., 80.], grid=True) + displacement = motion.get_displacement_at_time_and_depth([2, 4.4, 11, 15, 19], [150.0, 80.0], grid=True) assert displacement.shape == (2, 5) - assert displacement[0, 2] == 20. + assert displacement[0, 2] == 20.0 # save/load to folder folder = cache_folder / "motion_saved"