Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace toy_example by generate_ground_truth_recording in sorters folder #2919

Merged
merged 5 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions src/spikeinterface/preprocessing/tests/test_phase_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,6 @@

import scipy.fft

if hasattr(pytest, "global_test_folder"):
cache_folder = pytest.global_test_folder / "preprocessing"
else:
cache_folder = Path("cache_folder") / "preprocessing"

set_global_tmp_folder(cache_folder)


def create_shifted_channel():
duration = 5.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
from pathlib import Path

from spikeinterface import generate_ground_truth_recording
from spikeinterface.core.core_tools import is_editable_mode
import spikeinterface.extractors as se
import spikeinterface.sorters as ss
Expand All @@ -23,7 +24,7 @@ def check_gh_settings():


def generate_run_kwargs():
test_recording, _ = se.toy_example(duration=30, seed=0, num_channels=64, num_segments=1)
test_recording, _ = generate_ground_truth_recording(durations=[30], seed=0, num_channels=64)
test_recording = test_recording.save(name="toy")
test_recording.set_channel_gains(1)
test_recording.set_channel_offsets(0)
Expand Down
5 changes: 2 additions & 3 deletions src/spikeinterface/sorters/external/tests/test_kilosort4.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
import pytest
from pathlib import Path

from spikeinterface import load_extractor
from spikeinterface.extractors import toy_example
from spikeinterface import load_extractor, generate_ground_truth_recording
from spikeinterface.sorters import Kilosort4Sorter, run_sorter
from spikeinterface.sorters.tests.common_tests import SorterCommonTestSuite

Expand All @@ -23,7 +22,7 @@ def setUp(self):
if (cache_folder / "rec").is_dir():
recording = load_extractor(cache_folder / "rec")
else:
recording, _ = toy_example(num_channels=32, duration=60, seed=0, num_segments=1)
recording, _ = generate_ground_truth_recording(num_channels=32, durations=[60], seed=0)
recording = recording.save(folder=cache_folder / "rec", verbose=False, format="binary")
self.recording = recording
print(self.recording)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import unittest

import pytest
from spikeinterface.extractors import toy_example
from spikeinterface.sorters import PyKilosortSorter
from spikeinterface.sorters.tests.common_tests import SorterCommonTestSuite

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
from pathlib import Path

from spikeinterface import generate_ground_truth_recording
from spikeinterface.core.core_tools import is_editable_mode
import spikeinterface.extractors as se
import spikeinterface.sorters as ss
Expand All @@ -29,7 +30,7 @@ def check_gh_settings():


def generate_run_kwargs():
test_recording, _ = se.toy_example(duration=30, seed=0, num_channels=64, num_segments=1)
test_recording, _ = generate_ground_truth_recording(durations=[30], seed=0, num_channels=64)
test_recording = test_recording.save(name="toy")
test_recording.set_channel_gains(1)
test_recording.set_channel_offsets(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@

import pytest

from spikeinterface import generate_ground_truth_recording
from spikeinterface.core.core_tools import is_editable_mode
import spikeinterface.extractors as se

import spikeinterface.sorters as ss

os.environ["SINGULARITY_DISABLE_CACHE"] = "true"
Expand All @@ -23,7 +24,7 @@ def check_gh_settings():


def generate_run_kwargs():
test_recording, _ = se.toy_example(duration=30, seed=0, num_channels=64, num_segments=1)
test_recording, _ = generate_ground_truth_recording(durations=[30], seed=0, num_channels=64)
test_recording = test_recording.save(name="toy")
test_recording.set_channel_gains(1)
test_recording.set_channel_offsets(0)
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/sorters/external/yass.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class YassSorter(BaseSorter):

1. Retraining Neural Networks (Default)

rec, sort = se.toy_example(duration=300)
rec, sort = generate_ground_truth_recording(durations=[300])
sorting_yass = ss.run_yass(rec, '/home/cat/Downloads/test2')


Expand Down
6 changes: 3 additions & 3 deletions src/spikeinterface/sorters/tests/common_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pathlib import Path
import shutil

from spikeinterface.extractors import toy_example
from spikeinterface import generate_ground_truth_recording
from spikeinterface.sorters import run_sorter
from spikeinterface.core.snippets_tools import snippets_from_sorting

Expand All @@ -24,7 +24,7 @@ class SorterCommonTestSuite:
SorterClass = None

def setUp(self):
recording, sorting_gt = toy_example(num_channels=4, duration=60, seed=0, num_segments=1)
recording, sorting_gt = generate_ground_truth_recording(num_channels=4, durations=[60], seed=0)
rec_folder = cache_folder / "rec"
if rec_folder.is_dir():
shutil.rmtree(rec_folder)
Expand Down Expand Up @@ -80,7 +80,7 @@ class SnippetsSorterCommonTestSuite:
SorterClass = None

def setUp(self):
recording, sorting_gt = toy_example(num_channels=4, duration=60, seed=0, num_segments=1)
recording, sorting_gt = generate_ground_truth_recording(num_channels=4, durations=[60], seed=0)
snippets_folder = cache_folder / "snippets"
if snippets_folder.is_dir():
shutil.rmtree(snippets_folder)
Expand Down
7 changes: 4 additions & 3 deletions src/spikeinterface/sorters/tests/test_container_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import os

import spikeinterface as si
from spikeinterface.extractors import toy_example
from spikeinterface import generate_ground_truth_recording

from spikeinterface.sorters.container_tools import find_recording_folders, ContainerClient, install_package_in_container

ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS"))
Expand All @@ -21,10 +22,10 @@ def setup_module():
for test_dir in test_dirs:
if test_dir.exists():
shutil.rmtree(test_dir)
rec1, _ = toy_example(num_segments=1)
rec1, _ = generate_ground_truth_recording(durations=[10])
rec1 = rec1.save(folder=cache_folder / "mono")

rec2, _ = toy_example(num_segments=3)
rec2, _ = generate_ground_truth_recording(durations=[10, 10, 10])
rec2 = rec2.save(folder=cache_folder / "multi")


Expand Down
1 change: 0 additions & 1 deletion src/spikeinterface/sorters/tests/test_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from spikeinterface.core import load_extractor

# from spikeinterface.extractors import toy_example
from spikeinterface import generate_ground_truth_recording
from spikeinterface.sorters import run_sorter_jobs, run_sorter_by_property

Expand Down
Loading