Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed May 31, 2024
1 parent 5df55f7 commit 221afde
Show file tree
Hide file tree
Showing 7 changed files with 18 additions and 23 deletions.
6 changes: 2 additions & 4 deletions src/spikeinterface/preprocessing/motion.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,9 +384,7 @@ def correct_motion(
t1 = time.perf_counter()
run_times["estimate_motion"] = t1 - t0

recording_corrected = InterpolateMotionRecording(
recording, motion, **interpolate_motion_kwargs
)
recording_corrected = InterpolateMotionRecording(recording, motion, **interpolate_motion_kwargs)

if folder is not None:
(folder / "run_times.json").write_text(json.dumps(run_times, indent=4), encoding="utf8")
Expand Down Expand Up @@ -434,7 +432,7 @@ def load_motion_info(folder):
motion_info[name] = np.load(folder / f"{name}.npy")
else:
motion_info[name] = None

motion_info["motion"] = Motion.load(folder / "motion")

return motion_info
2 changes: 1 addition & 1 deletion src/spikeinterface/preprocessing/tests/test_motion.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_estimate_and_correct_motion():
folder = cache_folder / "estimate_and_correct_motion"
if folder.exists():
shutil.rmtree(folder)

rec_corrected = correct_motion(rec, folder=folder)
print(rec_corrected)

Expand Down
6 changes: 2 additions & 4 deletions src/spikeinterface/sortingcomponents/motion_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import numpy as np
from spikeinterface.core.core_tools import define_function_from_class
from spikeinterface.preprocessing import get_spatial_interpolation_kernel
from spikeinterface.preprocessing.basepreprocessor import (
BasePreprocessor, BasePreprocessorSegment)
from spikeinterface.preprocessing.basepreprocessor import BasePreprocessor, BasePreprocessorSegment

from ..preprocessing.filter import fix_dtype

Expand Down Expand Up @@ -346,8 +345,7 @@ def __init__(
if recording.dtype.kind == "f":
dtype = recording.dtype
else:
raise ValueError(
f"Can't interpolate traces of recording with non-floating dtype={recording.dtype=}.")
raise ValueError(f"Can't interpolate traces of recording with non-floating dtype={recording.dtype=}.")

dtype_ = fix_dtype(recording, dtype)
BasePreprocessor.__init__(self, recording, channel_ids=channel_ids, dtype=dtype_)
Expand Down
2 changes: 0 additions & 2 deletions src/spikeinterface/sortingcomponents/tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from spikeinterface.core import generate_ground_truth_recording



def make_dataset():
# this replace the MEArec 10s file for testing
recording, sorting = generate_ground_truth_recording(
Expand All @@ -23,4 +22,3 @@ def make_dataset():
seed=2205,
)
return recording, sorting

Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def test_estimate_motion():

# same params with differents engine should be the same
motion0, motion1 = motions["rigid / decentralized / torch"], motions["rigid / decentralized / numpy"]
assert (motion0 == motion1)
assert motion0 == motion1

motion0, motion1 = (
motions["rigid / decentralized / torch / time_horizon_s"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
import spikeinterface.core as sc
from spikeinterface import download_dataset
from spikeinterface.sortingcomponents.motion_interpolation import (
InterpolateMotionRecording, correct_motion_on_peaks, interpolate_motion,
interpolate_motion_on_traces)
InterpolateMotionRecording,
correct_motion_on_peaks,
interpolate_motion,
interpolate_motion_on_traces,
)
from spikeinterface.sortingcomponents.motion_utils import Motion
from spikeinterface.sortingcomponents.tests.common import make_dataset

Expand Down
16 changes: 7 additions & 9 deletions src/spikeinterface/sortingcomponents/tests/test_motion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,13 @@

def test_Motion():

temporal_bins_s = np.arange(0., 10., 1.)
spatial_bins_um = np.array([100., 200.])
temporal_bins_s = np.arange(0.0, 10.0, 1.0)
spatial_bins_um = np.array([100.0, 200.0])

displacement = np.zeros((temporal_bins_s.shape[0], spatial_bins_um.shape[0]))
displacement[:, :] = np.linspace(-20, 20, temporal_bins_s.shape[0])[:, np.newaxis]

motion = Motion(
displacement, temporal_bins_s, spatial_bins_um, direction="y"
)
motion = Motion(displacement, temporal_bins_s, spatial_bins_um, direction="y")
print(motion)

# serialize with pickle before interpolation fit
Expand All @@ -40,16 +38,16 @@ def test_Motion():
assert motion2.interpolator is None

# do interpolate
displacement = motion.get_displacement_at_time_and_depth([2, 4.4, 11], [120., 80., 150.])
displacement = motion.get_displacement_at_time_and_depth([2, 4.4, 11], [120.0, 80.0, 150.0])
# print(displacement)
assert displacement.shape[0] == 3
# check clip
assert displacement[2] == 20.
assert displacement[2] == 20.0

# interpolate grid
displacement = motion.get_displacement_at_time_and_depth([2, 4.4, 11, 15, 19], [150., 80.], grid=True)
displacement = motion.get_displacement_at_time_and_depth([2, 4.4, 11, 15, 19], [150.0, 80.0], grid=True)
assert displacement.shape == (2, 5)
assert displacement[0, 2] == 20.
assert displacement[0, 2] == 20.0

# save/load to folder
folder = cache_folder / "motion_saved"
Expand Down

0 comments on commit 221afde

Please sign in to comment.