Skip to content

Commit

Permalink
fix testing imports
Browse files Browse the repository at this point in the history
  • Loading branch information
h-mayorquin committed May 31, 2024
1 parent 59e54ab commit e624756
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 77 deletions.
Original file line number Diff line number Diff line change
@@ -1,54 +1,15 @@
from pathlib import Path
import pickle
from tabnanny import check

import pytest
import numpy as np
import h5py

from spikeinterface import load_extractor
from spikeinterface.core.testing import check_recordings_equal
from spikeinterface.core.testing import check_recordings_equal, check_sortings_equal
from spikeinterface.extractors import NwbRecordingExtractor, NwbSortingExtractor


@pytest.mark.streaming_extractors
@pytest.mark.skipif("ros3" not in h5py.registered_drivers(), reason="ROS3 driver not installed")
def test_recording_s3_nwb_ros3(tmp_path):
file_path = (
"https://dandi-api-staging-dandisets.s3.amazonaws.com/blobs/5f4/b7a/5f4b7a1f-7b95-4ad8-9579-4df6025371cc"
)
rec = NwbRecordingExtractor(file_path, stream_mode="ros3")

start_frame = 0
end_frame = 300
num_frames = end_frame - start_frame

num_seg = rec.get_num_segments()
num_chans = rec.get_num_channels()
dtype = rec.get_dtype()

for segment_index in range(num_seg):
num_samples = rec.get_num_samples(segment_index=segment_index)

full_traces = rec.get_traces(segment_index=segment_index, start_frame=start_frame, end_frame=end_frame)
assert full_traces.shape == (num_frames, num_chans)
assert full_traces.dtype == dtype

if rec.has_scaleable_traces():
trace_scaled = rec.get_traces(segment_index=segment_index, return_scaled=True, end_frame=2)
assert trace_scaled.dtype == "float32"

tmp_file = tmp_path / "test_ros3_recording.pkl"
with open(tmp_file, "wb") as f:
pickle.dump(rec, f)

with open(tmp_file, "rb") as f:
reloaded_recording = pickle.load(f)

check_recordings_equal(rec, reloaded_recording)


@pytest.mark.streaming_extractors
@pytest.mark.parametrize("cache", [True, False]) # Test with and without cache
def test_recording_s3_nwb_fsspec(tmp_path, cache):
Expand Down Expand Up @@ -154,37 +115,6 @@ def test_recording_s3_nwb_remfile_file_like(tmp_path):
check_recordings_equal(rec, rec2)


@pytest.mark.streaming_extractors
@pytest.mark.skipif("ros3" not in h5py.registered_drivers(), reason="ROS3 driver not installed")
def test_sorting_s3_nwb_ros3(tmp_path):
file_path = "https://dandiarchive.s3.amazonaws.com/blobs/84b/aa4/84baa446-cf19-43e8-bdeb-fc804852279b"
# we provide the 'sampling_frequency' because the NWB file does not the electrical series
sort = NwbSortingExtractor(file_path, sampling_frequency=30000, stream_mode="ros3", t_start=0)

start_frame = 0
end_frame = 300
num_frames = end_frame - start_frame

num_seg = sort.get_num_segments()
num_units = len(sort.unit_ids)

for segment_index in range(num_seg):
for unit in sort.unit_ids:
spike_train = sort.get_unit_spike_train(unit_id=unit, segment_index=segment_index)
assert len(spike_train) > 0
assert spike_train.dtype == "int64"
assert np.all(spike_train >= 0)

tmp_file = tmp_path / "test_ros3_sorting.pkl"
with open(tmp_file, "wb") as f:
pickle.dump(sort, f)

with open(tmp_file, "rb") as f:
reloaded_sorting = pickle.load(f)

check_sortings_equal(reloaded_sorting, sort)


@pytest.mark.streaming_extractors
@pytest.mark.parametrize("cache", [True, False]) # Test with and without cache
def test_sorting_s3_nwb_fsspec(tmp_path, cache):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@

import pytest
import numpy as np
import pandas as pd
import shutil
import platform
from pathlib import Path

from spikeinterface.core import generate_ground_truth_recording
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from __future__ import annotations

from pathlib import Path
import json

from typing import List, Optional
import scipy.signal

from spikeinterface.core import BaseRecording
from spikeinterface.core.node_pipeline import PipelineNode, WaveformsNode, find_parent_of_type
Expand Down Expand Up @@ -56,6 +54,8 @@ def __init__(

def compute(self, traces, peaks, waveforms):
# Denoise
import scipy.signal

denoised_waveforms = scipy.signal.savgol_filter(waveforms, self.window_length, self.order, axis=1)

return denoised_waveforms
5 changes: 3 additions & 2 deletions src/spikeinterface/widgets/tests/test_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@

matplotlib.use("Agg")

import matplotlib.pyplot as plt


from spikeinterface import (
compute_sparsity,
Expand Down Expand Up @@ -578,12 +576,15 @@ def test_plot_multicomparison(self):
for backend in possible_backends_by_sorter:
sw.plot_multicomparison_agreement_by_sorter(mcmp)
if backend == "matplotlib":
import matplotlib.pyplot as plt

_, axes = plt.subplots(len(mcmp.object_list), 1)
sw.plot_multicomparison_agreement_by_sorter(mcmp, axes=axes)


if __name__ == "__main__":
# unittest.main()
import matplotlib.pyplot as plt

TestWidgets.setUpClass()
mytest = TestWidgets()
Expand Down

0 comments on commit e624756

Please sign in to comment.