Skip to content

Commit

Permalink
Load pca from waveform extractor and waveform from Zarr (#2613)
Browse files Browse the repository at this point in the history
* zarr and pca backward compatibility

* Add tests for back-compatibility

* Skip backward compatibility tests

* Remove zarr function
  • Loading branch information
alejoe91 authored Apr 12, 2024
1 parent abc5519 commit f26fb17
Show file tree
Hide file tree
Showing 2 changed files with 288 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@

import numpy as np

from spikeinterface.core import generate_ground_truth_recording
from spikeinterface.core import generate_ground_truth_recording, SortingAnalyzer

from spikeinterface.core.waveforms_extractor_backwards_compatibility import MockWaveformExtractor
from spikeinterface.core.waveforms_extractor_backwards_compatibility import extract_waveforms as mock_extract_waveforms
from spikeinterface.core.waveforms_extractor_backwards_compatibility import load_waveforms as load_waveforms_backwards
from spikeinterface.core.waveforms_extractor_backwards_compatibility import _read_old_waveforms_extractor_binary

import spikeinterface.full as si

# remove this when WaveformsExtractor will be removed
from spikeinterface.core import extract_waveforms as old_extract_waveforms
Expand Down Expand Up @@ -87,23 +86,54 @@ def test_extract_waveforms():
print(mock_loaded_we_old)


@pytest.mark.skip()
@pytest.mark.skip("This test is run locally")
def test_read_old_waveforms_extractor_binary():
folder = "/data_local/DataSpikeSorting/waveform_extractor_backward_compatibility/waveforms_extractor_1"
sorting_analyzer = _read_old_waveforms_extractor_binary(folder)
import pandas as pd

print(sorting_analyzer)
folder = Path(__file__).parent / "old_waveforms"
mock_waveforms = load_waveforms_backwards(folder / "we-0.100.0")
sorting_analyzer = load_waveforms_backwards(folder / "we-0.100.0", output="SortingAnalyzer")

assert isinstance(mock_waveforms, MockWaveformExtractor)
assert isinstance(sorting_analyzer, SortingAnalyzer)

for ext_name in sorting_analyzer.get_loaded_extension_names():
print()
print(ext_name)
keys = sorting_analyzer.get_extension(ext_name).data.keys()
print(keys)
data = sorting_analyzer.get_extension(ext_name).get_data()
if isinstance(data, np.ndarray):
print(data.shape)
elif isinstance(data, pd.DataFrame):
print(data.columns)
else:
print(type(data))


# @pytest.mark.skip("This test is run locally")
# def test_read_old_waveforms_extractor_zarr():
# import pandas as pd

# folder = Path(__file__).parent / "old_waveforms"
# mock_waveforms = load_waveforms_backwards(folder / "we-0.100.0.zarr")
# sorting_analyzer = load_waveforms_backwards(folder / "we-0.100.0.zarr", output="SortingAnalyzer")

# assert isinstance(mock_waveforms, MockWaveformExtractor)
# assert isinstance(sorting_analyzer, SortingAnalyzer)

# for ext_name in sorting_analyzer.get_loaded_extension_names():
# print(ext_name)
# keys = sorting_analyzer.get_extension(ext_name).data.keys()
# print(keys)
# data = sorting_analyzer.get_extension(ext_name).get_data()
# if isinstance(data, np.ndarray):
# print(data.shape)
# elif isinstance(data, pd.DataFrame):
# print(data.columns)
# else:
# print(type(data))


if __name__ == "__main__":
test_extract_waveforms()
test_read_old_waveforms_extractor_binary()
# test_read_old_waveforms_extractor_binary()
Loading

0 comments on commit f26fb17

Please sign in to comment.