Skip to content

Commit

Permalink
Merge branch 'main' into select_units_inherit_spike_vector
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 authored Jan 19, 2024
2 parents a552632 + c234cfb commit 3ea2f3d
Show file tree
Hide file tree
Showing 87 changed files with 3,743 additions and 1,908 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/deepinterpolation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
python-version: '3.9'
- name: Get changed files
id: changed-files
uses: tj-actions/changed-files@v35
uses: tj-actions/changed-files@v41
- name: Deepinteprolation changes
id: modules-changed
run: |
Expand Down
11 changes: 9 additions & 2 deletions .github/workflows/full-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ jobs:
uses: ./.github/actions/show-test-environment
- name: Get changed files
id: changed-files
uses: tj-actions/changed-files@v35
uses: tj-actions/changed-files@v41
- name: Module changes
id: modules-changed
run: |
Expand Down Expand Up @@ -123,6 +123,10 @@ jobs:
echo "Sortingcomponents changed"
echo "SORTINGCOMPONENTS_CHANGED=true" >> $GITHUB_OUTPUT
fi
if [[ $file == *"/generation/"* ]]; then
echo "Generation changed"
echo "GENERATION_CHANGED=true" >> $GITHUB_OUTPUT
fi
done
- name: Set execute permissions on run_tests.sh
run: chmod +x .github/run_tests.sh
Expand All @@ -149,7 +153,7 @@ jobs:
if: ${{ steps.modules-changed.outputs.SORTERS_CHANGED == 'true' || steps.modules-changed.outputs.CORE_CHANGED == 'true' }}
run: ./.github/run_tests.sh sorters
- name: Test comparison
if: ${{ steps.modules-changed.outputs.COMPARISON_CHANGED == 'true' || steps.modules-changed.outputs.CORE_CHANGED == 'true' }}
if: ${{ steps.modules-changed.outputs.COMPARISON_CHANGED == 'true' || steps.modules-changed.outputs.GENERATION_CHANGED == 'true' || steps.modules-changed.outputs.CORE_CHANGED == 'true' }}
run: ./.github/run_tests.sh comparison
- name: Test curation
if: ${{ steps.modules-changed.outputs.CURATION_CHANGED == 'true' || steps.modules-changed.outputs.CORE_CHANGED == 'true' }}
Expand All @@ -166,3 +170,6 @@ jobs:
- name: Test internal sorters
if: ${{ steps.modules-changed.outputs.SORTERS_INTERNAL_CHANGED == 'true' || steps.modules-changed.outputs.SORTINGCOMPONENTS_CHANGED || steps.modules-changed.outputs.CORE_CHANGED == 'true' }}
run: ./.github/run_tests.sh sorters_internal
- name: Test generation
if: ${{ steps.modules-changed.outputs.GENERATION_CHANGED == 'true' || steps.modules-changed.outputs.CORE_CHANGED == 'true' }}
run: ./.github/run_tests.sh generation
2 changes: 1 addition & 1 deletion .github/workflows/issue-on-change-matlab.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
uses: actions/checkout@v3
- name: Get changed files
id: changed-files
uses: tj-actions/changed-files@v35
uses: tj-actions/changed-files@v41
with:
files: |
**/*.m
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/streaming-extractor-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
run: sudo apt install libopenblas-dev # Necessary for ROS3 support
- name: Get changed files
id: changed-files
uses: tj-actions/changed-files@v35
uses: tj-actions/changed-files@v41
- name: Module changes
id: modules-changed
run: |
Expand Down
6 changes: 0 additions & 6 deletions .github/workflows/test_containers_docker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ on: workflow_dispatch

jobs:
test-images:
env:
SPIKEINTERFACE_DEV_PATH: ${{ github.workspace }}
name: Test on (${{ matrix.os }})
runs-on: ${{ matrix.os }}
strategy:
Expand All @@ -30,10 +28,6 @@ jobs:
pip install pytest
pip install -e .[full]
pip install docker
- name: Test that containers install the local CI version of spikeinterface
run: |
echo $SPIKEINTERFACE_DEV_PATH
python -c "import os; assert os.getenv('SPIKEINTERFACE_DEV_PATH') is not None"
- name: Run test docker containers
run: |
pytest -vv --capture=tee-sys -rA src/spikeinterface/sorters/external/tests/test_docker_containers.py
7 changes: 0 additions & 7 deletions .github/workflows/test_containers_singularity.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ on: workflow_dispatch

jobs:
test-images:
env:
SPIKEINTERFACE_DEV_PATH: ${{ github.workspace }}
name: Test on (${{ matrix.os }})
runs-on: ${{ matrix.os }}
strategy:
Expand Down Expand Up @@ -34,11 +32,6 @@ jobs:
pip install pytest
pip install -e .[full]
pip install spython
- name: Test that containers install the local CI version of spikeinterface
run: |
echo $SPIKEINTERFACE_DEV_PATH
python -c "import os; assert os.getenv('SPIKEINTERFACE_DEV_PATH') is not None"
ls -l
- name: Run test singularity containers
run: |
pytest -vv --capture=tee-sys -rA src/spikeinterface/sorters/external/tests/test_singularity_containers.py
3 changes: 1 addition & 2 deletions .github/workflows/test_containers_singularity_gpu.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: CI-Test in AWS for GPU
name: Test GPU sorter images in singularity on AWS

on:
workflow_dispatch:
Expand Down Expand Up @@ -46,6 +46,5 @@ jobs:
- name: Run test singularity containers with GPU
env:
REPO_TOKEN: ${{ secrets.PERSONAL_ACCESS_TOKEN }}
SPIKEINTERFACE_DEV_PATH: ${{ github.workspace }}
run: |
pytest -vv --capture=tee-sys -rA src/spikeinterface/sorters/external/tests/test_singularity_containers_gpu.py
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/psf/black
rev: 23.11.0
rev: 23.12.1
hooks:
- id: black
files: ^src/
2 changes: 1 addition & 1 deletion conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
mark_names = ["core", "extractors", "preprocessing", "postprocessing",
"sorters_external", "sorters_internal", "sorters",
"qualitymetrics", "comparison", "curation",
"widgets", "exporters", "sortingcomponents"]
"widgets", "exporters", "sortingcomponents", "generation"]


# define global test folder
Expand Down
6 changes: 6 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -363,3 +363,9 @@ Template Matching
.. automodule:: spikeinterface.sortingcomponents.matching

.. autofunction:: find_spikes_from_templates


spikeinterface.generation
-------------------------

.. automodule:: spikeinterface.generation
9 changes: 9 additions & 0 deletions doc/modules/generation.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Generation module
=================

The :py:mod:`spikeinterface.generation` provides functions to generate recordings containing spikes.
This module proposes several approaches for this including purely synthetic recordings as well as "hybrid" recordings (where templates come from true datasets).


The :py:mod:`spikeinterface.core.generate` already provides functions for generating synthetic data but this module will supply an extended and more complex
machinery, for instance generating recordings that possess various types of drift.
1 change: 1 addition & 0 deletions doc/modules/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ Modules documentation
curation
sortingcomponents
motion_correction
generation
9 changes: 4 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ classifiers = [

dependencies = [
"numpy",
"neo>=0.12.0",
"threadpoolctl",
"tqdm",
"zarr",
"neo>=0.12.0",
"probeinterface>=0.2.19",
]

Expand Down Expand Up @@ -83,7 +84,6 @@ streaming_extractors = [
]

full = [
"zarr",
"h5py",
"pandas",
"xarray",
Expand Down Expand Up @@ -112,16 +112,14 @@ qualitymetrics = [

test_core = [
"pytest",
"zarr",
"psutil",
]

test = [
"pytest",
"pytest-dependency",
"pytest-cov",

# zarr is needed for testing
"zarr",
"xarray",
"huggingface_hub",

Expand Down Expand Up @@ -180,6 +178,7 @@ dev = [
[tool.pytest.ini_options]
markers = [
"core",
"generation",
"extractors",
"preprocessing",
"postprocessing",
Expand Down
6 changes: 6 additions & 0 deletions src/spikeinterface/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,9 @@
import spikeinterface.full as si
"""

# This flag must be set to False for release
# This avoids using versioning that contains ".dev0" (and this is a better choice)
# This is mainly useful when using run_sorter in a container and spikeinterface install
DEV_MODE = True
# DEV_MODE = False
2 changes: 1 addition & 1 deletion src/spikeinterface/comparison/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from .hybrid import (
HybridSpikesRecording,
HybridUnitsRecording,
generate_injected_sorting,
generate_sorting_to_inject,
create_hybrid_units_recording,
create_hybrid_spikes_recording,
)
4 changes: 2 additions & 2 deletions src/spikeinterface/comparison/basecomparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _do_comparison(
):
# do pairwise matching
if self._verbose:
print("Multicomaprison step 1: pairwise comparison")
print("Multicomparison step 1: pairwise comparison")

self.comparisons = {}
for i in range(len(self.object_list)):
Expand Down Expand Up @@ -133,7 +133,7 @@ def _do_graph(self):

def _clean_graph(self):
if self._verbose:
print("Multicomaprison step 3: clean graph")
print("Multicomparison step 3: clean graph")
clean_graph = self.graph.copy()
import networkx as nx

Expand Down
47 changes: 7 additions & 40 deletions src/spikeinterface/comparison/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
NumpySorting,
)
from spikeinterface.core.core_tools import define_function_from_class
from spikeinterface.core.generate import generate_sorting, InjectTemplatesRecording, _ensure_seed
from spikeinterface.core.generate import (
generate_sorting,
InjectTemplatesRecording,
_ensure_seed,
generate_sorting_to_inject,
)


class HybridUnitsRecording(InjectTemplatesRecording):
Expand Down Expand Up @@ -174,7 +179,7 @@ def __init__(
num_samples = [
target_recording.get_num_frames(seg_index) for seg_index in range(target_recording.get_num_segments())
]
self.injected_sorting = generate_injected_sorting(
self.injected_sorting = generate_sorting_to_inject(
target_sorting, num_samples, max_injected_per_unit, injected_rate, refractory_period_ms
)
else:
Expand All @@ -201,44 +206,6 @@ def __init__(
)


def generate_injected_sorting(
sorting: BaseSorting,
num_samples: List[int],
max_injected_per_unit: int = 1000,
injected_rate: float = 0.05,
refractory_period_ms: float = 1.5,
) -> NumpySorting:
injected_spike_trains = [{} for seg_index in range(sorting.get_num_segments())]
t_r = int(round(refractory_period_ms * sorting.get_sampling_frequency() * 1e-3))

for segment_index in range(sorting.get_num_segments()):
for unit_id in sorting.unit_ids:
spike_train = sorting.get_unit_spike_train(unit_id, segment_index=segment_index)
n_injection = min(max_injected_per_unit, int(round(injected_rate * len(spike_train))))
# Inject more, then take out all that violate the refractory period.
n = int(n_injection + 10 * np.sqrt(n_injection))
injected_spike_train = np.sort(
np.random.uniform(low=0, high=num_samples[segment_index], size=n).astype(np.int64)
)

# Remove spikes that are in the refractory period.
violations = np.where(np.diff(injected_spike_train) < t_r)[0]
injected_spike_train = np.delete(injected_spike_train, violations)

# Remove spikes that violate the refractory period of the real spikes.
# TODO: Need a better & faster way than this.
min_diff = np.min(np.abs(injected_spike_train[:, None] - spike_train[None, :]), axis=1)
violations = min_diff < t_r
injected_spike_train = injected_spike_train[~violations]

if len(injected_spike_train) > n_injection:
injected_spike_train = np.sort(np.random.choice(injected_spike_train, n_injection, replace=False))

injected_spike_trains[segment_index][unit_id] = injected_spike_train

return NumpySorting.from_unit_dict(injected_spike_trains, sorting.get_sampling_frequency())


create_hybrid_units_recording = define_function_from_class(
source_class=HybridUnitsRecording, name="create_hybrid_units_recording"
)
Expand Down
11 changes: 1 addition & 10 deletions src/spikeinterface/comparison/tests/test_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from spikeinterface.comparison import (
create_hybrid_units_recording,
create_hybrid_spikes_recording,
generate_injected_sorting,
)
from spikeinterface.extractors import toy_example
from spikeinterface.preprocessing import bandpass_filter
Expand Down Expand Up @@ -89,16 +88,8 @@ def test_hybrid_spikes_recording():
check_recordings_equal(hybrid_spikes_recording, saved_2job, return_scaled=False)


def test_generate_injected_sorting():
recording = load_extractor(cache_folder / "recording")
sorting = load_extractor(cache_folder / "sorting")
injected_sorting = generate_injected_sorting(
sorting, [recording.get_num_frames(seg_index) for seg_index in range(recording.get_num_segments())]
)


if __name__ == "__main__":
setup_module()
test_generate_injected_sorting()
test_generate_sorting_to_inject()
test_hybrid_units_recording()
test_hybrid_spikes_recording()
20 changes: 14 additions & 6 deletions src/spikeinterface/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
from .base import load_extractor # , load_extractor_from_dict, load_extractor_from_json, load_extractor_from_pickle
from .baserecording import BaseRecording, BaseRecordingSegment
from .basesorting import BaseSorting, BaseSortingSegment
from .basesorting import BaseSorting, BaseSortingSegment, SpikeVectorSortingSegment
from .baseevent import BaseEvent, BaseEventSegment
from .basesnippets import BaseSnippets, BaseSnippetsSegment
from .baserecordingsnippets import BaseRecordingSnippets

# main extractor from dump and cache
from .binaryrecordingextractor import BinaryRecordingExtractor, read_binary
from .npzsortingextractor import NpzSortingExtractor, read_npz_sorting
from .numpyextractors import NumpyRecording, NumpySorting, SharedMemorySorting, NumpyEvent, NumpySnippets
from .zarrrecordingextractor import ZarrRecordingExtractor, read_zarr, get_default_zarr_compressor
from .numpyextractors import (
NumpyRecording,
SharedMemoryRecording,
NumpySorting,
SharedMemorySorting,
NumpyEvent,
NumpySnippets,
)
from .zarrextractors import ZarrRecordingExtractor, ZarrSortingExtractor, read_zarr, get_default_zarr_compressor
from .binaryfolder import BinaryFolderRecording, read_binary_folder
from .sortingfolder import NumpyFolderSorting, NpzFolderSorting, read_numpy_sorting_folder, read_npz_folder
from .npysnippetsextractor import NpySnippetsExtractor, read_npy_snippets
Expand Down Expand Up @@ -79,14 +86,13 @@

# tools
from .core_tools import (
write_binary_recording,
write_to_h5_dataset_format,
write_binary_recording,
read_python,
write_python,
)
from .job_tools import ensure_n_jobs, ensure_chunk_size, ChunkRecordingExecutor, split_job_kwargs, fix_job_kwargs
from .recording_tools import (
write_binary_recording,
write_to_h5_dataset_format,
get_random_data_chunks,
get_channel_distances,
get_closest_channels,
Expand Down Expand Up @@ -132,3 +138,5 @@

# channel sparsity
from .sparsity import ChannelSparsity, compute_sparsity

from .template import Templates
Loading

0 comments on commit 3ea2f3d

Please sign in to comment.