From e6247562a854d2619d7380acecb0d9be496c504f Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Fri, 31 May 2024 07:08:37 -0600 Subject: [PATCH] fix testing imports --- .../tests/test_nwbextractors_streaming.py | 70 ------------------- .../tests/common_extension_tests.py | 2 - .../waveforms/savgol_denoiser.py | 6 +- .../widgets/tests/test_widgets.py | 5 +- 4 files changed, 6 insertions(+), 77 deletions(-) diff --git a/src/spikeinterface/extractors/tests/test_nwbextractors_streaming.py b/src/spikeinterface/extractors/tests/test_nwbextractors_streaming.py index 2732e5077a..b3c5b9c934 100644 --- a/src/spikeinterface/extractors/tests/test_nwbextractors_streaming.py +++ b/src/spikeinterface/extractors/tests/test_nwbextractors_streaming.py @@ -1,10 +1,8 @@ from pathlib import Path import pickle -from tabnanny import check import pytest import numpy as np -import h5py from spikeinterface import load_extractor from spikeinterface.core.testing import check_recordings_equal @@ -12,43 +10,6 @@ from spikeinterface.extractors import NwbRecordingExtractor, NwbSortingExtractor -@pytest.mark.streaming_extractors -@pytest.mark.skipif("ros3" not in h5py.registered_drivers(), reason="ROS3 driver not installed") -def test_recording_s3_nwb_ros3(tmp_path): - file_path = ( - "https://dandi-api-staging-dandisets.s3.amazonaws.com/blobs/5f4/b7a/5f4b7a1f-7b95-4ad8-9579-4df6025371cc" - ) - rec = NwbRecordingExtractor(file_path, stream_mode="ros3") - - start_frame = 0 - end_frame = 300 - num_frames = end_frame - start_frame - - num_seg = rec.get_num_segments() - num_chans = rec.get_num_channels() - dtype = rec.get_dtype() - - for segment_index in range(num_seg): - num_samples = rec.get_num_samples(segment_index=segment_index) - - full_traces = rec.get_traces(segment_index=segment_index, start_frame=start_frame, end_frame=end_frame) - assert full_traces.shape == (num_frames, num_chans) - assert full_traces.dtype == dtype - - if rec.has_scaleable_traces(): - trace_scaled = rec.get_traces(segment_index=segment_index, return_scaled=True, end_frame=2) - assert trace_scaled.dtype == "float32" - - tmp_file = tmp_path / "test_ros3_recording.pkl" - with open(tmp_file, "wb") as f: - pickle.dump(rec, f) - - with open(tmp_file, "rb") as f: - reloaded_recording = pickle.load(f) - - check_recordings_equal(rec, reloaded_recording) - - @pytest.mark.streaming_extractors @pytest.mark.parametrize("cache", [True, False]) # Test with and without cache def test_recording_s3_nwb_fsspec(tmp_path, cache): @@ -154,37 +115,6 @@ def test_recording_s3_nwb_remfile_file_like(tmp_path): check_recordings_equal(rec, rec2) -@pytest.mark.streaming_extractors -@pytest.mark.skipif("ros3" not in h5py.registered_drivers(), reason="ROS3 driver not installed") -def test_sorting_s3_nwb_ros3(tmp_path): - file_path = "https://dandiarchive.s3.amazonaws.com/blobs/84b/aa4/84baa446-cf19-43e8-bdeb-fc804852279b" - # we provide the 'sampling_frequency' because the NWB file does not the electrical series - sort = NwbSortingExtractor(file_path, sampling_frequency=30000, stream_mode="ros3", t_start=0) - - start_frame = 0 - end_frame = 300 - num_frames = end_frame - start_frame - - num_seg = sort.get_num_segments() - num_units = len(sort.unit_ids) - - for segment_index in range(num_seg): - for unit in sort.unit_ids: - spike_train = sort.get_unit_spike_train(unit_id=unit, segment_index=segment_index) - assert len(spike_train) > 0 - assert spike_train.dtype == "int64" - assert np.all(spike_train >= 0) - - tmp_file = tmp_path / "test_ros3_sorting.pkl" - with open(tmp_file, "wb") as f: - pickle.dump(sort, f) - - with open(tmp_file, "rb") as f: - reloaded_sorting = pickle.load(f) - - check_sortings_equal(reloaded_sorting, sort) - - @pytest.mark.streaming_extractors @pytest.mark.parametrize("cache", [True, False]) # Test with and without cache def test_sorting_s3_nwb_fsspec(tmp_path, cache): diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index bf462a9466..605997f5f6 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -2,9 +2,7 @@ import pytest import numpy as np -import pandas as pd import shutil -import platform from pathlib import Path from spikeinterface.core import generate_ground_truth_recording diff --git a/src/spikeinterface/sortingcomponents/waveforms/savgol_denoiser.py b/src/spikeinterface/sortingcomponents/waveforms/savgol_denoiser.py index e03d52fb35..2a54fe231c 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/savgol_denoiser.py +++ b/src/spikeinterface/sortingcomponents/waveforms/savgol_denoiser.py @@ -1,9 +1,7 @@ from __future__ import annotations -from pathlib import Path -import json + from typing import List, Optional -import scipy.signal from spikeinterface.core import BaseRecording from spikeinterface.core.node_pipeline import PipelineNode, WaveformsNode, find_parent_of_type @@ -56,6 +54,8 @@ def __init__( def compute(self, traces, peaks, waveforms): # Denoise + import scipy.signal + denoised_waveforms = scipy.signal.savgol_filter(waveforms, self.window_length, self.order, axis=1) return denoised_waveforms diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 5366fb864f..775d0b3fc5 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -8,8 +8,6 @@ matplotlib.use("Agg") -import matplotlib.pyplot as plt - from spikeinterface import ( compute_sparsity, @@ -578,12 +576,15 @@ def test_plot_multicomparison(self): for backend in possible_backends_by_sorter: sw.plot_multicomparison_agreement_by_sorter(mcmp) if backend == "matplotlib": + import matplotlib.pyplot as plt + _, axes = plt.subplots(len(mcmp.object_list), 1) sw.plot_multicomparison_agreement_by_sorter(mcmp, axes=axes) if __name__ == "__main__": # unittest.main() + import matplotlib.pyplot as plt TestWidgets.setUpClass() mytest = TestWidgets()