diff --git a/.github/actions/build-test-environment/action.yml b/.github/actions/build-test-environment/action.yml index 723e8a702f..c2524d2c16 100644 --- a/.github/actions/build-test-environment/action.yml +++ b/.github/actions/build-test-environment/action.yml @@ -1,41 +1,20 @@ name: Install packages description: This action installs the package and its dependencies for testing -inputs: - python-version: - description: 'Python version to set up' - required: false - os: - description: 'Operating system to set up' - required: false - runs: using: "composite" steps: - name: Install dependencies run: | - sudo apt install git git config --global user.email "CI@example.com" git config --global user.name "CI Almighty" - python -m venv ${{ github.workspace }}/test_env # Environment used in the caching step - python -m pip install -U pip # Official recommended way - source ${{ github.workspace }}/test_env/bin/activate pip install tabulate # This produces summaries at the end pip install -e .[test,extractors,streaming_extractors,test_extractors,full] shell: bash - - name: Force installation of latest dev from key-packages when running dev (not release) - run: | - source ${{ github.workspace }}/test_env/bin/activate - spikeinterface_is_dev_version=$(python -c "import spikeinterface; print(spikeinterface.DEV_MODE)") - if [ $spikeinterface_is_dev_version = "True" ]; then - echo "Running spikeinterface dev version" - pip install --no-cache-dir git+https://github.com/NeuralEnsemble/python-neo - pip install --no-cache-dir git+https://github.com/SpikeInterface/probeinterface - fi - echo "Running tests for release, using pyproject.toml versions of neo and probeinterface" + - name: Install git-annex shell: bash - - name: git-annex install run: | + pip install datalad-installer wget https://downloads.kitenet.net/git-annex/linux/current/git-annex-standalone-amd64.tar.gz mkdir /home/runner/work/installation mv git-annex-standalone-amd64.tar.gz /home/runner/work/installation/ @@ -44,4 +23,14 @@ runs: tar xvzf git-annex-standalone-amd64.tar.gz echo "$(pwd)/git-annex.linux" >> $GITHUB_PATH cd $workdir + git config --global filter.annex.process "git-annex filter-process" # recommended for efficiency + - name: Force installation of latest dev from key-packages when running dev (not release) + run: | + spikeinterface_is_dev_version=$(python -c "import spikeinterface; print(spikeinterface.DEV_MODE)") + if [ $spikeinterface_is_dev_version = "True" ]; then + echo "Running spikeinterface dev version" + pip install --no-cache-dir git+https://github.com/NeuralEnsemble/python-neo + pip install --no-cache-dir git+https://github.com/SpikeInterface/probeinterface + fi + echo "Running tests for release, using pyproject.toml versions of neo and probeinterface" shell: bash diff --git a/.github/workflows/all-tests.yml b/.github/workflows/all-tests.yml index dcaec8b272..a9c840d5d5 100644 --- a/.github/workflows/all-tests.yml +++ b/.github/workflows/all-tests.yml @@ -47,7 +47,7 @@ jobs: echo "$file was changed" done - - name: Set testing environment # This decides which tests are run and whether to install especial dependencies + - name: Set testing environment # This decides which tests are run and whether to install special dependencies shell: bash run: | changed_files="${{ steps.changed-files.outputs.all_changed_files }}" diff --git a/.github/workflows/full-test-with-codecov.yml b/.github/workflows/full-test-with-codecov.yml index 407c614ebf..f8ed2aa7a9 100644 --- a/.github/workflows/full-test-with-codecov.yml +++ b/.github/workflows/full-test-with-codecov.yml @@ -45,7 +45,6 @@ jobs: env: HDF5_PLUGIN_PATH: ${{ github.workspace }}/hdf5_plugin_path_maxwell run: | - source ${{ github.workspace }}/test_env/bin/activate pytest -m "not sorters_external" --cov=./ --cov-report xml:./coverage.xml -vv -ra --durations=0 | tee report_full.txt; test ${PIPESTATUS[0]} -eq 0 || exit 1 echo "# Timing profile of full tests" >> $GITHUB_STEP_SUMMARY python ./.github/scripts/build_job_summary.py report_full.txt >> $GITHUB_STEP_SUMMARY diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4c4bd68be4..4c36d6fb86 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,12 +1,12 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: v5.0.0 hooks: - id: check-yaml - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/psf/black - rev: 24.8.0 + rev: 24.10.0 hooks: - id: black files: ^src/ diff --git a/doc/api.rst b/doc/api.rst index 6bb9b39091..eb9a61eb9c 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -346,6 +346,9 @@ spikeinterface.curation .. autofunction:: remove_redundant_units .. autofunction:: remove_duplicated_spikes .. autofunction:: remove_excess_spikes + .. autofunction:: load_model + .. autofunction:: auto_label_units + .. autofunction:: train_model Deprecated ~~~~~~~~~~ diff --git a/doc/conf.py b/doc/conf.py index e3d58ca8f2..d229dc18ee 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -119,12 +119,15 @@ # for sphinx gallery plugin sphinx_gallery_conf = { - 'only_warn_on_example_error': True, + # This is the default but including here explicitly. Should build all docs and fail on gallery failures only. + # other option would be abort_on_example_error, but this fails on first failure. So we decided against this. + 'only_warn_on_example_error': False, 'examples_dirs': ['../examples/tutorials'], 'gallery_dirs': ['tutorials' ], # path where to save gallery generated examples 'subsection_order': ExplicitOrder([ '../examples/tutorials/core', '../examples/tutorials/extractors', + '../examples/tutorials/curation', '../examples/tutorials/qualitymetrics', '../examples/tutorials/comparison', '../examples/tutorials/widgets', diff --git a/doc/development/development.rst b/doc/development/development.rst index 246a2bcb9a..1638c41243 100644 --- a/doc/development/development.rst +++ b/doc/development/development.rst @@ -192,6 +192,7 @@ Miscelleaneous Stylistic Conventions #. Avoid using abbreviations in variable names (e.g. use :code:`recording` instead of :code:`rec`). It is especially important to avoid single letter variables. #. Use index as singular and indices for plural following the NumPy convention. Avoid idx or indexes. Plus, id and ids are reserved for identifiers (i.e. channel_ids) #. We use file_path and folder_path (instead of file_name and folder_name) for clarity. +#. For the titles of documentation pages, only capitalize the first letter of the first word and classes or software packages. For example, "How to use a SortingAnalyzer in SpikeInterface". #. For creating headers to divide sections of code we use the following convention (see issue `#3019 `_): @@ -212,6 +213,25 @@ We use Sphinx to build the documentation. To build the documentation locally, yo This will build the documentation in the :code:`doc/_build/html` folder. You can open the :code:`index.html` file in your browser to see the documentation. +Adding new documentation +------------------------ + +Documentation can be added as a +`sphinx-gallery `_ +python file ('tutorials') +or a +`sphinx rst `_ +file (all other sections). + +To add a new tutorial, add your ``.py`` file to ``spikeinterface/examples``. +Then, update the ``spikeinterface/doc/tutorials_custom_index.rst`` file +to make a new card linking to the page and an optional image. See +``tutorials_custom_index.rst`` header for more information. + +For other sections, write your documentation in ``.rst`` format and add +the page to the appropriate ``index.rst`` file found in the relevant +folder (e.g. ``how_to/index.rst``). + How to run code coverage locally -------------------------------- To run code coverage locally, you can use the following command: diff --git a/doc/get_started/quickstart.rst b/doc/get_started/quickstart.rst index 3d45606a78..bfa4335ac1 100644 --- a/doc/get_started/quickstart.rst +++ b/doc/get_started/quickstart.rst @@ -287,7 +287,7 @@ available parameters are dictionaries and can be accessed with: 'detect_threshold': 5, 'freq_max': 5000.0, 'freq_min': 400.0, - 'max_threads_per_process': 1, + 'max_threads_per_worker': 1, 'mp_context': None, 'n_jobs': 20, 'nested_params': None, @@ -673,7 +673,7 @@ compute quality metrics (some quality metrics require certain extensions 'min_spikes': 0, 'window_size_s': 1}, 'snr': {'peak_mode': 'extremum', 'peak_sign': 'neg'}, - 'synchrony': {'synchrony_sizes': (2, 4, 8)}} + 'synchrony': {} Since the recording is very short, let’s change some parameters to diff --git a/doc/how_to/auto_curation_prediction.rst b/doc/how_to/auto_curation_prediction.rst new file mode 100644 index 0000000000..9b1612ec12 --- /dev/null +++ b/doc/how_to/auto_curation_prediction.rst @@ -0,0 +1,43 @@ +How to use a trained model to predict the curation labels +========================================================= + +For a more detailed guide to using trained models, `read our tutorial here +`_). + +There is a Collection of models for automated curation available on the +`SpikeInterface HuggingFace page `_. + +We'll apply the model ``toy_tetrode_model`` from ``SpikeInterface`` on a SortingAnalyzer +called ``sorting_analyzer``. We assume that the quality and template metrics have +already been computed. + +We need to pass the ``sorting_analyzer``, the ``repo_id`` (which is just the part of the +repo's URL after huggingface.co/) and that we trust the model. + +.. code:: + + from spikeinterface.curation import auto_label_units + + labels_and_probabilities = auto_label_units( + sorting_analyzer = sorting_analyzer, + repo_id = "SpikeInterface/toy_tetrode_model", + trust_model = True + ) + +If you have a local directory containing the model in a ``skops`` file you can use this to +create the labels: + +.. code:: + + labels_and_probabilities = si.auto_label_units( + sorting_analyzer = sorting_analyzer, + model_folder = "my_folder_with_a_model_in_it", + ) + +The returned labels are a dictionary of model's predictions and it's confidence. These +are also saved as a property of your ``sorting_analyzer`` and can be accessed like so: + +.. code:: + + labels = sorting_analyzer.sorting.get_property("classifier_label") + probabilities = sorting_analyzer.sorting.get_property("classifier_probability") diff --git a/doc/how_to/auto_curation_training.rst b/doc/how_to/auto_curation_training.rst new file mode 100644 index 0000000000..20ab57d284 --- /dev/null +++ b/doc/how_to/auto_curation_training.rst @@ -0,0 +1,58 @@ +How to train a model to predict curation labels +=============================================== + +A full tutorial for model-based curation can be found `here `_. + +Here, we assume that you have: + +* Two SortingAnalyzers called ``analyzer_1`` and + ``analyzer_2``, and have calculated some template and quality metrics for both +* Manually curated labels for the units in each analyzer, in lists called + ``analyzer_1_labels`` and ``analyzer_2_labels``. If you have used phy, the lists can + be accessed using ``curated_labels = analyzer.sorting.get_property("quality")``. + +With these objects calculated, you can train a model as follows + +.. code:: + + from spikeinterface.curation import train_model + + analyzer_list = [analyzer_1, analyzer_2] + labels_list = [analyzer_1_labels, analyzer_2_labels] + output_folder = "/path/to/output_folder" + + trainer = train_model( + mode="analyzers", + labels=labels_list, + analyzers=analyzer_list, + output_folder=output_folder, + metric_names=None, # Set if you want to use a subset of metrics, defaults to all calculated quality and template metrics + imputation_strategies=None, # Default is all available imputation strategies + scaling_techniques=None, # Default is all available scaling techniques + classifiers=None, # Defaults to Random Forest classifier only - we usually find this gives the best results, but a range of classifiers is available + seed=None, # Set a seed for reproducibility + ) + + +The trainer tries several models and chooses the most accurate one. This model and +some metadata are stored in the ``output_folder``, which can later be loaded using the +``load_model`` function (`more details `_). +We can also access the model, which is an sklearn ``Pipeline``, from the trainer object + +.. code:: + + best_model = trainer.best_pipeline + + +The training function can also be run in “csv” mode, if you prefer to +store metrics in as .csv files. If the target labels are stored as a column in +the file, you can point to these with the ``target_label`` parameter + +.. code:: + + trainer = train_model( + mode="csv", + metrics_paths = ["/path/to/csv_file_1", "/path/to/csv_file_2"], + target_label = "my_label", + output_folder=output_folder, + ) diff --git a/doc/how_to/benchmark_with_hybrid_recordings.rst b/doc/how_to/benchmark_with_hybrid_recordings.rst index 9975bb1a4b..5121a69690 100644 --- a/doc/how_to/benchmark_with_hybrid_recordings.rst +++ b/doc/how_to/benchmark_with_hybrid_recordings.rst @@ -2531,9 +2531,8 @@ Although non of the sorters find all units perfectly, ``Kilosort2.5``, ``Kilosort4``, and ``SpyKING CIRCUS 2`` all find around 10-12 hybrid units with accuracy greater than 80%. ``Kilosort4`` has a better overall curve, being able to find almost all units with an accuracy above 50%. -``Kilosort2.5`` performs well when looking at precision (finding all -spikes in a hybrid unit), at the cost of lower recall (finding spikes -when it shouldn’t). +``Kilosort2.5`` performs well when looking at precision (not finding spikes +when it shouldn’t), but it has a lower recall (finding all spikes in the ground truth). In this example, we showed how to: diff --git a/doc/how_to/combine_recordings.rst b/doc/how_to/combine_recordings.rst index db37e28382..4a088f01b1 100644 --- a/doc/how_to/combine_recordings.rst +++ b/doc/how_to/combine_recordings.rst @@ -1,4 +1,4 @@ -Combine Recordings in SpikeInterface +Combine recordings in SpikeInterface ==================================== In this tutorial we will walk through combining multiple recording objects. Sometimes this occurs due to hardware diff --git a/doc/how_to/index.rst b/doc/how_to/index.rst index 5d7eae9003..7f79156a3b 100644 --- a/doc/how_to/index.rst +++ b/doc/how_to/index.rst @@ -15,3 +15,5 @@ Guides on how to solve specific, short problems in SpikeInterface. Learn how to. load_your_data_into_sorting benchmark_with_hybrid_recordings drift_with_lfp + auto_curation_training + auto_curation_prediction diff --git a/doc/how_to/load_matlab_data.rst b/doc/how_to/load_matlab_data.rst index 1f24fb66d3..eab1e0a300 100644 --- a/doc/how_to/load_matlab_data.rst +++ b/doc/how_to/load_matlab_data.rst @@ -1,4 +1,4 @@ -Export MATLAB Data to Binary & Load in SpikeInterface +Export MATLAB data to binary & load in SpikeInterface ======================================================== In this tutorial, we will walk through the process of exporting data from MATLAB in a binary format and subsequently loading it using SpikeInterface in Python. diff --git a/doc/how_to/load_your_data_into_sorting.rst b/doc/how_to/load_your_data_into_sorting.rst index 4e434ecb7a..e250cfa6e9 100644 --- a/doc/how_to/load_your_data_into_sorting.rst +++ b/doc/how_to/load_your_data_into_sorting.rst @@ -1,5 +1,5 @@ -Load Your Own Data into a Sorting -================================= +Load your own data into a Sorting object +======================================== Why make a :code:`Sorting`? diff --git a/doc/how_to/process_by_channel_group.rst b/doc/how_to/process_by_channel_group.rst index bac0de4d0c..08a87ab738 100644 --- a/doc/how_to/process_by_channel_group.rst +++ b/doc/how_to/process_by_channel_group.rst @@ -1,4 +1,4 @@ -Process a Recording by Channel Group +Process a recording by channel group ==================================== In this tutorial, we will walk through how to preprocess and sort a recording diff --git a/doc/how_to/viewers.rst b/doc/how_to/viewers.rst index c7574961bd..7bb41cadb6 100644 --- a/doc/how_to/viewers.rst +++ b/doc/how_to/viewers.rst @@ -1,4 +1,4 @@ -Visualize Data +Visualize data ============== There are several ways to plot signals (raw, preprocessed) and spikes. diff --git a/doc/images/files_screen.png b/doc/images/files_screen.png new file mode 100644 index 0000000000..ef2b5b0873 Binary files /dev/null and b/doc/images/files_screen.png differ diff --git a/doc/images/hf-logo.svg b/doc/images/hf-logo.svg new file mode 100644 index 0000000000..ab959d165f --- /dev/null +++ b/doc/images/hf-logo.svg @@ -0,0 +1,8 @@ + + + + + + + + diff --git a/doc/images/initial_model_screen.png b/doc/images/initial_model_screen.png new file mode 100644 index 0000000000..b01c4248a6 Binary files /dev/null and b/doc/images/initial_model_screen.png differ diff --git a/doc/images/overview.png b/doc/images/overview.png index ea5ba49d08..e367c4b6e4 100644 Binary files a/doc/images/overview.png and b/doc/images/overview.png differ diff --git a/doc/index.rst b/doc/index.rst index ed443e4200..e6d8aa3fea 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -51,7 +51,7 @@ SpikeInterface is made of several modules to deal with different aspects of the overview get_started/index - tutorials/index + tutorials_custom_index how_to/index modules/index api diff --git a/doc/modules/benchmark.rst b/doc/modules/benchmark.rst new file mode 100644 index 0000000000..faf53be790 --- /dev/null +++ b/doc/modules/benchmark.rst @@ -0,0 +1,141 @@ +Benchmark module +================ + +This module contains machinery to compare some sorters against ground truth in many multiple situtation. + + +..notes:: + + In 0.102.0 The previous :py:func:`~spikeinterface.comparison.GroundTruthStudy()` has been replaced by + :py:func:`~spikeinterface.benchmark.SorterStudy()` + + +This module also aims to benchmark sorting components (detection, clustering, motion, template matching) using the +same base class :py:func:`~spikeinterface.benchmark.BenchmarkStudy()` but specialized to a targeted component. + +By design, the main class handle the concept of "levels" : this allows to compare several complexities at the same time. +For instance, compare kilosort4 vs kilsort2.5 (level 0) for different noises amplitudes (level 1) combined with +several motion vectors (leevel 2). + +**Example: compare many sorters : a ground truth study** + +We have a high level class to compare many sorters against ground truth: :py:func:`~spikeinterface.benchmark.SorterStudy()` + + +A study is a systematic performance comparison of several ground truth recordings with several sorters or several cases +like the different parameter sets. + +The study class proposes high-level tool functions to run many ground truth comparisons with many "cases" +on many recordings and then collect and aggregate results in an easy way. + +The all mechanism is based on an intrinsic organization into a "study_folder" with several subfolders: + + * datasets: contains ground truth datasets + * sorters : contains outputs of sorters + * sortings: contains light copy of all sorting + * metrics: contains metrics + * ... + + +.. code-block:: python + + import matplotlib.pyplot as plt + import seaborn as sns + + import spikeinterface.extractors as se + import spikeinterface.widgets as sw + from spikeinterface.benchmark import SorterStudy + + + # generate 2 simulated datasets (could be also mearec files) + rec0, gt_sorting0 = generate_ground_truth_recording(num_channels=4, durations=[30.], seed=42) + rec1, gt_sorting1 = generate_ground_truth_recording(num_channels=4, durations=[30.], seed=91) + + datasets = { + "toy0": (rec0, gt_sorting0), + "toy1": (rec1, gt_sorting1), + } + + # define some "cases" here we want to test tridesclous2 on 2 datasets and spykingcircus2 on one dataset + # so it is a two level study (sorter_name, dataset) + # this could be more complicated like (sorter_name, dataset, params) + cases = { + ("tdc2", "toy0"): { + "label": "tridesclous2 on tetrode0", + "dataset": "toy0", + "params": {"sorter_name": "tridesclous2"} + }, + ("tdc2", "toy1"): { + "label": "tridesclous2 on tetrode1", + "dataset": "toy1", + "params": {"sorter_name": "tridesclous2"} + }, + ("sc", "toy0"): { + "label": "spykingcircus2 on tetrode0", + "dataset": "toy0", + "params": { + "sorter_name": "spykingcircus", + "docker_image": True + }, + }, + } + # this initilizes a folder + study = SorterStudy.create(study_folder=study_folder, datasets=datasets, cases=cases, + levels=["sorter_name", "dataset"]) + + + # This internally do run_sorter() for all cases in one function + study.run() + + # Run the benchmark : this internanly do compare_sorter_to_ground_truth() for all cases + study.compute_results() + + # Collect comparisons one by one + for case_key in study.cases: + print('*' * 10) + print(case_key) + # raw counting of tp/fp/... + comp = study.get_result(case_key)["gt_comparison"] + # summary + comp.print_summary() + perf_unit = comp.get_performance(method='by_unit') + perf_avg = comp.get_performance(method='pooled_with_average') + # some plots + m = comp.get_confusion_matrix() + w_comp = sw.plot_agreement_matrix(sorting_comparison=comp) + + # Collect synthetic dataframes and display + # As shown previously, the performance is returned as a pandas dataframe. + # The spikeinterface.comparison.get_performance_by_unit() function, + # gathers all the outputs in the study folder and merges them into a single dataframe. + # Same idea for spikeinterface.comparison.get_count_units() + + # this is a dataframe + perfs = study.get_performance_by_unit() + + # this is a dataframe + unit_counts = study.get_count_units() + + # Study also have several plotting methods for plotting the result + study.plot_agreement_matrix() + study.plot_unit_counts() + study.plot_performances(mode="ordered") + study.plot_performances(mode="snr") + + + + +Benchmark spike collisions +-------------------------- + +SpikeInterface also has a specific toolset to benchmark how well sorters are at recovering spikes in "collision". + +We have three classes to handle collision-specific comparisons, and also to quantify the effects on correlogram +estimation: + + * :py:class:`~spikeinterface.comparison.CollisionGTComparison` + * :py:class:`~spikeinterface.comparison.CorrelogramGTComparison` + +For more details, checkout the following paper: + +`Samuel Garcia, Alessio P. Buccino and Pierre Yger. "How Do Spike Collisions Affect Spike Sorting Performance?" `_ diff --git a/doc/modules/comparison.rst b/doc/modules/comparison.rst index edee7f1fda..a02d76664d 100644 --- a/doc/modules/comparison.rst +++ b/doc/modules/comparison.rst @@ -5,6 +5,10 @@ Comparison module SpikeInterface has a :py:mod:`~spikeinterface.comparison` module, which contains functions and tools to compare spike trains and templates (useful for tracking units over multiple sessions). +.. note:: + + In version 0.102.0 the benchmark part of comparison has moved in the new :py:mod:`~spikeinterface.benchmark` + In addition, the :py:mod:`~spikeinterface.comparison` module contains advanced benchmarking tools to evaluate the effects of spike collisions on spike sorting results, and to construct hybrid recordings for comparison. @@ -242,135 +246,6 @@ An **over-merged** unit has a relatively high agreement (>= 0.2 by default) for cmp_gt_HS.get_redundant_units(redundant_score=0.2) - -**Example: compare many sorters with a Ground Truth Study** - -We also have a high level class to compare many sorters against ground truth: -:py:func:`~spikeinterface.comparison.GroundTruthStudy()` - -A study is a systematic performance comparison of several ground truth recordings with several sorters or several cases -like the different parameter sets. - -The study class proposes high-level tool functions to run many ground truth comparisons with many "cases" -on many recordings and then collect and aggregate results in an easy way. - -The all mechanism is based on an intrinsic organization into a "study_folder" with several subfolders: - - * datasets: contains ground truth datasets - * sorters : contains outputs of sorters - * sortings: contains light copy of all sorting - * metrics: contains metrics - * ... - - -.. code-block:: python - - import matplotlib.pyplot as plt - import seaborn as sns - - import spikeinterface.extractors as se - import spikeinterface.widgets as sw - from spikeinterface.comparison import GroundTruthStudy - - - # generate 2 simulated datasets (could be also mearec files) - rec0, gt_sorting0 = generate_ground_truth_recording(num_channels=4, durations=[30.], seed=42) - rec1, gt_sorting1 = generate_ground_truth_recording(num_channels=4, durations=[30.], seed=91) - - datasets = { - "toy0": (rec0, gt_sorting0), - "toy1": (rec1, gt_sorting1), - } - - # define some "cases" here we want to test tridesclous2 on 2 datasets and spykingcircus2 on one dataset - # so it is a two level study (sorter_name, dataset) - # this could be more complicated like (sorter_name, dataset, params) - cases = { - ("tdc2", "toy0"): { - "label": "tridesclous2 on tetrode0", - "dataset": "toy0", - "run_sorter_params": { - "sorter_name": "tridesclous2", - }, - }, - ("tdc2", "toy1"): { - "label": "tridesclous2 on tetrode1", - "dataset": "toy1", - "run_sorter_params": { - "sorter_name": "tridesclous2", - }, - }, - - ("sc", "toy0"): { - "label": "spykingcircus2 on tetrode0", - "dataset": "toy0", - "run_sorter_params": { - "sorter_name": "spykingcircus", - "docker_image": True - }, - }, - } - # this initilizes a folder - study = GroundTruthStudy.create(study_folder=study_folder, datasets=datasets, cases=cases, - levels=["sorter_name", "dataset"]) - - - # all cases in one function - study.run_sorters() - - # Collect comparisons - # - # You can collect in one shot all results and run the - # GroundTruthComparison on it. - # So you can have fine access to all individual results. - # - # Note: use exhaustive_gt=True when you know exactly how many - # units in the ground truth (for synthetic datasets) - - # run all comparisons and loop over the results - study.run_comparisons(exhaustive_gt=True) - for key, comp in study.comparisons.items(): - print('*' * 10) - print(key) - # raw counting of tp/fp/... - print(comp.count_score) - # summary - comp.print_summary() - perf_unit = comp.get_performance(method='by_unit') - perf_avg = comp.get_performance(method='pooled_with_average') - # some plots - m = comp.get_confusion_matrix() - w_comp = sw.plot_agreement_matrix(sorting_comparison=comp) - - # Collect synthetic dataframes and display - # As shown previously, the performance is returned as a pandas dataframe. - # The spikeinterface.comparison.get_performance_by_unit() function, - # gathers all the outputs in the study folder and merges them into a single dataframe. - # Same idea for spikeinterface.comparison.get_count_units() - - # this is a dataframe - perfs = study.get_performance_by_unit() - - # this is a dataframe - unit_counts = study.get_count_units() - - # we can also access run times - run_times = study.get_run_times() - print(run_times) - - # Easy plotting with seaborn - fig1, ax1 = plt.subplots() - sns.barplot(data=run_times, x='rec_name', y='run_time', hue='sorter_name', ax=ax1) - ax1.set_title('Run times') - - ############################################################################## - - fig2, ax2 = plt.subplots() - sns.swarmplot(data=perfs, x='sorter_name', y='recall', hue='rec_name', ax=ax2) - ax2.set_title('Recall') - ax2.set_ylim(-0.1, 1.1) - - .. _symmetric: 2. Compare the output of two spike sorters (symmetric comparison) @@ -537,35 +412,3 @@ sorting analyzers from day 1 (:code:`analyzer_day1`) to day 5 (:code:`analyzer_d # match all m_tcmp = sc.compare_multiple_templates(waveform_list=analyzer_list, name_list=["D1", "D2", "D3", "D4", "D5"]) - - - -Benchmark spike collisions --------------------------- - -SpikeInterface also has a specific toolset to benchmark how well sorters are at recovering spikes in "collision". - -We have three classes to handle collision-specific comparisons, and also to quantify the effects on correlogram -estimation: - - * :py:class:`~spikeinterface.comparison.CollisionGTComparison` - * :py:class:`~spikeinterface.comparison.CorrelogramGTComparison` - * :py:class:`~spikeinterface.comparison.CollisionGTStudy` - * :py:class:`~spikeinterface.comparison.CorrelogramGTStudy` - -For more details, checkout the following paper: - -`Samuel Garcia, Alessio P. Buccino and Pierre Yger. "How Do Spike Collisions Affect Spike Sorting Performance?" `_ - - -Hybrid recording ----------------- - -To benchmark spike sorting results, we need ground-truth spiking activity. -This can be generated with artificial simulations, e.g., using `MEArec `_, or -alternatively by generating so-called "hybrid" recordings. - -The :py:mod:`~spikeinterface.comparison` module includes functions to generate such "hybrid" recordings: - - * :py:func:`~spikeinterface.comparison.create_hybrid_units_recording`: add new units to an existing recording - * :py:func:`~spikeinterface.comparison.create_hybrid_spikes_recording`: add new spikes to existing units in a recording diff --git a/doc/modules/core.rst b/doc/modules/core.rst index 8aa1815a55..5df9a7e6b1 100644 --- a/doc/modules/core.rst +++ b/doc/modules/core.rst @@ -385,7 +385,7 @@ and merging unit groups. sorting_analyzer_select = sorting_analyzer.select_units(unit_ids=[0, 1, 2, 3]) sorting_analyzer_remove = sorting_analyzer.remove_units(remove_unit_ids=[0]) - sorting_analyzer_merge = sorting_analyzer.merge_units([0, 1], [2, 3]) + sorting_analyzer_merge = sorting_analyzer.merge_units([[0, 1], [2, 3]]) All computed extensions will be automatically propagated or merged when curating. Please refer to the :ref:`modules/curation:Curation module` documentation for more information. diff --git a/doc/modules/curation.rst b/doc/modules/curation.rst index d115b33e4a..37de992806 100644 --- a/doc/modules/curation.rst +++ b/doc/modules/curation.rst @@ -88,7 +88,7 @@ The ``censored_period_ms`` parameter is the time window in milliseconds to consi The :py:func:`~spikeinterface.curation.remove_redundand_units` function removes redundant units from the sorting output. Redundant units are units that share over a certain percentage of spikes, by default 80%. -The function can acto both on a ``BaseSorting`` or a ``SortingAnalyzer`` object. +The function can act both on a ``BaseSorting`` or a ``SortingAnalyzer`` object. .. code-block:: python @@ -102,13 +102,18 @@ The function can acto both on a ``BaseSorting`` or a ``SortingAnalyzer`` object. ) # remove redundant units from SortingAnalyzer object - clean_sorting_analyzer = remove_redundant_units( + # note this returns a cleaned sorting + clean_sorting = remove_redundant_units( sorting_analyzer, duplicate_threshold=0.9, remove_strategy="min_shift" ) + # in order to have a SortingAnalyer with only the non-redundant units one must + # select the designed units remembering to give format and folder if one wants + # a persistent SortingAnalyzer. + clean_sorting_analyzer = sorting_analyzer.select_units(clean_sorting.unit_ids) -We recommend usinf the ``SortingAnalyzer`` approach, since the ``min_shift`` strategy keeps +We recommend using the ``SortingAnalyzer`` approach, since the ``min_shift`` strategy keeps the unit (among the redundant ones), with a better template alignment. diff --git a/doc/modules/qualitymetrics/synchrony.rst b/doc/modules/qualitymetrics/synchrony.rst index d244fd0c0f..696dacbd3c 100644 --- a/doc/modules/qualitymetrics/synchrony.rst +++ b/doc/modules/qualitymetrics/synchrony.rst @@ -12,7 +12,7 @@ trains. This way synchronous events can be found both in multi-unit and single-u Complexity is calculated by counting the number of spikes (i.e. non-empty bins) that occur at the same sample index, within and across spike trains. -Synchrony metrics can be computed for different synchrony sizes (>1), defining the number of simultaneous spikes to count. +Synchrony metrics are computed for 2, 4 and 8 synchronous spikes. @@ -29,7 +29,7 @@ Example code import spikeinterface.qualitymetrics as sqm # Combine a sorting and recording into a sorting_analyzer - synchrony = sqm.compute_synchrony_metrics(sorting_analyzer=sorting_analyzer synchrony_sizes=(2, 4, 8)) + synchrony = sqm.compute_synchrony_metrics(sorting_analyzer=sorting_analyzer) # synchrony is a tuple of dicts with the synchrony metrics for each unit diff --git a/doc/releases/0.101.2.rst b/doc/releases/0.101.2.rst new file mode 100644 index 0000000000..e54546ddfb --- /dev/null +++ b/doc/releases/0.101.2.rst @@ -0,0 +1,66 @@ +.. _release0.101.2: + +SpikeInterface 0.101.2 release notes +------------------------------------ + +4th October 2024 + +Minor release with bug fixes + +core: + +* Fix `random_spikes_selection()` (#3456) +* Expose `backend_options` at the analyzer level to set `storage_options` and `saving_options` (#3446) +* Avoid warnings in `SortingAnalyzer` (#3455) +* Fix `reset_global_job_kwargs` (#3452) +* Allow to save recordingless analyzer as (#3443) +* Fix compute analyzer pipeline with tmp recording (#3433) +* Fix bug in saving zarr recordings (#3432) +* Set `run_info` to `None` for `load_waveforms` (#3430) +* Fix integer overflow in parallel computing (#3426) +* Refactor `pandas` save load and `convert_dtypes` (#3412) +* Add spike-train based lazy `SortingGenerator` (#2227) + + +extractors: + +* Improve IBL recording extractors by PID (#3449) + +sorters: + +* Get default encoding for `Popen` (#3439) + +postprocessing: + +* Add `max_threads_per_process` and `mp_context` to pca by channel computation and PCA metrics (#3434) + +widgets: + +* Fix metrics widgets for convert_dtypes (#3417) +* Fix plot motion for multi-segment (#3414) + +motion correction: + +* Auto-cast recording to float prior to interpolation (#3415) + +documentation: + +* Add docstring for `generate_unit_locations` (#3418) +* Add `get_channel_locations` to the base recording API (#3403) + +continuous integration: + +* Enable testing arm64 Mac architecture in the CI (#3422) +* Add kachery_zone secret (#3416) + +testing: + +* Relax causal filter tests (#3445) + +Contributors: + +* @alejoe91 +* @h-mayorquin +* @jiumao2 +* @samuelgarcia +* @zm711 diff --git a/doc/tutorials_custom_index.rst b/doc/tutorials_custom_index.rst new file mode 100644 index 0000000000..82f2c06eed --- /dev/null +++ b/doc/tutorials_custom_index.rst @@ -0,0 +1,229 @@ +.. This page provides a custom index to the 'Tutorials' page, rather than the default sphinx-gallery +.. generated page. The benefits of this are flexibility in design and inclusion of non-sphinx files in the index. +.. +.. To update this index with a new documentation page +.. 1) Copy the grid-item-card and associated ".. raw:: html" section. +.. 2) change :link: to a link to your page. If this is an `.rst` file, point to the rst file directly. +.. If it is a sphinx-gallery generated file, format the path as separated by underscore and prefix `sphx_glr`, +.. pointing to the .py file. e.g. `tutorials/my/page.py` -> `sphx_glr_tutorials_my_page.py +.. 3) Change :img-top: to point to the thumbnail image of your choosing. You can point to images generated +.. in the sphinx gallery page if you wish. +.. 4) In the `html` section, change the `default-title` to your pages title and `hover-content` to the subtitle. + +:orphan: + +Tutorials +============ + +Longer form tutorials about using SpikeInterface. Many of these are downloadable +as notebooks or Python scripts so that you can "code along" with the tutorials. + +If you're new to SpikeInterface, we recommend trying out the +:ref:`get_started/quickstart:Quickstart tutorial` first. + +Updating from legacy +-------------------- + +.. toctree:: + :maxdepth: 1 + + tutorials/waveform_extractor_to_sorting_analyzer + +Core tutorials +-------------- + +These tutorials focus on the :py:mod:`spikeinterface.core` module. + +.. grid:: 1 2 2 3 + :gutter: 2 + + .. grid-item-card:: Recording objects + :link-type: ref + :link: sphx_glr_tutorials_core_plot_1_recording_extractor.py + :img-top: /tutorials/core/images/thumb/sphx_glr_plot_1_recording_extractor_thumb.png + :img-alt: Recording objects + :class-card: gallery-card + :text-align: center + + .. grid-item-card:: Sorting objects + :link-type: ref + :link: sphx_glr_tutorials_core_plot_2_sorting_extractor.py + :img-top: /tutorials/core/images/thumb/sphx_glr_plot_2_sorting_extractor_thumb.png + :img-alt: Sorting objects + :class-card: gallery-card + :text-align: center + + .. grid-item-card:: Handling probe information + :link-type: ref + :link: sphx_glr_tutorials_core_plot_3_handle_probe_info.py + :img-top: /tutorials/core/images/thumb/sphx_glr_plot_3_handle_probe_info_thumb.png + :img-alt: Handling probe information + :class-card: gallery-card + :text-align: center + + .. grid-item-card:: SortingAnalyzer + :link-type: ref + :link: sphx_glr_tutorials_core_plot_4_sorting_analyzer.py + :img-top: /tutorials/core/images/thumb/sphx_glr_plot_4_sorting_analyzer_thumb.png + :img-alt: SortingAnalyzer + :class-card: gallery-card + :text-align: center + + .. grid-item-card:: Append and/or concatenate segments + :link-type: ref + :link: sphx_glr_tutorials_core_plot_5_append_concatenate_segments.py + :img-top: /tutorials/core/images/thumb/sphx_glr_plot_5_append_concatenate_segments_thumb.png + :img-alt: Append/Concatenate segments + :class-card: gallery-card + :text-align: center + + .. grid-item-card:: Handle time information + :link-type: ref + :link: sphx_glr_tutorials_core_plot_6_handle_times.py + :img-top: /tutorials/core/images/thumb/sphx_glr_plot_6_handle_times_thumb.png + :img-alt: Handle time information + :class-card: gallery-card + :text-align: center + +Extractors tutorials +-------------------- + +The :py:mod:`spikeinterface.extractors` module is designed to load and save recorded and sorted data, and to handle probe information. + +.. grid:: 1 2 2 3 + :gutter: 2 + + .. grid-item-card:: Read various formats + :link-type: ref + :link: sphx_glr_tutorials_extractors_plot_1_read_various_formats.py + :img-top: /tutorials/extractors/images/thumb/sphx_glr_plot_1_read_various_formats_thumb.png + :img-alt: Read various formats + :class-card: gallery-card + :text-align: center + + .. grid-item-card:: Working with unscaled traces + :link-type: ref + :link: sphx_glr_tutorials_extractors_plot_2_working_with_unscaled_traces.py + :img-top: /tutorials/extractors/images/thumb/sphx_glr_plot_2_working_with_unscaled_traces_thumb.png + :img-alt: Unscaled traces + :class-card: gallery-card + :text-align: center + +Quality metrics tutorial +------------------------ + +The :code:`spikeinterface.qualitymetrics` module allows users to compute various quality metrics to assess the goodness of a spike sorting output. + +.. grid:: 1 2 2 3 + :gutter: 2 + + .. grid-item-card:: Quality Metrics + :link-type: ref + :link: sphx_glr_tutorials_qualitymetrics_plot_3_quality_metrics.py + :img-top: /tutorials/qualitymetrics/images/thumb/sphx_glr_plot_3_quality_metrics_thumb.png + :img-alt: Quality Metrics + :class-card: gallery-card + :text-align: center + + .. grid-item-card:: Curation Tutorial + :link-type: ref + :link: sphx_glr_tutorials_qualitymetrics_plot_4_curation.py + :img-top: /tutorials/qualitymetrics/images/thumb/sphx_glr_plot_4_curation_thumb.png + :img-alt: Curation Tutorial + :class-card: gallery-card + :text-align: center + +Automated curation tutorials +---------------------------- + +Learn how to curate your units using a trained machine learning model. Or how to create +and share your own model. + +.. grid:: 1 2 2 3 + :gutter: 2 + + .. grid-item-card:: Model-based curation + :link-type: ref + :link: sphx_glr_tutorials_curation_plot_1_automated_curation.py + :img-top: /tutorials/curation/images/sphx_glr_plot_1_automated_curation_002.png + :img-alt: Model-based curation + :class-card: gallery-card + :text-align: center + + .. grid-item-card:: Train your own model + :link-type: ref + :link: sphx_glr_tutorials_curation_plot_2_train_a_model.py + :img-top: /tutorials/curation/images/thumb/sphx_glr_plot_2_train_a_model_thumb.png + :img-alt: Train your own model + :class-card: gallery-card + :text-align: center + + .. grid-item-card:: Upload your model to HuggingFaceHub + :link-type: ref + :link: sphx_glr_tutorials_curation_plot_3_upload_a_model.py + :img-top: /images/hf-logo.svg + :img-alt: Upload your model + :class-card: gallery-card + :text-align: center + +Comparison tutorial +------------------- + +The :code:`spikeinterface.comparison` module allows you to compare sorter outputs or benchmark against ground truth. + +.. grid:: 1 2 2 3 + :gutter: 2 + + .. grid-item-card:: Sorter Comparison + :link-type: ref + :link: sphx_glr_tutorials_comparison_plot_5_comparison_sorter_weaknesses.py + :img-top: /tutorials/comparison/images/thumb/sphx_glr_plot_5_comparison_sorter_weaknesses_thumb.png + :img-alt: Sorter Comparison + :class-card: gallery-card + :text-align: center + +Widgets tutorials +----------------- + +The :code:`widgets` module contains several plotting routines (widgets) for visualizing recordings, sorting data, probe layout, and more. + +.. grid:: 1 2 2 3 + :gutter: 2 + + .. grid-item-card:: RecordingExtractor Widgets + :link-type: ref + :link: sphx_glr_tutorials_widgets_plot_1_rec_gallery.py + :img-top: /tutorials/widgets/images/thumb/sphx_glr_plot_1_rec_gallery_thumb.png + :img-alt: Recording Widgets + :class-card: gallery-card + :text-align: center + + .. grid-item-card:: SortingExtractor Widgets + :link-type: ref + :link: sphx_glr_tutorials_widgets_plot_2_sort_gallery.py + :img-top: /tutorials/widgets/images/thumb/sphx_glr_plot_2_sort_gallery_thumb.png + :img-alt: Sorting Widgets + :class-card: gallery-card + :text-align: center + + .. grid-item-card:: Waveforms Widgets + :link-type: ref + :link: sphx_glr_tutorials_widgets_plot_3_waveforms_gallery.py + :img-top: /tutorials/widgets/images/thumb/sphx_glr_plot_3_waveforms_gallery_thumb.png + :img-alt: Waveforms Widgets + :class-card: gallery-card + :text-align: center + + .. grid-item-card:: Peaks Widgets + :link-type: ref + :link: sphx_glr_tutorials_widgets_plot_4_peaks_gallery.py + :img-top: /tutorials/widgets/images/thumb/sphx_glr_plot_4_peaks_gallery_thumb.png + :img-alt: Peaks Widgets + :class-card: gallery-card + :text-align: center + +Download All Examples +--------------------- + +- :download:`Download all examples in Python source code ` +- :download:`Download all examples in Jupyter notebooks ` diff --git a/doc/whatisnew.rst b/doc/whatisnew.rst index c8038387f9..2851f8ab4a 100644 --- a/doc/whatisnew.rst +++ b/doc/whatisnew.rst @@ -8,6 +8,7 @@ Release notes .. toctree:: :maxdepth: 1 + releases/0.101.2.rst releases/0.101.1.rst releases/0.101.0.rst releases/0.100.8.rst @@ -44,6 +45,11 @@ Release notes releases/0.9.1.rst +Version 0.101.2 +=============== + +* Minor release with bug fixes + Version 0.101.1 =============== diff --git a/examples/how_to/benchmark_with_hybrid_recordings.py b/examples/how_to/benchmark_with_hybrid_recordings.py index abf6a25ff5..d983578797 100644 --- a/examples/how_to/benchmark_with_hybrid_recordings.py +++ b/examples/how_to/benchmark_with_hybrid_recordings.py @@ -276,7 +276,8 @@ # From the performance plots, we can see that there is no clear "winner", but `Kilosort3` definitely performs worse than the other options. # # Although non of the sorters find all units perfectly, `Kilosort2.5`, `Kilosort4`, and `SpyKING CIRCUS 2` all find around 10-12 hybrid units with accuracy greater than 80%. -# `Kilosort4` has a better overall curve, being able to find almost all units with an accuracy above 50%. `Kilosort2.5` performs well when looking at precision (finding all spikes in a hybrid unit), at the cost of lower recall (finding spikes when it shouldn't). +# `Kilosort4` has a better overall curve, being able to find almost all units with an accuracy above 50%. `Kilosort2.5` performs well when looking at precision (not finding spikes +# when it shouldn’t), but it has a lower recall (finding all spikes in the ground truth). # # # In this example, we showed how to: diff --git a/examples/tutorials/curation/README.rst b/examples/tutorials/curation/README.rst new file mode 100644 index 0000000000..0f64179e65 --- /dev/null +++ b/examples/tutorials/curation/README.rst @@ -0,0 +1,5 @@ +Curation tutorials +------------------ + +Learn how to use models to automatically curated your sorted data, or generate models +based on your own curation. diff --git a/examples/tutorials/curation/plot_1_automated_curation.py b/examples/tutorials/curation/plot_1_automated_curation.py new file mode 100644 index 0000000000..e88b0973df --- /dev/null +++ b/examples/tutorials/curation/plot_1_automated_curation.py @@ -0,0 +1,287 @@ +""" +Model-based curation tutorial +============================= + +Sorters are not perfect. They output excellent units, as well as noisy ones, and ones that +should be split or merged. Hence one should curate the generated units. Historically, this +has been done using laborious manual curation. An alternative is to use automated methods +based on metrics which quantify features of the units. In spikeinterface these are the +quality metrics and the template metrics. A simple approach is to use thresholding: +only accept units whose metrics pass a certain quality threshold. Another approach is to +take one (or more) manually labelled sortings, whose metrics have been computed, and train +a machine learning model to predict labels. + +This notebook provides a step-by-step guide on how to take a machine learning model that +someone else has trained and use it to curate your own spike sorted output. SpikeInterface +also provides the tools to train your own model, +`which you can learn about here `_. + +We'll download a toy model and use it to label our sorted data. We start by importing some packages +""" + +import warnings +warnings.filterwarnings("ignore") +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt + +import spikeinterface.core as si +import spikeinterface.curation as sc +import spikeinterface.widgets as sw + +# note: you can use more cores using e.g. +# si.set_global_jobs_kwargs(n_jobs = 8) + +############################################################################## +# Download a pretrained model +# --------------------------- +# +# Let's download a pretrained model from `Hugging Face `_ (HF), +# a model sharing platform focused on AI and ML models and datasets. The +# ``load_model`` function allows us to download directly from HF, or use a model in a local +# folder. The function downloads the model and saves it in a temporary folder and returns a +# model and some metadata about the model. + +model, model_info = sc.load_model( + repo_id = "SpikeInterface/toy_tetrode_model", + trusted = ['numpy.dtype'] +) + + +############################################################################## +# This model was trained on artifically generated tetrode data. There are also models trained +# on real data, like the one discussed `below <#A-model-trained-on-real-Neuropixels-data>`_. +# Each model object has a nice html representation, which will appear if you're using a Jupyter notebook. + +model + +############################################################################## +# This tells us more information about the model. The one we've just downloaded was trained used +# a ``RandomForestClassifier```. You can also discover this information by running +# ``model.get_params()``. The model object (an `sklearn Pipeline `_) also contains information +# about which metrics were used to compute the model. We can access it from the model (or from the model_info) + +print(model.feature_names_in_) + +############################################################################## +# Hence, to use this model we need to create a ``sorting_analyzer`` with all these metrics computed. +# We'll do this by generating a recording and sorting, creating a sorting analyzer and computing a +# bunch of extensions. Follow these links for more info on `recordings `_, `sortings `_, `sorting analyzers `_ +# and `extensions `_. + +recording, sorting = si.generate_ground_truth_recording(num_channels=4, seed=4, num_units=10) +sorting_analyzer = si.create_sorting_analyzer(sorting=sorting, recording=recording) +sorting_analyzer.compute(['noise_levels','random_spikes','waveforms','templates','spike_locations','spike_amplitudes','correlograms','principal_components','quality_metrics','template_metrics']) +sorting_analyzer.compute('template_metrics', include_multi_channel_metrics=True) + +############################################################################## +# This sorting_analyzer now contains the required quality metrics and template metrics. +# We can check that this is true by accessing the extension data. + +all_metric_names = list(sorting_analyzer.get_extension('quality_metrics').get_data().keys()) + list(sorting_analyzer.get_extension('template_metrics').get_data().keys()) +print(set(all_metric_names) == set(model.feature_names_in_)) + +############################################################################## +# Great! We can now use the model to predict labels. Here, we pass the HF repo id directly +# to the ``auto_label_units`` function. This returns a dictionary containing a label and +# a confidence for each unit contained in the ``sorting_analyzer``. + +labels = sc.auto_label_units( + sorting_analyzer = sorting_analyzer, + repo_id = "SpikeInterface/toy_tetrode_model", + trusted = ['numpy.dtype'] +) + +print(labels) + + +############################################################################## +# The model has labelled one unit as bad. Let's look at that one, and also the 'good' unit +# with the highest confidence of being 'good'. + +sw.plot_unit_templates(sorting_analyzer, unit_ids=['7','9']) + +############################################################################## +# Nice! Unit 9 looks more like an expected action potential waveform while unit 7 doesn't, +# and it seems reasonable that unit 7 is labelled as `bad`. However, for certain experiments +# or brain areas, unit 7 might be a great small-amplitude unit. This example highlights that +# you should be careful applying models trained on one dataset to your own dataset. You can +# explore the currently available models on the `spikeinterface hugging face hub `_ +# page, or `train your own one `_. +# +# Assess the model performance +# ---------------------------- +# +# To assess the performance of the model relative to labels assigned by a human creator, we can load or generate some +# "human labels", and plot a confusion matrix of predicted vs human labels for all clusters. Here +# we'll be a conservative human, who has labelled several units with small amplitudes as 'bad'. + +human_labels = ['bad', 'good', 'good', 'bad', 'good', 'bad', 'good', 'bad', 'good', 'good'] + +# Note: if you labelled using phy, you can load the labels using: +# human_labels = sorting_analyzer.sorting.get_property('quality') +# We need to load in the `label_conversion` dictionary, which converts integers such +# as '0' and '1' to readable labels such as 'good' and 'bad'. This is stored as +# in `model_info`, which we loaded earlier. + +from sklearn.metrics import confusion_matrix, balanced_accuracy_score + +label_conversion = model_info['label_conversion'] +predictions = labels['prediction'] + +conf_matrix = confusion_matrix(human_labels, predictions) + +# Calculate balanced accuracy for the confusion matrix +balanced_accuracy = balanced_accuracy_score(human_labels, predictions) + +plt.imshow(conf_matrix) +for (index, value) in np.ndenumerate(conf_matrix): + plt.annotate( str(value), xy=index, color="white", fontsize="15") +plt.xlabel('Predicted Label') +plt.ylabel('Human Label') +plt.xticks(ticks = [0, 1], labels = list(label_conversion.values())) +plt.yticks(ticks = [0, 1], labels = list(label_conversion.values())) +plt.title('Predicted vs Human Label') +plt.suptitle(f"Balanced Accuracy: {balanced_accuracy}") +plt.show() + + +############################################################################## +# Here, there are several false positives (if we consider the human labels to be "the truth"). +# +# Next, we can also see how the model's confidence relates to the probability that the model +# label matches the human label. +# +# This could be used to help decide which units should be auto-curated and which need further +# manual creation. For example, we might accept any unit as 'good' that the model predicts +# as 'good' with confidence over a threshold, say 80%. If the confidence is lower we might decide to take a +# look at this unit manually. Below, we will create a plot that shows how the agreement +# between human and model labels changes as we increase the confidence threshold. We see that +# the agreement increases as the confidence does. So the model gets more accurate with a +# higher confidence threshold, as expceted. + + +def calculate_moving_avg(label_df, confidence_label, window_size): + + label_df[f'{confidence_label}_decile'] = pd.cut(label_df[confidence_label], 10, labels=False, duplicates='drop') + # Group by decile and calculate the proportion of correct labels (agreement) + p_label_grouped = label_df.groupby(f'{confidence_label}_decile')['model_x_human_agreement'].mean() + # Convert decile to range 0-1 + p_label_grouped.index = p_label_grouped.index / 10 + # Sort the DataFrame by confidence scores + label_df_sorted = label_df.sort_values(by=confidence_label) + + p_label_moving_avg = label_df_sorted['model_x_human_agreement'].rolling(window=window_size).mean() + + return label_df_sorted[confidence_label], p_label_moving_avg + +confidences = labels['probability'] + +# Make dataframe of human label, model label, and confidence +label_df = pd.DataFrame(data = { + 'human_label': human_labels, + 'decoder_label': predictions, + 'confidence': confidences}, + index = sorting_analyzer.sorting.get_unit_ids()) + +# Calculate the proportion of agreed labels by confidence decile +label_df['model_x_human_agreement'] = label_df['human_label'] == label_df['decoder_label'] + +p_agreement_sorted, p_agreement_moving_avg = calculate_moving_avg(label_df, 'confidence', 3) + +# Plot the moving average of agreement +plt.figure(figsize=(6, 6)) +plt.plot(p_agreement_sorted, p_agreement_moving_avg, label = 'Moving Average') +plt.axhline(y=1/len(np.unique(predictions)), color='black', linestyle='--', label='Chance') +plt.xlabel('Confidence'); #plt.xlim(0.5, 1) +plt.ylabel('Proportion Agreement with Human Label'); plt.ylim(0, 1) +plt.title('Agreement vs Confidence (Moving Average)') +plt.legend(); plt.grid(True); plt.show() + +############################################################################## +# In this case, you might decide to only trust labels which had confidence over above 0.88, +# and manually labels the ones the model isn't so confident about. +# +# A model trained on real Neuropixels data +# ---------------------------------------- +# +# Above, we used a toy model trained on generated data. There are also models on HuggingFace +# trained on real data. +# +# For example, the following classifiers are trained on Neuropixels data from 11 mice recorded in +# V1,SC and ALM: https://huggingface.co/AnoushkaJain3/noise_neural_classifier/ and +# https://huggingface.co/AnoushkaJain3/sua_mua_classifier/ . One will classify units into +# `noise` or `not-noise` and the other will classify the `not-noise` units into single +# unit activity (sua) units and multi-unit activity (mua) units. +# +# There is more information about the model on the model's HuggingFace page. Take a look! +# The idea here is to first apply the noise/not-noise classifier, then the sua/mua one. +# We can do so as follows: +# + +# Apply the noise/not-noise model +noise_neuron_labels = sc.auto_label_units( + sorting_analyzer = sorting_analyzer, + repo_id = "AnoushkaJain3/noise_neural_classifier", + trust_model=True, +) + +noise_units = noise_neuron_labels[noise_neuron_labels['prediction']=='noise'] +analyzer_neural = sorting_analyzer.remove_units(noise_units.index) + +# Apply the sua/mua model +sua_mua_labels = sc.auto_label_units( + sorting_analyzer = analyzer_neural, + repo_id = "AnoushkaJain3/sua_mua_classifier", + trust_model=True, +) + +all_labels = pd.concat([sua_mua_labels, noise_units]).sort_index() +print(all_labels) + +############################################################################## +# If you run this without the ``trust_model=True`` parameter, you will receive an error: +# +# .. code-block:: +# +# UntrustedTypesFoundException: Untrusted types found in the file: ['sklearn.metrics._classification.balanced_accuracy_score', 'sklearn.metrics._scorer._Scorer', 'sklearn.model_selection._search_successive_halving.HalvingGridSearchCV', 'sklearn.model_selection._split.StratifiedKFold'] +# +# This is a security warning, which can be overcome by passing the trusted types list +# ``trusted = ['sklearn.metrics._classification.balanced_accuracy_score', 'sklearn.metrics._scorer._Scorer', 'sklearn.model_selection._search_successive_halving.HalvingGridSearchCV', 'sklearn.model_selection._split.StratifiedKFold']`` +# or by passing the ``trust_model=True``` keyword. +# +# .. dropdown:: More about security +# +# Sharing models, with are Python objects, is complicated. +# We have chosen to use the `skops format `_, instead +# of the common but insecure ``.pkl`` format (read about ``pickle`` security issues +# `here `_). While unpacking the ``.skops`` file, each function +# is checked. Ideally, skops should recognise all `sklearn`, `numpy` and `scipy` functions and +# allow the object to be loaded if it only contains these (and no unkown malicious code). But +# when ``skops`` it's not sure, it raises an error. Here, it doesn't recognise +# ``['sklearn.metrics._classification.balanced_accuracy_score', 'sklearn.metrics._scorer._Scorer', +# 'sklearn.model_selection._search_successive_halving.HalvingGridSearchCV', +# 'sklearn.model_selection._split.StratifiedKFold']``. Taking a look, these are all functions +# from `sklearn`, and we can happily add them to the ``trusted`` functions to load. +# +# In general, you should be cautious when downloading ``.skops`` files and ``.pkl`` files from repos, +# especially from unknown sources. +# +# Directly applying a sklearn Pipeline +# ------------------------------------ +# +# Instead of using ``HuggingFace`` and ``skops``, someone might have given you a model +# in differet way: perhaps by e-mail or a download. If you have the model in a +# folder, you can apply it in a very similar way: +# +# .. code-block:: +# +# labels = sc.auto_label_units( +# sorting_analyzer = sorting_analyzer, +# model_folder = "path/to/model/folder", +# ) + +############################################################################## +# Using this, you lose the advantages of the model metadata: the quality metric parameters +# are not checked and the labels are not converted their original human readable names (like +# 'good' and 'bad'). Hence we advise using the methods discussed above, when possible. diff --git a/examples/tutorials/curation/plot_2_train_a_model.py b/examples/tutorials/curation/plot_2_train_a_model.py new file mode 100644 index 0000000000..1a38836527 --- /dev/null +++ b/examples/tutorials/curation/plot_2_train_a_model.py @@ -0,0 +1,168 @@ +""" +Training a model for automated curation +============================= + +If the pretrained models do not give satisfactory performance on your data, it is easy to train your own classifier using SpikeInterface. +""" + + +############################################################################## +# Step 1: Generate and label data +# ------------------------------- +# +# First we will import our dependencies +import warnings +warnings.filterwarnings("ignore") +from pathlib import Path +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt + +import spikeinterface.core as si +import spikeinterface.curation as sc +import spikeinterface.widgets as sw + +# Note, you can set the number of cores you use using e.g. +# si.set_global_job_kwargs(n_jobs = 8) + +############################################################################## +# For this tutorial, we will use simulated data to create ``recording`` and ``sorting`` objects. We'll +# create two sorting objects: :code:`sorting_1` is coupled to the real recording, so the spike times of the sorter will +# perfectly match the spikes in the recording. Hence this will contain good units. However, we've +# uncoupled :code:`sorting_2` to the recording and the spike times will not be matched with the spikes in the recording. +# Hence these units will mostly be random noise. We'll combine the "good" and "noise" sortings into one sorting +# object using :code:`si.aggregate_units`. +# +# (When making your own model, you should +# `load your own recording `_ +# and `do a sorting `_ on your data.) + +recording, sorting_1 = si.generate_ground_truth_recording(num_channels=4, seed=1, num_units=5) +_, sorting_2 =si.generate_ground_truth_recording(num_channels=4, seed=2, num_units=5) + +both_sortings = si.aggregate_units([sorting_1, sorting_2]) + +############################################################################## +# To do some visualisation and postprocessing, we need to create a sorting analyzer, and +# compute some extensions: + +analyzer = si.create_sorting_analyzer(sorting = both_sortings, recording=recording) +analyzer.compute(['noise_levels','random_spikes','waveforms','templates']) + +############################################################################## +# Now we can plot the templates for the first and fifth units. The first (unit id 0) belongs to +# :code:`sorting_1` so should look like a real unit; the sixth (unit id 5) belongs to :code:`sorting_2` +# so should look like noise. + +sw.plot_unit_templates(analyzer, unit_ids=["0", "5"]) + +############################################################################## +# This is as expected: great! (Find out more about plotting using widgets `here `_.) +# We've set up our system so that the first five units are 'good' and the next five are 'bad'. +# So we can make a list of labels which contain this information. For real data, you could +# use a manual curation tool to make your own list. + +labels = ['good', 'good', 'good', 'good', 'good', 'bad', 'bad', 'bad', 'bad', 'bad'] + +############################################################################## +# Step 2: Train our model +# ----------------------- +# +# We'll now train a model, based on our labelled data. The model will be trained using properties +# of the units, and then be applied to units from other sortings. The properties we use are the +# `quality metrics `_ +# and `template metrics `_. +# Hence we need to compute these, using some ``sorting_analyzer``` extensions. + +analyzer.compute(['spike_locations','spike_amplitudes','correlograms','principal_components','quality_metrics','template_metrics']) + +############################################################################## +# Now that we have metrics and labels, we're ready to train the model using the +# ``train_model``` function. The trainer will try several classifiers, imputation strategies and +# scaling techniques then save the most accurate. To save time in this tutorial, +# we'll only try one classifier (Random Forest), imputation strategy (median) and scaling +# technique (standard scaler). +# +# We will use a list of one analyzer here, so the model is trained on a single +# session. In reality, we would usually train a model using multiple analyzers from an +# experiment, which should make the model more robust. To do this, you can simply pass +# a list of analyzers and a list of manually curated labels for each +# of these analyzers. Then the model would use all of these data as input. + +trainer = sc.train_model( + mode = "analyzers", # You can supply a labelled csv file instead of an analyzer + labels = [labels], + analyzers = [analyzer], + folder = "my_folder", # Where to save the model and model_info.json file + metric_names = None, # Specify which metrics to use for training: by default uses those already calculted + imputation_strategies = ["median"], # Defaults to all + scaling_techniques = ["standard_scaler"], # Defaults to all + classifiers = None, # Default to Random Forest only. Other classifiers you can try [ "AdaBoostClassifier","GradientBoostingClassifier","LogisticRegression","MLPClassifier"] + overwrite = True, # Whether or not to overwrite `folder` if it already exists. Default is False. + search_kwargs = {'cv': 3} # Parameters used during the model hyperparameter search +) + +best_model = trainer.best_pipeline + +############################################################################## +# +# You can pass many sklearn `classifiers `_ +# `imputation strategies `_ and +# `scalers `_, although the +# documentation is quite overwhelming. You can find the classifiers we've tried out +# using the ``sc.get_default_classifier_search_spaces`` function. +# +# The above code saves the model in ``model.skops``, some metadata in +# ``model_info.json`` and the model accuracies in ``model_accuracies.csv`` +# in the specified ``folder`` (in this case ``'my_folder'``). +# +# (``skops`` is a file format: you can think of it as a more-secure pkl file. `Read more `_.) +# +# The ``model_accuracies.csv`` file contains the accuracy, precision and recall of the +# tested models. Let's take a look: + +accuracies = pd.read_csv(Path("my_folder") / "model_accuracies.csv", index_col = 0) +accuracies.head() + +############################################################################## +# Our model is perfect!! This is because the task was *very* easy. We had 10 units; where +# half were pure noise and half were not. +# +# The model also contains some more information, such as which features are "important", +# as defined by sklearn (learn about feature importance of a Random Forest Classifier +# `here `_.) +# We can plot these: + +# Plot feature importances +importances = best_model.named_steps['classifier'].feature_importances_ +indices = np.argsort(importances)[::-1] + +# The sklearn importances are not computed for inputs whose values are all `nan`. +# Hence, we need to pick out the non-`nan` columns of our metrics +features = best_model.feature_names_in_ +n_features = best_model.n_features_in_ + +metrics = pd.concat([analyzer.get_extension('quality_metrics').get_data(), analyzer.get_extension('template_metrics').get_data()], axis=1) +non_null_metrics = ~(metrics.isnull().all()).values + +features = features[non_null_metrics] +n_features = len(features) + +plt.figure(figsize=(12, 7)) +plt.title("Feature Importances") +plt.bar(range(n_features), importances[indices], align="center") +plt.xticks(range(n_features), features[indices], rotation=90) +plt.xlim([-1, n_features]) +plt.subplots_adjust(bottom=0.3) +plt.show() + +############################################################################## +# Roughly, this means the model is using metrics such as "nn_hit_rate" and "l_ratio" +# but is not using "sync_spike_4" and "rp_contanimation". This is a toy model, so don't +# take these results seriously. But using this information, you could retrain another, +# simpler model using a subset of the metrics, by passing, e.g., +# ``metric_names = ['nn_hit_rate', 'l_ratio',...]`` to the ``train_model`` function. +# +# Now that you have a model, you can `apply it to another sorting +# `_ +# or `upload it to HuggingFaceHub `_. diff --git a/examples/tutorials/curation/plot_3_upload_a_model.py b/examples/tutorials/curation/plot_3_upload_a_model.py new file mode 100644 index 0000000000..0a9ea402db --- /dev/null +++ b/examples/tutorials/curation/plot_3_upload_a_model.py @@ -0,0 +1,139 @@ +""" +Upload a pipeline to Hugging Face Hub +===================================== +""" +############################################################################## +# In this tutorial we will upload a pipeline, trained in SpikeInterface, to the +# `Hugging Face Hub `_ (HFH). +# +# To do this, you first need to train a model. `Learn how here! `_ +# +# Hugging Face Hub? +# ----------------- +# Hugging Face Hub (HFH) is a model sharing platform focused on AI and ML models and datasets. +# To upload your own model to HFH, you need to make an account with them. +# If you do not want to make an account, you can simply share the model folder with colleagues. +# There are also several ways to interaction with HFH: the way we propose here doesn't use +# many of the tools ``skops`` and hugging face have developed such as the ``Card`` and +# ``hub_utils``. Feel free to check those out `here `_. +# +# Prepare your model +# ------------------ +# +# The plan is to make a folder with the following file structure +# +# .. code-block:: +# +# my_model_folder/ +# my_model_name.skops +# model_info.json +# training_data.csv +# labels.csv +# metadata.json +# +# SpikeInterface and HFH don't require you to keep this folder structure, we just advise it as +# best practice. +# +# If you've used SpikeInterface to train your model, the ``train_model`` function auto-generates +# most of this data. The only thing missing is the the ``metadata.json`` file. The purpose of this +# file is to detail how the model was trained, which can help prospective users decide if it +# is relevant for them. For example, taking +# a model trained on mouse data and applying it to a primate is likely a bad idea (or a +# great research paper!). And a model trained using tetrode data might have limited application +# on a silcone high-density probes. Hence we suggest saving at least the species, brain areas +# and probe information, as is done in the dictionary below. Note that we format the metadata +# so that the information +# in common with the NWB data format is consistent with it. Since the models can be trained +# on several curations, all the metadata fields are lists: +# +# .. code-block:: +# +# import json +# +# model_metadata = { +# "subject_species": ["Mus musculus"], +# "brain_areas": ["CA1"], +# "probes": +# [{ +# "manufacturer": "IMEc", +# "name": "Neuropixels 2.0" +# }] +# } +# with open("my_model_folder/metadata.json", "w") as file: +# json.dump(model_metadata, file) +# +# Upload to HuggingFaceHub +# ------------------------ +# +# We'll now upload this folder to HFH using the web interface. +# +# First, go to https://huggingface.co/ and make an account. Once you've logged in, press +# ``+`` then ``New model`` or find ``+ New Model`` in the user menu. You will be asked +# to enter a model name, to choose a license for the model and whether the model should +# be public or private. After you have made these choices, press ``Create Model``. +# +# You should be on your model's landing page, whose header looks something like +# +# .. image:: ../../images/initial_model_screen.png +# :width: 550 +# :align: center +# :alt: The page shown on HuggingFaceHub when a user first initialises a model +# +# Click Files, then ``+ Add file`` then ``Upload file(s)``. You can then add your files to the repository. Upload these by pressing ``Commit changes to main``. +# +# You are returned to the Files page, which should look similar to +# +# .. image:: ../../images/files_screen.png +# :width: 700 +# :align: center +# :alt: The file list for a model HuggingFaceHub. +# +# Let's add some information about the model for users to see when they go on your model's +# page. Click on ``Model card`` then ``Edit model card``. Here is a sample model card for +# For a model based on synthetically generated tetrode data, +# +# .. code-block:: +# +# --- +# license: mit +# --- +# +# ## Model description +# +# A toy model, trained on toy data generated from spikeinterface. +# +# # Intended use +# +# Used to try out automated curation in SpikeInterface. +# +# # How to Get Started with the Model +# +# This can be used to automatically label a sorting in spikeinterface. Provided you have a `sorting_analyzer`, it is used as follows +# +# ` ` ` python (NOTE: you should remove the spaces between each backtick. This is just formatting for the notebook you are reading) +# +# from spikeinterface.curation import auto_label_units +# labels = auto_label_units( +# sorting_analyzer = sorting_analyzer, +# repo_id = "SpikeInterface/toy_tetrode_model", +# trust_model=True +# ) +# ` ` ` +# +# or you can download the entire repositry to `a_folder_for_a_model`, and use +# +# ` ` ` python +# from spikeinterface.curation import auto_label_units +# +# labels = auto_label_units( +# sorting_analyzer = sorting_analyzer, +# model_folder = "path/to/a_folder_for_a_model", +# trusted = ['numpy.dtype'] +# ) +# ` ` ` +# +# # Authors +# +# Chris Halcrow +# +# You can see the repo with this Model card `here `_. diff --git a/examples/tutorials/qualitymetrics/plot_3_quality_mertics.py b/examples/tutorials/qualitymetrics/plot_3_quality_metrics.py similarity index 100% rename from examples/tutorials/qualitymetrics/plot_3_quality_mertics.py rename to examples/tutorials/qualitymetrics/plot_3_quality_metrics.py diff --git a/examples/tutorials/widgets/plot_2_sort_gallery.py b/examples/tutorials/widgets/plot_2_sort_gallery.py index da5c611ce4..056b5e3a8d 100644 --- a/examples/tutorials/widgets/plot_2_sort_gallery.py +++ b/examples/tutorials/widgets/plot_2_sort_gallery.py @@ -31,14 +31,14 @@ # plot_autocorrelograms() # ~~~~~~~~~~~~~~~~~~~~~~~~ -w_ach = sw.plot_autocorrelograms(sorting, window_ms=150.0, bin_ms=5.0, unit_ids=[1, 2, 5]) +w_ach = sw.plot_autocorrelograms(sorting, window_ms=150.0, bin_ms=5.0, unit_ids=['1', '2', '5']) ############################################################################## # plot_crosscorrelograms() # ~~~~~~~~~~~~~~~~~~~~~~~~ -w_cch = sw.plot_crosscorrelograms(sorting, window_ms=150.0, bin_ms=5.0, unit_ids=[1, 2, 5]) +w_cch = sw.plot_crosscorrelograms(sorting, window_ms=150.0, bin_ms=5.0, unit_ids=['1', '2', '5']) plt.show() diff --git a/pyproject.toml b/pyproject.toml index c1c02db8db..98879c7302 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,13 +1,13 @@ [project] name = "spikeinterface" -version = "0.101.1" +version = "0.102.0" authors = [ { name="Alessio Buccino", email="alessiop.buccino@gmail.com" }, { name="Samuel Garcia", email="sam.garcia.die@gmail.com" }, ] description = "Python toolkit for analysis, visualization, and comparison of spike sorting output" readme = "README.md" -requires-python = ">=3.9,<4.0" +requires-python = ">=3.9,<3.13" # Only numpy 2.1 supported on python 3.13 for windows. We need to wait for fix on neo classifiers = [ "Programming Language :: Python :: 3 :: Only", "License :: OSI Approved :: MIT License", @@ -23,7 +23,7 @@ dependencies = [ "numpy>=1.20, <2.0", # 1.20 np.ptp, 1.26 might be necessary for avoiding pickling errors when numpy >2.0 "threadpoolctl>=3.0.0", "tqdm", - "zarr>=2.16,<2.18", + "zarr>=2.18,<3", "neo>=0.13.0", "probeinterface>=0.2.23", "packaging", @@ -73,7 +73,7 @@ extractors = [ ] streaming_extractors = [ - "ONE-api>=2.7.0", # alf sorter and streaming IBL + "ONE-api>=2.7.0,<2.10.0", # alf sorter and streaming IBL "ibllib>=2.36.0", # streaming IBL # Following dependencies are for streaming with nwb files "pynwb>=2.6.0", @@ -101,6 +101,8 @@ full = [ "matplotlib>=3.6", # matplotlib.colormaps "cuda-python; platform_system != 'Darwin'", "numba", + "skops", + "huggingface_hub" ] widgets = [ @@ -124,16 +126,16 @@ test_core = [ # for github test : probeinterface and neo from master # for release we need pypi, so this need to be commented - # "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", - # "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", + "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", + "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", ] test_extractors = [ # Functions to download data in neo test suite "pooch>=1.8.2", "datalad>=1.0.2", - # "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", - # "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", + "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", + "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", ] test_preprocessing = [ @@ -171,10 +173,14 @@ test = [ "torch", "pynndescent", + # curation + "skops", + "huggingface_hub", + # for github test : probeinterface and neo from master # for release we need pypi, so this need to be commented - # "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", - # "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", + "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", + "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", ] docs = [ @@ -192,13 +198,15 @@ docs = [ "hdbscan>=0.8.33", # For sorters spykingcircus2 + tridesclous "numba", # For many postprocessing functions "networkx", + "skops", # For auotmated curation + "scikit-learn", # For auotmated curation # Download data "pooch>=1.8.2", "datalad>=1.0.2", # for release we need pypi, so this needs to be commented - # "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", # We always build from the latest version - # "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # We always build from the latest version + "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", # We always build from the latest version + "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # We always build from the latest version ] diff --git a/readthedocs.yml b/readthedocs.yml index 512fcbc709..c6c44d83a0 100644 --- a/readthedocs.yml +++ b/readthedocs.yml @@ -1,5 +1,9 @@ version: 2 +sphinx: + # Path to your Sphinx configuration file. + configuration: doc/conf.py + build: os: ubuntu-22.04 tools: diff --git a/src/spikeinterface/__init__.py b/src/spikeinterface/__init__.py index 97fb95b623..306c12d516 100644 --- a/src/spikeinterface/__init__.py +++ b/src/spikeinterface/__init__.py @@ -30,5 +30,5 @@ # This flag must be set to False for release # This avoids using versioning that contains ".dev0" (and this is a better choice) # This is mainly useful when using run_sorter in a container and spikeinterface install -# DEV_MODE = True -DEV_MODE = False +DEV_MODE = True +# DEV_MODE = False diff --git a/src/spikeinterface/benchmark/__init__.py b/src/spikeinterface/benchmark/__init__.py new file mode 100644 index 0000000000..3cf0c6a6f6 --- /dev/null +++ b/src/spikeinterface/benchmark/__init__.py @@ -0,0 +1,7 @@ +""" +Module to benchmark: + * sorters + * some sorting components (clustering, motion, template matching) +""" + +from .benchmark_sorter import SorterStudy diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py b/src/spikeinterface/benchmark/benchmark_base.py similarity index 92% rename from src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py rename to src/spikeinterface/benchmark/benchmark_base.py index 4d6dd43bce..f427557677 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py +++ b/src/spikeinterface/benchmark/benchmark_base.py @@ -11,7 +11,7 @@ from spikeinterface.core import SortingAnalyzer -from spikeinterface import load_extractor, create_sorting_analyzer, load_sorting_analyzer +from spikeinterface import load, create_sorting_analyzer, load_sorting_analyzer from spikeinterface.widgets import get_some_colors @@ -131,7 +131,7 @@ def create(cls, study_folder, datasets={}, cases={}, levels=None): return cls(study_folder) - def create_benchmark(self): + def create_benchmark(self, key): raise NotImplementedError def scan_folder(self): @@ -150,13 +150,13 @@ def scan_folder(self): analyzer = load_sorting_analyzer(folder) self.analyzers[key] = analyzer # the sorting is in memory here we take the saved one because comparisons need to pickle it later - sorting = load_extractor(analyzer.folder / "sorting") + sorting = load(analyzer.folder / "sorting") self.datasets[key] = analyzer.recording, sorting # for rec_file in (self.folder / "datasets" / "recordings").glob("*.pickle"): # key = rec_file.stem - # rec = load_extractor(rec_file) - # gt_sorting = load_extractor(self.folder / f"datasets" / "gt_sortings" / key) + # rec = load(rec_file) + # gt_sorting = load(self.folder / f"datasets" / "gt_sortings" / key) # self.datasets[key] = (rec, gt_sorting) with open(self.folder / "cases.pickle", "rb") as f: @@ -208,10 +208,11 @@ def run(self, case_keys=None, keep=True, verbose=False, **job_kwargs): for key in case_keys: result_folder = self.folder / "results" / self.key_to_str(key) + sorter_folder = self.folder / "sorters" / self.key_to_str(key) if keep and result_folder.exists(): continue - elif not keep and result_folder.exists(): + elif not keep and (result_folder.exists() or sorter_folder.exists()): self.remove_benchmark(key) job_keys.append(key) @@ -258,25 +259,9 @@ def get_run_times(self, case_keys=None): return df def plot_run_times(self, case_keys=None): - if case_keys is None: - case_keys = list(self.cases.keys()) - run_times = self.get_run_times(case_keys=case_keys) - - colors = self.get_colors() - import matplotlib.pyplot as plt + from .benchmark_plot_tools import plot_run_times - fig, ax = plt.subplots() - labels = [] - for i, key in enumerate(case_keys): - labels.append(self.cases[key]["label"]) - rt = run_times.at[key, "run_times"] - ax.bar(i, rt, width=0.8, color=colors[key]) - ax.set_xticks(np.arange(len(case_keys))) - ax.set_xticklabels(labels, rotation=45.0) - return fig - - # ax = run_times.plot(kind="bar") - # return ax.figure + return plot_run_times(self, case_keys=case_keys) def compute_results(self, case_keys=None, verbose=False, **result_params): if case_keys is None: @@ -443,7 +428,7 @@ def load_folder(cls, folder): elif format == "sorting": from spikeinterface.core import load_extractor - result[k] = load_extractor(folder / k) + result[k] = load(folder / k) elif format == "Motion": from spikeinterface.sortingcomponents.motion import Motion @@ -462,10 +447,3 @@ def run(self): def compute_result(self): # run becnhmark result raise NotImplementedError - - -def _simpleaxis(ax): - ax.spines["top"].set_visible(False) - ax.spines["right"].set_visible(False) - ax.get_xaxis().tick_bottom() - ax.get_yaxis().tick_left() diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py b/src/spikeinterface/benchmark/benchmark_clustering.py similarity index 92% rename from src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py rename to src/spikeinterface/benchmark/benchmark_clustering.py index 92fcda35d9..1c731ecb64 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py +++ b/src/spikeinterface/benchmark/benchmark_clustering.py @@ -11,8 +11,7 @@ import numpy as np - -from .benchmark_tools import BenchmarkStudy, Benchmark +from .benchmark_base import Benchmark, BenchmarkStudy from spikeinterface.core.sortinganalyzer import create_sorting_analyzer from spikeinterface.core.template_tools import get_template_extremum_channel @@ -161,49 +160,21 @@ def get_count_units(self, case_keys=None, well_detected_score=None, redundant_sc return count_units - def plot_unit_counts(self, case_keys=None, figsize=None, **extra_kwargs): - from spikeinterface.widgets.widget_list import plot_study_unit_counts + # plotting by methods + def plot_unit_counts(self, **kwargs): + from .benchmark_plot_tools import plot_unit_counts - plot_study_unit_counts(self, case_keys, figsize=figsize, **extra_kwargs) + return plot_unit_counts(self, **kwargs) - def plot_agreements(self, case_keys=None, figsize=(15, 15)): - if case_keys is None: - case_keys = list(self.cases.keys()) - import pylab as plt + def plot_agreement_matrix(self, **kwargs): + from .benchmark_plot_tools import plot_agreement_matrix - fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False) + return plot_agreement_matrix(self, **kwargs) - for count, key in enumerate(case_keys): - ax = axs[0, count] - ax.set_title(self.cases[key]["label"]) - plot_agreement_matrix(self.get_result(key)["gt_comparison"], ax=ax) + def plot_performances_vs_snr(self, **kwargs): + from .benchmark_plot_tools import plot_performances_vs_snr - return fig - - def plot_performances_vs_snr(self, case_keys=None, figsize=(15, 15)): - if case_keys is None: - case_keys = list(self.cases.keys()) - import pylab as plt - - fig, axes = plt.subplots(ncols=1, nrows=3, figsize=figsize) - - for count, k in enumerate(("accuracy", "recall", "precision")): - - ax = axes[count] - for key in case_keys: - label = self.cases[key]["label"] - - analyzer = self.get_sorting_analyzer(key) - metrics = analyzer.get_extension("quality_metrics").get_data() - x = metrics["snr"].values - y = self.get_result(key)["gt_comparison"].get_performance()[k].values - ax.scatter(x, y, marker=".", label=label) - ax.set_title(k) - - if count == 2: - ax.legend() - - return fig + return plot_performances_vs_snr(self, **kwargs) def plot_error_metrics(self, metric="cosine", case_keys=None, figsize=(15, 5)): diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py b/src/spikeinterface/benchmark/benchmark_matching.py similarity index 59% rename from src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py rename to src/spikeinterface/benchmark/benchmark_matching.py index ab1523d13a..1934b65ef4 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py +++ b/src/spikeinterface/benchmark/benchmark_matching.py @@ -9,11 +9,8 @@ ) import numpy as np -from spikeinterface.sortingcomponents.benchmark.benchmark_tools import Benchmark, BenchmarkStudy +from .benchmark_base import Benchmark, BenchmarkStudy from spikeinterface.core.basesorting import minimum_spike_dtype -from spikeinterface.sortingcomponents.tools import remove_empty_templates -from spikeinterface.core.recording_tools import get_noise_levels -from spikeinterface.core.sparsity import compute_sparsity class MatchingBenchmark(Benchmark): @@ -36,7 +33,7 @@ def run(self, **job_kwargs): sorting["unit_index"] = spikes["cluster_index"] sorting["segment_index"] = spikes["segment_index"] sorting = NumpySorting(sorting, self.recording.sampling_frequency, unit_ids) - self.result = {"sorting": sorting} + self.result = {"sorting": sorting, "spikes": spikes} self.result["templates"] = self.templates def compute_result(self, with_collision=False, **result_params): @@ -48,6 +45,7 @@ def compute_result(self, with_collision=False, **result_params): _run_key_saved = [ ("sorting", "sorting"), + ("spikes", "npy"), ("templates", "zarr_templates"), ] _result_key_saved = [("gt_collision", "pickle"), ("gt_comparison", "pickle")] @@ -64,45 +62,25 @@ def create_benchmark(self, key): benchmark = MatchingBenchmark(recording, gt_sorting, params) return benchmark - def plot_agreements(self, case_keys=None, figsize=None): - if case_keys is None: - case_keys = list(self.cases.keys()) - import pylab as plt - - fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False) - - for count, key in enumerate(case_keys): - ax = axs[0, count] - ax.set_title(self.cases[key]["label"]) - plot_agreement_matrix(self.get_result(key)["gt_comparison"], ax=ax) - - def plot_performances_vs_snr(self, case_keys=None, figsize=None, metrics=["accuracy", "recall", "precision"]): - if case_keys is None: - case_keys = list(self.cases.keys()) + def plot_agreement_matrix(self, **kwargs): + from .benchmark_plot_tools import plot_agreement_matrix - fig, axs = plt.subplots(ncols=1, nrows=len(metrics), figsize=figsize, squeeze=False) + return plot_agreement_matrix(self, **kwargs) - for count, k in enumerate(metrics): + def plot_performances_vs_snr(self, **kwargs): + from .benchmark_plot_tools import plot_performances_vs_snr - ax = axs[count, 0] - for key in case_keys: - label = self.cases[key]["label"] + return plot_performances_vs_snr(self, **kwargs) - analyzer = self.get_sorting_analyzer(key) - metrics = analyzer.get_extension("quality_metrics").get_data() - x = metrics["snr"].values - y = self.get_result(key)["gt_comparison"].get_performance()[k].values - ax.scatter(x, y, marker=".", label=label) - ax.set_title(k) + def plot_performances_comparison(self, **kwargs): + from .benchmark_plot_tools import plot_performances_comparison - if count == 2: - ax.legend() - - return fig + return plot_performances_comparison(self, **kwargs) def plot_collisions(self, case_keys=None, figsize=None): if case_keys is None: case_keys = list(self.cases.keys()) + import matplotlib.pyplot as plt fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False) @@ -119,70 +97,6 @@ def plot_collisions(self, case_keys=None, figsize=None): return fig - def plot_comparison_matching( - self, - case_keys=None, - performance_names=["accuracy", "recall", "precision"], - colors=["g", "b", "r"], - ylim=(-0.1, 1.1), - figsize=None, - ): - - if case_keys is None: - case_keys = list(self.cases.keys()) - - num_methods = len(case_keys) - import pylab as plt - - fig, axs = plt.subplots(ncols=num_methods, nrows=num_methods, figsize=(10, 10)) - for i, key1 in enumerate(case_keys): - for j, key2 in enumerate(case_keys): - if len(axs.shape) > 1: - ax = axs[i, j] - else: - ax = axs[j] - comp1 = self.get_result(key1)["gt_comparison"] - comp2 = self.get_result(key2)["gt_comparison"] - if i <= j: - for performance, color in zip(performance_names, colors): - perf1 = comp1.get_performance()[performance] - perf2 = comp2.get_performance()[performance] - ax.plot(perf2, perf1, ".", label=performance, color=color) - - ax.plot([0, 1], [0, 1], "k--", alpha=0.5) - ax.set_ylim(ylim) - ax.set_xlim(ylim) - ax.spines[["right", "top"]].set_visible(False) - ax.set_aspect("equal") - - label1 = self.cases[key1]["label"] - label2 = self.cases[key2]["label"] - if j == i: - ax.set_ylabel(f"{label1}") - else: - ax.set_yticks([]) - if i == j: - ax.set_xlabel(f"{label2}") - else: - ax.set_xticks([]) - if i == num_methods - 1 and j == num_methods - 1: - patches = [] - import matplotlib.patches as mpatches - - for color, name in zip(colors, performance_names): - patches.append(mpatches.Patch(color=color, label=name)) - ax.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc="upper left", borderaxespad=0.0) - else: - ax.spines["bottom"].set_visible(False) - ax.spines["left"].set_visible(False) - ax.spines["top"].set_visible(False) - ax.spines["right"].set_visible(False) - ax.set_xticks([]) - ax.set_yticks([]) - plt.tight_layout(h_pad=0, w_pad=0) - - return fig - def get_count_units(self, case_keys=None, well_detected_score=None, redundant_score=None, overmerged_score=None): import pandas as pd @@ -225,6 +139,7 @@ def plot_unit_counts(self, case_keys=None, figsize=None): plot_study_unit_counts(self, case_keys, figsize=figsize) def plot_unit_losses(self, before, after, metric=["precision"], figsize=None): + import matplotlib.pyplot as plt fig, axs = plt.subplots(ncols=1, nrows=len(metric), figsize=figsize, squeeze=False) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py b/src/spikeinterface/benchmark/benchmark_motion_estimation.py similarity index 97% rename from src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py rename to src/spikeinterface/benchmark/benchmark_motion_estimation.py index ec7e1e24a8..5a3c490d38 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py +++ b/src/spikeinterface/benchmark/benchmark_motion_estimation.py @@ -8,7 +8,8 @@ import numpy as np from spikeinterface.core import get_noise_levels -from spikeinterface.sortingcomponents.benchmark.benchmark_tools import Benchmark, BenchmarkStudy, _simpleaxis +from .benchmark_base import Benchmark, BenchmarkStudy +from .benchmark_plot_tools import _simpleaxis from spikeinterface.sortingcomponents.motion import estimate_motion from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.sortingcomponents.peak_selection import select_peaks @@ -108,6 +109,8 @@ def run(self, **job_kwargs): estimate_motion=t4 - t3, ) + self.result["peaks"] = peaks + self.result["peak_locations"] = peak_locations self.result["step_run_times"] = step_run_times self.result["raw_motion"] = motion @@ -130,6 +133,8 @@ def compute_result(self, **result_params): self.result["motion"] = motion _run_key_saved = [ + ("peaks", "npy"), + ("peak_locations", "npy"), ("raw_motion", "Motion"), ("step_run_times", "pickle"), ] @@ -160,7 +165,9 @@ def create_benchmark(self, key): def plot_true_drift(self, case_keys=None, scaling_probe=1.5, figsize=(8, 6)): self.plot_drift(case_keys=case_keys, tested_drift=False, scaling_probe=scaling_probe, figsize=figsize) - def plot_drift(self, case_keys=None, gt_drift=True, tested_drift=True, scaling_probe=1.0, figsize=(8, 6)): + def plot_drift( + self, case_keys=None, gt_drift=True, tested_drift=True, raster=False, scaling_probe=1.0, figsize=(8, 6) + ): import matplotlib.pyplot as plt if case_keys is None: @@ -194,6 +201,13 @@ def plot_drift(self, case_keys=None, gt_drift=True, tested_drift=True, scaling_p # for i in range(self.gt_unit_positions.shape[1]): # ax.plot(temporal_bins_s, self.gt_unit_positions[:, i], alpha=0.5, ls="--", c="0.5") + if raster: + peaks = bench.result["peaks"] + peak_locations = bench.result["peak_locations"] + rec = bench.recording + x = peaks["sample_index"] / rec.sampling_frequency + y = peak_locations[bench.direction] + ax.scatter(x, y, alpha=0.2, s=2, c=np.abs(peaks["amplitude"]), cmap="inferno") for i in range(gt_motion.displacement[0].shape[1]): depth = motion.spatial_bins_um[i] diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py b/src/spikeinterface/benchmark/benchmark_motion_interpolation.py similarity index 98% rename from src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py rename to src/spikeinterface/benchmark/benchmark_motion_interpolation.py index 38365adfd1..ab72a1f9bd 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py +++ b/src/spikeinterface/benchmark/benchmark_motion_interpolation.py @@ -10,7 +10,7 @@ from spikeinterface.curation import MergeUnitsSorting -from spikeinterface.sortingcomponents.benchmark.benchmark_tools import Benchmark, BenchmarkStudy, _simpleaxis +from .benchmark_base import Benchmark, BenchmarkStudy class MotionInterpolationBenchmark(Benchmark): diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_detection.py b/src/spikeinterface/benchmark/benchmark_peak_detection.py similarity index 98% rename from src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_detection.py rename to src/spikeinterface/benchmark/benchmark_peak_detection.py index 7d862343d2..77b5e0025c 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_detection.py +++ b/src/spikeinterface/benchmark/benchmark_peak_detection.py @@ -12,10 +12,9 @@ import numpy as np -from spikeinterface.sortingcomponents.benchmark.benchmark_tools import Benchmark, BenchmarkStudy +from .benchmark_base import Benchmark, BenchmarkStudy from spikeinterface.core.basesorting import minimum_spike_dtype from spikeinterface.core.sortinganalyzer import create_sorting_analyzer -from spikeinterface.core.template_tools import get_template_extremum_channel class PeakDetectionBenchmark(Benchmark): diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py b/src/spikeinterface/benchmark/benchmark_peak_localization.py similarity index 99% rename from src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py rename to src/spikeinterface/benchmark/benchmark_peak_localization.py index 05d142113b..399729fa29 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py +++ b/src/spikeinterface/benchmark/benchmark_peak_localization.py @@ -6,7 +6,7 @@ compute_grid_convolution, ) import numpy as np -from spikeinterface.sortingcomponents.benchmark.benchmark_tools import Benchmark, BenchmarkStudy +from .benchmark_base import Benchmark, BenchmarkStudy from spikeinterface.core.sortinganalyzer import create_sorting_analyzer diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py b/src/spikeinterface/benchmark/benchmark_peak_selection.py similarity index 98% rename from src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py rename to src/spikeinterface/benchmark/benchmark_peak_selection.py index 008de2d931..41edea156f 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py +++ b/src/spikeinterface/benchmark/benchmark_peak_selection.py @@ -6,15 +6,9 @@ from spikeinterface.comparison.comparisontools import make_matching_events from spikeinterface.core import get_noise_levels -import time -import string, random -import pylab as plt -import os import numpy as np -from spikeinterface.sortingcomponents.benchmark.benchmark_tools import Benchmark, BenchmarkStudy -from spikeinterface.core.basesorting import minimum_spike_dtype -from spikeinterface.core.sortinganalyzer import create_sorting_analyzer +from .benchmark_base import Benchmark, BenchmarkStudy class PeakSelectionBenchmark(Benchmark): diff --git a/src/spikeinterface/benchmark/benchmark_plot_tools.py b/src/spikeinterface/benchmark/benchmark_plot_tools.py new file mode 100644 index 0000000000..e15636ebaf --- /dev/null +++ b/src/spikeinterface/benchmark/benchmark_plot_tools.py @@ -0,0 +1,305 @@ +import numpy as np + + +def _simpleaxis(ax): + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + ax.get_xaxis().tick_bottom() + ax.get_yaxis().tick_left() + + +def plot_run_times(study, case_keys=None): + """ + Plot run times for a BenchmarkStudy. + + Parameters + ---------- + study : SorterStudy + A study object. + case_keys : list or None + A selection of cases to plot, if None, then all. + """ + import matplotlib.pyplot as plt + + if case_keys is None: + case_keys = list(study.cases.keys()) + + run_times = study.get_run_times(case_keys=case_keys) + + colors = study.get_colors() + + fig, ax = plt.subplots() + labels = [] + for i, key in enumerate(case_keys): + labels.append(study.cases[key]["label"]) + rt = run_times.at[key, "run_times"] + ax.bar(i, rt, width=0.8, color=colors[key]) + ax.set_xticks(np.arange(len(case_keys))) + ax.set_xticklabels(labels, rotation=45.0) + return fig + + +def plot_unit_counts(study, case_keys=None): + """ + Plot unit counts for a study: "num_well_detected", "num_false_positive", "num_redundant", "num_overmerged" + + Parameters + ---------- + study : SorterStudy + A study object. + case_keys : list or None + A selection of cases to plot, if None, then all. + """ + import matplotlib.pyplot as plt + from spikeinterface.widgets.utils import get_some_colors + + if case_keys is None: + case_keys = list(study.cases.keys()) + + count_units = study.get_count_units(case_keys=case_keys) + + fig, ax = plt.subplots() + + columns = count_units.columns.tolist() + columns.remove("num_gt") + columns.remove("num_sorter") + + ncol = len(columns) + + colors = get_some_colors(columns, color_engine="auto", map_name="hot") + colors["num_well_detected"] = "green" + + xticklabels = [] + for i, key in enumerate(case_keys): + for c, col in enumerate(columns): + x = i + 1 + c / (ncol + 1) + y = count_units.loc[key, col] + if not "well_detected" in col: + y = -y + + if i == 0: + label = col.replace("num_", "").replace("_", " ").title() + else: + label = None + + ax.bar([x], [y], width=1 / (ncol + 2), label=label, color=colors[col]) + + xticklabels.append(study.cases[key]["label"]) + + ax.set_xticks(np.arange(len(case_keys)) + 1) + ax.set_xticklabels(xticklabels) + ax.legend() + + return fig + + +def plot_performances(study, mode="ordered", performance_names=("accuracy", "precision", "recall"), case_keys=None): + """ + Plot performances over case for a study. + + Parameters + ---------- + study : GroundTruthStudy + A study object. + mode : "ordered" | "snr" | "swarm", default: "ordered" + Which plot mode to use: + + * "ordered": plot performance metrics vs unit indices ordered by decreasing accuracy + * "snr": plot performance metrics vs snr + * "swarm": plot performance metrics as a swarm plot (see seaborn.swarmplot for details) + performance_names : list or tuple, default: ("accuracy", "precision", "recall") + Which performances to plot ("accuracy", "precision", "recall") + case_keys : list or None + A selection of cases to plot, if None, then all. + """ + import matplotlib.pyplot as plt + import pandas as pd + import seaborn as sns + + if case_keys is None: + case_keys = list(study.cases.keys()) + + perfs = study.get_performance_by_unit(case_keys=case_keys) + colors = study.get_colors() + + if mode in ("ordered", "snr"): + num_axes = len(performance_names) + fig, axs = plt.subplots(ncols=num_axes) + else: + fig, ax = plt.subplots() + + if mode == "ordered": + for count, performance_name in enumerate(performance_names): + ax = axs.flatten()[count] + for key in case_keys: + label = study.cases[key]["label"] + val = perfs.xs(key).loc[:, performance_name].values + val = np.sort(val)[::-1] + ax.plot(val, label=label, c=colors[key]) + ax.set_title(performance_name) + if count == len(performance_names) - 1: + ax.legend(bbox_to_anchor=(0.05, 0.05), loc="lower left", framealpha=0.8) + + elif mode == "snr": + metric_name = mode + for count, performance_name in enumerate(performance_names): + ax = axs.flatten()[count] + + max_metric = 0 + for key in case_keys: + x = study.get_metrics(key).loc[:, metric_name].values + y = perfs.xs(key).loc[:, performance_name].values + label = study.cases[key]["label"] + ax.scatter(x, y, s=10, label=label, color=colors[key]) + max_metric = max(max_metric, np.max(x)) + ax.set_title(performance_name) + ax.set_xlim(0, max_metric * 1.05) + ax.set_ylim(0, 1.05) + if count == 0: + ax.legend(loc="lower right") + + elif mode == "swarm": + levels = perfs.index.names + df = pd.melt( + perfs.reset_index(), + id_vars=levels, + var_name="Metric", + value_name="Score", + value_vars=performance_names, + ) + df["x"] = df.apply(lambda r: " ".join([r[col] for col in levels]), axis=1) + sns.swarmplot(data=df, x="x", y="Score", hue="Metric", dodge=True, ax=ax) + + +def plot_agreement_matrix(study, ordered=True, case_keys=None): + """ + Plot agreement matri ces for cases in a study. + + Parameters + ---------- + study : GroundTruthStudy + A study object. + case_keys : list or None + A selection of cases to plot, if None, then all. + ordered : bool + Order units with best agreement scores. + This enable to see agreement on a diagonal. + """ + + import matplotlib.pyplot as plt + from spikeinterface.widgets import AgreementMatrixWidget + + if case_keys is None: + case_keys = list(study.cases.keys()) + + num_axes = len(case_keys) + fig, axs = plt.subplots(ncols=num_axes) + + for count, key in enumerate(case_keys): + ax = axs.flatten()[count] + comp = study.get_result(key)["gt_comparison"] + + unit_ticks = len(comp.sorting1.unit_ids) <= 16 + count_text = len(comp.sorting1.unit_ids) <= 16 + + AgreementMatrixWidget( + comp, ordered=ordered, count_text=count_text, unit_ticks=unit_ticks, backend="matplotlib", ax=ax + ) + label = study.cases[key]["label"] + ax.set_xlabel(label) + + if count > 0: + ax.set_ylabel(None) + ax.set_yticks([]) + ax.set_xticks([]) + + +def plot_performances_vs_snr(study, case_keys=None, figsize=None, metrics=["accuracy", "recall", "precision"]): + import matplotlib.pyplot as plt + + if case_keys is None: + case_keys = list(study.cases.keys()) + + fig, axs = plt.subplots(ncols=1, nrows=len(metrics), figsize=figsize, squeeze=False) + + for count, k in enumerate(metrics): + + ax = axs[count, 0] + for key in case_keys: + label = study.cases[key]["label"] + + analyzer = study.get_sorting_analyzer(key) + metrics = analyzer.get_extension("quality_metrics").get_data() + x = metrics["snr"].values + y = study.get_result(key)["gt_comparison"].get_performance()[k].values + ax.scatter(x, y, marker=".", label=label) + ax.set_title(k) + + ax.set_ylim(-0.05, 1.05) + + if count == 2: + ax.legend() + + return fig + + +def plot_performances_comparison( + study, + case_keys=None, + figsize=None, + metrics=["accuracy", "recall", "precision"], + colors=["g", "b", "r"], + ylim=(-0.1, 1.1), +): + import matplotlib.pyplot as plt + + if case_keys is None: + case_keys = list(study.cases.keys()) + + num_methods = len(case_keys) + assert num_methods >= 2, "plot_performances_comparison need at least 2 cases!" + + fig, axs = plt.subplots(ncols=num_methods - 1, nrows=num_methods - 1, figsize=(10, 10), squeeze=False) + for i, key1 in enumerate(case_keys): + for j, key2 in enumerate(case_keys): + + if i < j: + ax = axs[i, j - 1] + + comp1 = study.get_result(key1)["gt_comparison"] + comp2 = study.get_result(key2)["gt_comparison"] + + for performance, color in zip(metrics, colors): + perf1 = comp1.get_performance()[performance] + perf2 = comp2.get_performance()[performance] + ax.scatter(perf2, perf1, marker=".", label=performance, color=color) + + ax.plot([0, 1], [0, 1], "k--", alpha=0.5) + ax.set_ylim(ylim) + ax.set_xlim(ylim) + ax.spines[["right", "top"]].set_visible(False) + ax.set_aspect("equal") + + label1 = study.cases[key1]["label"] + label2 = study.cases[key2]["label"] + + if i == j - 1: + ax.set_xlabel(label2) + ax.set_ylabel(label1) + + else: + if j >= 1 and i < num_methods - 1: + ax = axs[i, j - 1] + ax.spines[["right", "top", "left", "bottom"]].set_visible(False) + ax.set_xticks([]) + ax.set_yticks([]) + + ax = axs[num_methods - 2, 0] + patches = [] + from matplotlib.patches import Patch + + for color, name in zip(colors, metrics): + patches.append(Patch(color=color, label=name)) + ax.legend(handles=patches) + fig.tight_layout() + return fig diff --git a/src/spikeinterface/benchmark/benchmark_sorter.py b/src/spikeinterface/benchmark/benchmark_sorter.py new file mode 100644 index 0000000000..3cf6dca04f --- /dev/null +++ b/src/spikeinterface/benchmark/benchmark_sorter.py @@ -0,0 +1,144 @@ +""" +This replace the previous `GroundTruthStudy` +""" + +import numpy as np +from ..core import NumpySorting +from .benchmark_base import Benchmark, BenchmarkStudy +from ..sorters import run_sorter +from spikeinterface.comparison import compare_sorter_to_ground_truth + + +# TODO later integrate CollisionGTComparison optionally in this class. + + +class SorterBenchmark(Benchmark): + def __init__(self, recording, gt_sorting, params, sorter_folder): + self.recording = recording + self.gt_sorting = gt_sorting + self.params = params + self.sorter_folder = sorter_folder + self.result = {} + + def run(self): + # run one sorter sorter_name is must be in params + raw_sorting = run_sorter(recording=self.recording, folder=self.sorter_folder, **self.params) + sorting = NumpySorting.from_sorting(raw_sorting) + self.result = {"sorting": sorting} + + def compute_result(self): + # run becnhmark result + sorting = self.result["sorting"] + comp = compare_sorter_to_ground_truth(self.gt_sorting, sorting, exhaustive_gt=True) + self.result["gt_comparison"] = comp + + _run_key_saved = [ + ("sorting", "sorting"), + ] + _result_key_saved = [ + ("gt_comparison", "pickle"), + ] + + +class SorterStudy(BenchmarkStudy): + """ + This class is used to tests several sorter in several situtation. + This replace the previous GroundTruthStudy with more flexibility. + """ + + benchmark_class = SorterBenchmark + + def create_benchmark(self, key): + dataset_key = self.cases[key]["dataset"] + recording, gt_sorting = self.datasets[dataset_key] + params = self.cases[key]["params"] + sorter_folder = self.folder / "sorters" / self.key_to_str(key) + benchmark = SorterBenchmark(recording, gt_sorting, params, sorter_folder) + return benchmark + + def remove_benchmark(self, key): + BenchmarkStudy.remove_benchmark(self, key) + + sorter_folder = self.folder / "sorters" / self.key_to_str(key) + import shutil + + if sorter_folder.exists(): + shutil.rmtree(sorter_folder) + + def get_performance_by_unit(self, case_keys=None): + import pandas as pd + + if case_keys is None: + case_keys = self.cases.keys() + + perf_by_unit = [] + for key in case_keys: + comp = self.get_result(key)["gt_comparison"] + + perf = comp.get_performance(method="by_unit", output="pandas") + + if isinstance(key, str): + perf[self.levels] = key + elif isinstance(key, tuple): + for col, k in zip(self.levels, key): + perf[col] = k + + perf = perf.reset_index() + perf_by_unit.append(perf) + + perf_by_unit = pd.concat(perf_by_unit) + perf_by_unit = perf_by_unit.set_index(self.levels) + perf_by_unit = perf_by_unit.sort_index() + return perf_by_unit + + def get_count_units(self, case_keys=None, well_detected_score=None, redundant_score=None, overmerged_score=None): + import pandas as pd + + if case_keys is None: + case_keys = list(self.cases.keys()) + + if isinstance(case_keys[0], str): + index = pd.Index(case_keys, name=self.levels) + else: + index = pd.MultiIndex.from_tuples(case_keys, names=self.levels) + + columns = ["num_gt", "num_sorter", "num_well_detected"] + key0 = case_keys[0] + comp = self.get_result(key0)["gt_comparison"] + if comp.exhaustive_gt: + columns.extend(["num_false_positive", "num_redundant", "num_overmerged", "num_bad"]) + count_units = pd.DataFrame(index=index, columns=columns, dtype=int) + + for key in case_keys: + comp = self.get_result(key)["gt_comparison"] + + gt_sorting = comp.sorting1 + sorting = comp.sorting2 + + count_units.loc[key, "num_gt"] = len(gt_sorting.get_unit_ids()) + count_units.loc[key, "num_sorter"] = len(sorting.get_unit_ids()) + count_units.loc[key, "num_well_detected"] = comp.count_well_detected_units(well_detected_score) + + if comp.exhaustive_gt: + count_units.loc[key, "num_redundant"] = comp.count_redundant_units(redundant_score) + count_units.loc[key, "num_overmerged"] = comp.count_overmerged_units(overmerged_score) + count_units.loc[key, "num_false_positive"] = comp.count_false_positive_units(redundant_score) + count_units.loc[key, "num_bad"] = comp.count_bad_units() + + return count_units + + # plotting as methods + def plot_unit_counts(self, **kwargs): + from .benchmark_plot_tools import plot_unit_counts + + return plot_unit_counts(self, **kwargs) + + def plot_performances(self, **kwargs): + from .benchmark_plot_tools import plot_performances + + return plot_performances(self, **kwargs) + + def plot_agreement_matrix(self, **kwargs): + from .benchmark_plot_tools import plot_agreement_matrix + + return plot_agreement_matrix(self, **kwargs) diff --git a/src/spikeinterface/benchmark/benchmark_tools.py b/src/spikeinterface/benchmark/benchmark_tools.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/common_benchmark_testing.py b/src/spikeinterface/benchmark/tests/common_benchmark_testing.py similarity index 100% rename from src/spikeinterface/sortingcomponents/benchmark/tests/common_benchmark_testing.py rename to src/spikeinterface/benchmark/tests/common_benchmark_testing.py diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_clustering.py b/src/spikeinterface/benchmark/tests/test_benchmark_clustering.py similarity index 88% rename from src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_clustering.py rename to src/spikeinterface/benchmark/tests/test_benchmark_clustering.py index bc36fb607c..3f574fd058 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_clustering.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_clustering.py @@ -3,11 +3,13 @@ import shutil -from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import make_dataset -from spikeinterface.sortingcomponents.benchmark.benchmark_clustering import ClusteringStudy +from spikeinterface.benchmark.tests.common_benchmark_testing import make_dataset +from spikeinterface.benchmark.benchmark_clustering import ClusteringStudy from spikeinterface.core.sortinganalyzer import create_sorting_analyzer from spikeinterface.core.template_tools import get_template_extremum_channel +from pathlib import Path + @pytest.mark.skip() def test_benchmark_clustering(create_cache_folder): @@ -78,4 +80,5 @@ def test_benchmark_clustering(create_cache_folder): if __name__ == "__main__": - test_benchmark_clustering() + cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "benchmarks" + test_benchmark_clustering(cache_folder) diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py b/src/spikeinterface/benchmark/tests/test_benchmark_matching.py similarity index 86% rename from src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py rename to src/spikeinterface/benchmark/tests/test_benchmark_matching.py index 71a5f282a8..000a00faf5 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_matching.py @@ -1,6 +1,7 @@ import pytest import shutil +from pathlib import Path from spikeinterface.core import ( @@ -8,11 +9,11 @@ compute_sparsity, ) -from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import ( +from spikeinterface.benchmark.tests.common_benchmark_testing import ( make_dataset, compute_gt_templates, ) -from spikeinterface.sortingcomponents.benchmark.benchmark_matching import MatchingStudy +from spikeinterface.benchmark.benchmark_matching import MatchingStudy @pytest.mark.skip() @@ -72,4 +73,5 @@ def test_benchmark_matching(create_cache_folder): if __name__ == "__main__": - test_benchmark_matching() + cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "benchmarks" + test_benchmark_matching(cache_folder) diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py b/src/spikeinterface/benchmark/tests/test_benchmark_motion_estimation.py similarity index 86% rename from src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py rename to src/spikeinterface/benchmark/tests/test_benchmark_motion_estimation.py index 78a9eb7dbc..65cacfc8a0 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_motion_estimation.py @@ -2,12 +2,13 @@ import shutil +from pathlib import Path -from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import ( +from spikeinterface.benchmark.tests.common_benchmark_testing import ( make_drifting_dataset, ) -from spikeinterface.sortingcomponents.benchmark.benchmark_motion_estimation import MotionEstimationStudy +from spikeinterface.benchmark.benchmark_motion_estimation import MotionEstimationStudy @pytest.mark.skip() @@ -75,4 +76,5 @@ def test_benchmark_motion_estimaton(create_cache_folder): if __name__ == "__main__": - test_benchmark_motion_estimaton() + cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "benchmarks" + test_benchmark_motion_estimaton(cache_folder) diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py b/src/spikeinterface/benchmark/tests/test_benchmark_motion_interpolation.py similarity index 90% rename from src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py rename to src/spikeinterface/benchmark/tests/test_benchmark_motion_interpolation.py index 18def37d54..f7afd7a8bc 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_motion_interpolation.py @@ -4,14 +4,14 @@ import numpy as np import shutil +from pathlib import Path - -from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import ( +from spikeinterface.benchmark.tests.common_benchmark_testing import ( make_drifting_dataset, ) -from spikeinterface.sortingcomponents.benchmark.benchmark_motion_interpolation import MotionInterpolationStudy -from spikeinterface.sortingcomponents.benchmark.benchmark_motion_estimation import ( +from spikeinterface.benchmark.benchmark_motion_interpolation import MotionInterpolationStudy +from spikeinterface.benchmark.benchmark_motion_estimation import ( # get_unit_displacement, get_gt_motion_from_unit_displacement, ) @@ -139,4 +139,5 @@ def test_benchmark_motion_interpolation(create_cache_folder): if __name__ == "__main__": - test_benchmark_motion_interpolation() + cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "benchmarks" + test_benchmark_motion_interpolation(cache_folder) diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_detection.py b/src/spikeinterface/benchmark/tests/test_benchmark_peak_detection.py similarity index 87% rename from src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_detection.py rename to src/spikeinterface/benchmark/tests/test_benchmark_peak_detection.py index dffe1529b7..d45ac0b4ce 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_detection.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_peak_detection.py @@ -1,10 +1,10 @@ import pytest import shutil +from pathlib import Path - -from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import make_dataset -from spikeinterface.sortingcomponents.benchmark.benchmark_peak_detection import PeakDetectionStudy +from spikeinterface.benchmark.tests.common_benchmark_testing import make_dataset +from spikeinterface.benchmark.benchmark_peak_detection import PeakDetectionStudy from spikeinterface.core.sortinganalyzer import create_sorting_analyzer from spikeinterface.core.template_tools import get_template_extremum_channel @@ -69,5 +69,5 @@ def test_benchmark_peak_detection(create_cache_folder): if __name__ == "__main__": - # test_benchmark_peak_localization() - test_benchmark_peak_detection() + cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "benchmarks" + test_benchmark_peak_detection(cache_folder) diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_localization.py b/src/spikeinterface/benchmark/tests/test_benchmark_peak_localization.py similarity index 79% rename from src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_localization.py rename to src/spikeinterface/benchmark/tests/test_benchmark_peak_localization.py index 23060c4ddb..3b6240cb10 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_localization.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_peak_localization.py @@ -1,12 +1,12 @@ import pytest import shutil +from pathlib import Path +from spikeinterface.benchmark.tests.common_benchmark_testing import make_dataset -from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import make_dataset - -from spikeinterface.sortingcomponents.benchmark.benchmark_peak_localization import PeakLocalizationStudy -from spikeinterface.sortingcomponents.benchmark.benchmark_peak_localization import UnitLocalizationStudy +from spikeinterface.benchmark.benchmark_peak_localization import PeakLocalizationStudy +from spikeinterface.benchmark.benchmark_peak_localization import UnitLocalizationStudy @pytest.mark.skip() @@ -28,7 +28,8 @@ def test_benchmark_peak_localization(create_cache_folder): "init_kwargs": {"gt_positions": gt_sorting.get_property("gt_unit_locations")}, "params": { "method": method, - "method_kwargs": {"ms_before": 2}, + "ms_before": 2.0, + "method_kwargs": {}, }, } @@ -60,7 +61,7 @@ def test_benchmark_unit_locations(create_cache_folder): cache_folder = create_cache_folder job_kwargs = dict(n_jobs=0.8, chunk_duration="100ms") - recording, gt_sorting = make_dataset() + recording, gt_sorting, gt_analyzer = make_dataset() # create study study_folder = cache_folder / "study_unit_locations" @@ -71,7 +72,7 @@ def test_benchmark_unit_locations(create_cache_folder): "label": f"{method} on toy", "dataset": "toy", "init_kwargs": {"gt_positions": gt_sorting.get_property("gt_unit_locations")}, - "params": {"method": method, "method_kwargs": {"ms_before": 2}}, + "params": {"method": method, "ms_before": 2.0, "method_kwargs": {}}, } if study_folder.exists(): @@ -99,5 +100,6 @@ def test_benchmark_unit_locations(create_cache_folder): if __name__ == "__main__": - # test_benchmark_peak_localization() - test_benchmark_unit_locations() + cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "benchmarks" + # test_benchmark_peak_localization(cache_folder) + test_benchmark_unit_locations(cache_folder) diff --git a/src/spikeinterface/benchmark/tests/test_benchmark_peak_selection.py b/src/spikeinterface/benchmark/tests/test_benchmark_peak_selection.py new file mode 100644 index 0000000000..a6eb090a9d --- /dev/null +++ b/src/spikeinterface/benchmark/tests/test_benchmark_peak_selection.py @@ -0,0 +1,13 @@ +import pytest + +from pathlib import Path + + +@pytest.mark.skip() +def test_benchmark_peak_selection(create_cache_folder): + cache_folder = create_cache_folder + + +if __name__ == "__main__": + cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "benchmarks" + test_benchmark_peak_selection(cache_folder) diff --git a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py b/src/spikeinterface/benchmark/tests/test_benchmark_sorter.py similarity index 57% rename from src/spikeinterface/comparison/tests/test_groundtruthstudy.py rename to src/spikeinterface/benchmark/tests/test_benchmark_sorter.py index a92d6e9f77..db48d32fde 100644 --- a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_sorter.py @@ -4,12 +4,12 @@ from spikeinterface import generate_ground_truth_recording from spikeinterface.preprocessing import bandpass_filter -from spikeinterface.comparison import GroundTruthStudy +from spikeinterface.benchmark import SorterStudy @pytest.fixture(scope="module") def setup_module(tmp_path_factory): - study_folder = tmp_path_factory.mktemp("study_folder") + study_folder = tmp_path_factory.mktemp("sorter_study_folder") if study_folder.is_dir(): shutil.rmtree(study_folder) create_a_study(study_folder) @@ -36,63 +36,55 @@ def create_a_study(study_folder): ("tdc2", "no-preprocess", "tetrode"): { "label": "tridesclous2 without preprocessing and standard params", "dataset": "toy_tetrode", - "run_sorter_params": { + "params": { "sorter_name": "tridesclous2", }, - "comparison_params": {}, }, # ("tdc2", "with-preprocess", "probe32"): { "label": "tridesclous2 with preprocessing standar params", "dataset": "toy_probe32_preprocess", - "run_sorter_params": { + "params": { "sorter_name": "tridesclous2", }, - "comparison_params": {}, }, - # we comment this at the moement because SC2 is quite slow for testing - # ("sc2", "no-preprocess", "tetrode"): { - # "label": "spykingcircus2 without preprocessing standar params", - # "dataset": "toy_tetrode", - # "run_sorter_params": { - # "sorter_name": "spykingcircus2", - # }, - # "comparison_params": { - # }, - # }, } - study = GroundTruthStudy.create( + study = SorterStudy.create( study_folder, datasets=datasets, cases=cases, levels=["sorter_name", "processing", "probe_type"] ) # print(study) -def test_GroundTruthStudy(setup_module): +def test_SorterStudy(setup_module): + # job_kwargs = dict(n_jobs=2, chunk_duration="1s") + study_folder = setup_module - study = GroundTruthStudy(study_folder) + study = SorterStudy(study_folder) print(study) - study.run_sorters(verbose=True) - - print(study.sortings) - - print(study.comparisons) - study.run_comparisons() - print(study.comparisons) - - study.create_sorting_analyzer_gt(n_jobs=-1) + # # this run the sorters + study.run() - study.compute_metrics() + # # this run comparisons + study.compute_results() + print(study) - for key in study.cases: - metrics = study.get_metrics(key) - print(metrics) + # this is from the base class + rt = study.get_run_times() + # rt = study.plot_run_times() + # import matplotlib.pyplot as plt + # plt.show() - study.get_performance_by_unit() - study.get_count_units() + perf_by_unit = study.get_performance_by_unit() + # print(perf_by_unit) + count_units = study.get_count_units() + # print(count_units) if __name__ == "__main__": - setup_module() - test_GroundTruthStudy() + study_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "benchmarks" / "test_SorterStudy" + if study_folder.exists(): + shutil.rmtree(study_folder) + create_a_study(study_folder) + test_SorterStudy(study_folder) diff --git a/src/spikeinterface/comparison/__init__.py b/src/spikeinterface/comparison/__init__.py index 648ef4ed70..f4ada19f73 100644 --- a/src/spikeinterface/comparison/__init__.py +++ b/src/spikeinterface/comparison/__init__.py @@ -30,8 +30,8 @@ ) from .groundtruthstudy import GroundTruthStudy -from .collision import CollisionGTComparison, CollisionGTStudy -from .correlogram import CorrelogramGTComparison, CorrelogramGTStudy +from .collision import CollisionGTComparison +from .correlogram import CorrelogramGTComparison from .hybrid import ( HybridSpikesRecording, diff --git a/src/spikeinterface/comparison/collision.py b/src/spikeinterface/comparison/collision.py index 574bd16093..12bfab84ed 100644 --- a/src/spikeinterface/comparison/collision.py +++ b/src/spikeinterface/comparison/collision.py @@ -172,71 +172,75 @@ def compute_collision_by_similarity(self, similarity_matrix, unit_ids=None, good return similarities, recall_scores, pair_names -class CollisionGTStudy(GroundTruthStudy): - def run_comparisons(self, case_keys=None, exhaustive_gt=True, collision_lag=2.0, nbins=11, **kwargs): - _kwargs = dict() - _kwargs.update(kwargs) - _kwargs["exhaustive_gt"] = exhaustive_gt - _kwargs["collision_lag"] = collision_lag - _kwargs["nbins"] = nbins - GroundTruthStudy.run_comparisons(self, case_keys=None, comparison_class=CollisionGTComparison, **_kwargs) - self.exhaustive_gt = exhaustive_gt - self.collision_lag = collision_lag - - def get_lags(self, key): - comp = self.comparisons[key] - fs = comp.sorting1.get_sampling_frequency() - lags = comp.bins / fs * 1000.0 - return lags - - def precompute_scores_by_similarities(self, case_keys=None, good_only=False, min_accuracy=0.9): - import sklearn - - if case_keys is None: - case_keys = self.cases.keys() - - self.all_similarities = {} - self.all_recall_scores = {} - self.good_only = good_only - - for key in case_keys: - templates = self.get_templates(key) - flat_templates = templates.reshape(templates.shape[0], -1) - similarity = sklearn.metrics.pairwise.cosine_similarity(flat_templates) - comp = self.comparisons[key] - similarities, recall_scores, pair_names = comp.compute_collision_by_similarity( - similarity, good_only=good_only, min_accuracy=min_accuracy - ) - self.all_similarities[key] = similarities - self.all_recall_scores[key] = recall_scores - - def get_mean_over_similarity_range(self, similarity_range, key): - idx = (self.all_similarities[key] >= similarity_range[0]) & (self.all_similarities[key] <= similarity_range[1]) - all_similarities = self.all_similarities[key][idx] - all_recall_scores = self.all_recall_scores[key][idx] - - order = np.argsort(all_similarities) - all_similarities = all_similarities[order] - all_recall_scores = all_recall_scores[order, :] - - mean_recall_scores = np.nanmean(all_recall_scores, axis=0) - - return mean_recall_scores - - def get_lag_profile_over_similarity_bins(self, similarity_bins, key): - all_similarities = self.all_similarities[key] - all_recall_scores = self.all_recall_scores[key] - - order = np.argsort(all_similarities) - all_similarities = all_similarities[order] - all_recall_scores = all_recall_scores[order, :] - - result = {} - - for i in range(similarity_bins.size - 1): - cmin, cmax = similarity_bins[i], similarity_bins[i + 1] - amin, amax = np.searchsorted(all_similarities, [cmin, cmax]) - mean_recall_scores = np.nanmean(all_recall_scores[amin:amax], axis=0) - result[(cmin, cmax)] = mean_recall_scores - - return result +# This is removed at the moment. +# We need to move this maybe one day in benchmark. +# please do not delete this + +# class CollisionGTStudy(GroundTruthStudy): +# def run_comparisons(self, case_keys=None, exhaustive_gt=True, collision_lag=2.0, nbins=11, **kwargs): +# _kwargs = dict() +# _kwargs.update(kwargs) +# _kwargs["exhaustive_gt"] = exhaustive_gt +# _kwargs["collision_lag"] = collision_lag +# _kwargs["nbins"] = nbins +# GroundTruthStudy.run_comparisons(self, case_keys=None, comparison_class=CollisionGTComparison, **_kwargs) +# self.exhaustive_gt = exhaustive_gt +# self.collision_lag = collision_lag + +# def get_lags(self, key): +# comp = self.comparisons[key] +# fs = comp.sorting1.get_sampling_frequency() +# lags = comp.bins / fs * 1000.0 +# return lags + +# def precompute_scores_by_similarities(self, case_keys=None, good_only=False, min_accuracy=0.9): +# import sklearn + +# if case_keys is None: +# case_keys = self.cases.keys() + +# self.all_similarities = {} +# self.all_recall_scores = {} +# self.good_only = good_only + +# for key in case_keys: +# templates = self.get_templates(key) +# flat_templates = templates.reshape(templates.shape[0], -1) +# similarity = sklearn.metrics.pairwise.cosine_similarity(flat_templates) +# comp = self.comparisons[key] +# similarities, recall_scores, pair_names = comp.compute_collision_by_similarity( +# similarity, good_only=good_only, min_accuracy=min_accuracy +# ) +# self.all_similarities[key] = similarities +# self.all_recall_scores[key] = recall_scores + +# def get_mean_over_similarity_range(self, similarity_range, key): +# idx = (self.all_similarities[key] >= similarity_range[0]) & (self.all_similarities[key] <= similarity_range[1]) +# all_similarities = self.all_similarities[key][idx] +# all_recall_scores = self.all_recall_scores[key][idx] + +# order = np.argsort(all_similarities) +# all_similarities = all_similarities[order] +# all_recall_scores = all_recall_scores[order, :] + +# mean_recall_scores = np.nanmean(all_recall_scores, axis=0) + +# return mean_recall_scores + +# def get_lag_profile_over_similarity_bins(self, similarity_bins, key): +# all_similarities = self.all_similarities[key] +# all_recall_scores = self.all_recall_scores[key] + +# order = np.argsort(all_similarities) +# all_similarities = all_similarities[order] +# all_recall_scores = all_recall_scores[order, :] + +# result = {} + +# for i in range(similarity_bins.size - 1): +# cmin, cmax = similarity_bins[i], similarity_bins[i + 1] +# amin, amax = np.searchsorted(all_similarities, [cmin, cmax]) +# mean_recall_scores = np.nanmean(all_recall_scores[amin:amax], axis=0) +# result[(cmin, cmax)] = mean_recall_scores + +# return result diff --git a/src/spikeinterface/comparison/correlogram.py b/src/spikeinterface/comparison/correlogram.py index 0cafef2c12..717d11a3fa 100644 --- a/src/spikeinterface/comparison/correlogram.py +++ b/src/spikeinterface/comparison/correlogram.py @@ -128,57 +128,60 @@ def compute_correlogram_by_similarity(self, similarity_matrix, window_ms=None): return similarities, errors -class CorrelogramGTStudy(GroundTruthStudy): - def run_comparisons( - self, case_keys=None, exhaustive_gt=True, window_ms=100.0, bin_ms=1.0, well_detected_score=0.8, **kwargs - ): - _kwargs = dict() - _kwargs.update(kwargs) - _kwargs["exhaustive_gt"] = exhaustive_gt - _kwargs["window_ms"] = window_ms - _kwargs["bin_ms"] = bin_ms - _kwargs["well_detected_score"] = well_detected_score - GroundTruthStudy.run_comparisons(self, case_keys=None, comparison_class=CorrelogramGTComparison, **_kwargs) - self.exhaustive_gt = exhaustive_gt - - @property - def time_bins(self): - for key, value in self.comparisons.items(): - return value.time_bins - - def precompute_scores_by_similarities(self, case_keys=None, good_only=True): - import sklearn.metrics - - if case_keys is None: - case_keys = self.cases.keys() - - self.all_similarities = {} - self.all_errors = {} - - for key in case_keys: - templates = self.get_templates(key) - flat_templates = templates.reshape(templates.shape[0], -1) - similarity = sklearn.metrics.pairwise.cosine_similarity(flat_templates) - comp = self.comparisons[key] - similarities, errors = comp.compute_correlogram_by_similarity(similarity) - - self.all_similarities[key] = similarities - self.all_errors[key] = errors - - def get_error_profile_over_similarity_bins(self, similarity_bins, key): - all_similarities = self.all_similarities[key] - all_errors = self.all_errors[key] - - order = np.argsort(all_similarities) - all_similarities = all_similarities[order] - all_errors = all_errors[order, :] - - result = {} - - for i in range(similarity_bins.size - 1): - cmin, cmax = similarity_bins[i], similarity_bins[i + 1] - amin, amax = np.searchsorted(all_similarities, [cmin, cmax]) - mean_errors = np.nanmean(all_errors[amin:amax], axis=0) - result[(cmin, cmax)] = mean_errors - - return result +# This is removed at the moment. +# We need to move this maybe one day in benchmark + +# class CorrelogramGTStudy(GroundTruthStudy): +# def run_comparisons( +# self, case_keys=None, exhaustive_gt=True, window_ms=100.0, bin_ms=1.0, well_detected_score=0.8, **kwargs +# ): +# _kwargs = dict() +# _kwargs.update(kwargs) +# _kwargs["exhaustive_gt"] = exhaustive_gt +# _kwargs["window_ms"] = window_ms +# _kwargs["bin_ms"] = bin_ms +# _kwargs["well_detected_score"] = well_detected_score +# GroundTruthStudy.run_comparisons(self, case_keys=None, comparison_class=CorrelogramGTComparison, **_kwargs) +# self.exhaustive_gt = exhaustive_gt + +# @property +# def time_bins(self): +# for key, value in self.comparisons.items(): +# return value.time_bins + +# def precompute_scores_by_similarities(self, case_keys=None, good_only=True): +# import sklearn.metrics + +# if case_keys is None: +# case_keys = self.cases.keys() + +# self.all_similarities = {} +# self.all_errors = {} + +# for key in case_keys: +# templates = self.get_templates(key) +# flat_templates = templates.reshape(templates.shape[0], -1) +# similarity = sklearn.metrics.pairwise.cosine_similarity(flat_templates) +# comp = self.comparisons[key] +# similarities, errors = comp.compute_correlogram_by_similarity(similarity) + +# self.all_similarities[key] = similarities +# self.all_errors[key] = errors + +# def get_error_profile_over_similarity_bins(self, similarity_bins, key): +# all_similarities = self.all_similarities[key] +# all_errors = self.all_errors[key] + +# order = np.argsort(all_similarities) +# all_similarities = all_similarities[order] +# all_errors = all_errors[order, :] + +# result = {} + +# for i in range(similarity_bins.size - 1): +# cmin, cmax = similarity_bins[i], similarity_bins[i + 1] +# amin, amax = np.searchsorted(all_similarities, [cmin, cmax]) +# mean_errors = np.nanmean(all_errors[amin:amax], axis=0) +# result[(cmin, cmax)] = mean_errors + +# return result diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index 8929d6983c..df9e1420cb 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -1,441 +1,21 @@ -from __future__ import annotations - -from pathlib import Path -import shutil -import os -import json -import pickle - -import numpy as np - -from spikeinterface.core import load_extractor, create_sorting_analyzer, load_sorting_analyzer -from spikeinterface.sorters import run_sorter_jobs, read_sorter_folder - -from spikeinterface.qualitymetrics import compute_quality_metrics - -from .paircomparisons import compare_sorter_to_ground_truth, GroundTruthComparison - - -# TODO later : save comparison in folders when comparison object will be able to serialize - - -# This is to separate names when the key are tuples when saving folders -# _key_separator = "_##_" -_key_separator = "_-°°-_" +_txt_error_message = """ +GroundTruthStudy has been replaced by SorterStudy with similar API but not back compatible folder loading. +You can do: +from spikeinterface.benchmark import SorterStudy +study = SorterStudy.create(study_folder, datasets=..., cases=..., levels=...) +study.run() # this run sorters +study.compute_results() # this run the comparisons +# and then some ploting +study.plot_agreements() +study.plot_performances_vs_snr() +... +""" class GroundTruthStudy: - """ - This class is an helper function to run any comparison on several "cases" for many ground-truth dataset. - - "cases" refer to: - * several sorters for comparisons - * same sorter with differents parameters - * any combination of these (and more) - - For increased flexibility, cases keys can be a tuple so that we can vary complexity along several - "levels" or "axis" (paremeters or sorters). - In this case, the result dataframes will have `MultiIndex` to handle the different levels. - - A ground-truth dataset is made of a `Recording` and a `Sorting` object. For example, it can be a simulated dataset with MEArec or internally generated (see - :py:func:`~spikeinterface.core.generate.generate_ground_truth_recording()`). - - This GroundTruthStudy have been refactor in version 0.100 to be more flexible than previous versions. - Note that the underlying folder structure is not backward compatible! - - Parameters - ---------- - study_folder : str | Path - Path to folder containing `GroundTruthStudy` - """ - def __init__(self, study_folder): - self.folder = Path(study_folder) - - self.datasets = {} - self.cases = {} - self.sortings = {} - self.comparisons = {} - self.colors = None - - self.scan_folder() + raise RuntimeError(_txt_error_message) @classmethod def create(cls, study_folder, datasets={}, cases={}, levels=None): - # check that cases keys are homogeneous - key0 = list(cases.keys())[0] - if isinstance(key0, str): - assert all(isinstance(key, str) for key in cases.keys()), "Keys for cases are not homogeneous" - if levels is None: - levels = "level0" - else: - assert isinstance(levels, str) - elif isinstance(key0, tuple): - assert all(isinstance(key, tuple) for key in cases.keys()), "Keys for cases are not homogeneous" - num_levels = len(key0) - assert all( - len(key) == num_levels for key in cases.keys() - ), "Keys for cases are not homogeneous, tuple negth differ" - if levels is None: - levels = [f"level{i}" for i in range(num_levels)] - else: - levels = list(levels) - assert len(levels) == num_levels - else: - raise ValueError("Keys for cases must str or tuple") - - study_folder = Path(study_folder) - study_folder.mkdir(exist_ok=False, parents=True) - - (study_folder / "datasets").mkdir() - (study_folder / "datasets" / "recordings").mkdir() - (study_folder / "datasets" / "gt_sortings").mkdir() - (study_folder / "sorters").mkdir() - (study_folder / "sortings").mkdir() - (study_folder / "sortings" / "run_logs").mkdir() - (study_folder / "metrics").mkdir() - (study_folder / "comparisons").mkdir() - - for key, (rec, gt_sorting) in datasets.items(): - assert "/" not in key, "'/' cannot be in the key name!" - assert "\\" not in key, "'\\' cannot be in the key name!" - - # recordings are pickled - rec.dump_to_pickle(study_folder / f"datasets/recordings/{key}.pickle") - - # sortings are pickled + saved as NumpyFolderSorting - gt_sorting.dump_to_pickle(study_folder / f"datasets/gt_sortings/{key}.pickle") - gt_sorting.save(format="numpy_folder", folder=study_folder / f"datasets/gt_sortings/{key}") - - info = {} - info["levels"] = levels - (study_folder / "info.json").write_text(json.dumps(info, indent=4), encoding="utf8") - - # cases is dumped to a pickle file, json is not possible because of the tuple key - (study_folder / "cases.pickle").write_bytes(pickle.dumps(cases)) - - return cls(study_folder) - - def scan_folder(self): - if not (self.folder / "datasets").exists(): - raise ValueError(f"This is folder is not a GroundTruthStudy : {self.folder.absolute()}") - - with open(self.folder / "info.json", "r") as f: - self.info = json.load(f) - - self.levels = self.info["levels"] - - for rec_file in (self.folder / "datasets" / "recordings").glob("*.pickle"): - key = rec_file.stem - rec = load_extractor(rec_file) - gt_sorting = load_extractor(self.folder / f"datasets" / "gt_sortings" / key) - self.datasets[key] = (rec, gt_sorting) - - with open(self.folder / "cases.pickle", "rb") as f: - self.cases = pickle.load(f) - - self.sortings = {k: None for k in self.cases} - self.comparisons = {k: None for k in self.cases} - for key in self.cases: - sorting_folder = self.folder / "sortings" / self.key_to_str(key) - if sorting_folder.exists(): - self.sortings[key] = load_extractor(sorting_folder) - - comparison_file = self.folder / "comparisons" / (self.key_to_str(key) + ".pickle") - if comparison_file.exists(): - with open(comparison_file, mode="rb") as f: - try: - self.comparisons[key] = pickle.load(f) - except Exception: - pass - - def __repr__(self): - t = f"{self.__class__.__name__} {self.folder.stem} \n" - t += f" datasets: {len(self.datasets)} {list(self.datasets.keys())}\n" - t += f" cases: {len(self.cases)} {list(self.cases.keys())}\n" - num_computed = sum([1 for sorting in self.sortings.values() if sorting is not None]) - t += f" computed: {num_computed}\n" - - return t - - def key_to_str(self, key): - if isinstance(key, str): - return key - elif isinstance(key, tuple): - return _key_separator.join(key) - else: - raise ValueError("Keys for cases must str or tuple") - - def remove_sorting(self, key): - sorting_folder = self.folder / "sortings" / self.key_to_str(key) - log_file = self.folder / "sortings" / "run_logs" / f"{self.key_to_str(key)}.json" - comparison_file = self.folder / "comparisons" / self.key_to_str(key) - self.sortings[key] = None - self.comparisons[key] = None - if sorting_folder.exists(): - shutil.rmtree(sorting_folder) - for f in (log_file, comparison_file): - if f.exists(): - f.unlink() - - def set_colors(self, colors=None, map_name="tab20"): - from spikeinterface.widgets import get_some_colors - - if colors is None: - case_keys = list(self.cases.keys()) - self.colors = get_some_colors( - case_keys, map_name=map_name, color_engine="matplotlib", shuffle=False, margin=0 - ) - else: - self.colors = colors - - def get_colors(self): - if self.colors is None: - self.set_colors() - return self.colors - - def run_sorters(self, case_keys=None, engine="loop", engine_kwargs={}, keep=True, verbose=False): - if case_keys is None: - case_keys = self.cases.keys() - - job_list = [] - for key in case_keys: - sorting_folder = self.folder / "sortings" / self.key_to_str(key) - sorting_exists = sorting_folder.exists() - - sorter_folder = self.folder / "sorters" / self.key_to_str(key) - sorter_folder_exists = sorter_folder.exists() - - if keep: - if sorting_exists: - continue - if sorter_folder_exists: - # the sorter folder exists but havent been copied to sortings folder - sorting = read_sorter_folder(sorter_folder, raise_error=False) - if sorting is not None: - # save and skip - self.copy_sortings(case_keys=[key]) - continue - - self.remove_sorting(key) - - if sorter_folder_exists: - shutil.rmtree(sorter_folder) - - params = self.cases[key]["run_sorter_params"].copy() - # this ensure that sorter_name is given - recording, _ = self.datasets[self.cases[key]["dataset"]] - sorter_name = params.pop("sorter_name") - job = dict( - sorter_name=sorter_name, - recording=recording, - output_folder=sorter_folder, - ) - job.update(params) - # the verbose is overwritten and global to all run_sorters - job["verbose"] = verbose - job["with_output"] = False - job_list.append(job) - - run_sorter_jobs(job_list, engine=engine, engine_kwargs=engine_kwargs, return_output=False) - - # TODO later create a list in laucher for engine blocking and non-blocking - if engine not in ("slurm",): - self.copy_sortings(case_keys) - - def copy_sortings(self, case_keys=None, force=True): - if case_keys is None: - case_keys = self.cases.keys() - - for key in case_keys: - sorting_folder = self.folder / "sortings" / self.key_to_str(key) - sorter_folder = self.folder / "sorters" / self.key_to_str(key) - log_file = self.folder / "sortings" / "run_logs" / f"{self.key_to_str(key)}.json" - - if (sorter_folder / "spikeinterface_log.json").exists(): - sorting = read_sorter_folder( - sorter_folder, raise_error=False, register_recording=False, sorting_info=False - ) - else: - sorting = None - - if sorting is not None: - if sorting_folder.exists(): - if force: - self.remove_sorting(key) - else: - continue - - sorting = sorting.save(format="numpy_folder", folder=sorting_folder) - self.sortings[key] = sorting - - # copy logs - shutil.copyfile(sorter_folder / "spikeinterface_log.json", log_file) - - def run_comparisons(self, case_keys=None, comparison_class=GroundTruthComparison, **kwargs): - if case_keys is None: - case_keys = self.cases.keys() - - for key in case_keys: - dataset_key = self.cases[key]["dataset"] - _, gt_sorting = self.datasets[dataset_key] - sorting = self.sortings[key] - if sorting is None: - self.comparisons[key] = None - continue - comp = comparison_class(gt_sorting, sorting, **kwargs) - self.comparisons[key] = comp - - comparison_file = self.folder / "comparisons" / (self.key_to_str(key) + ".pickle") - with open(comparison_file, mode="wb") as f: - pickle.dump(comp, f) - - def get_run_times(self, case_keys=None): - import pandas as pd - - if case_keys is None: - case_keys = self.cases.keys() - - log_folder = self.folder / "sortings" / "run_logs" - - run_times = {} - for key in case_keys: - log_file = log_folder / f"{self.key_to_str(key)}.json" - with open(log_file, mode="r") as logfile: - log = json.load(logfile) - run_time = log.get("run_time", None) - run_times[key] = run_time - - return pd.Series(run_times, name="run_time") - - def create_sorting_analyzer_gt(self, case_keys=None, random_params={}, waveforms_params={}, **job_kwargs): - if case_keys is None: - case_keys = self.cases.keys() - - base_folder = self.folder / "sorting_analyzer" - base_folder.mkdir(exist_ok=True) - - dataset_keys = [self.cases[key]["dataset"] for key in case_keys] - dataset_keys = set(dataset_keys) - for dataset_key in dataset_keys: - # the waveforms depend on the dataset key - folder = base_folder / self.key_to_str(dataset_key) - recording, gt_sorting = self.datasets[dataset_key] - sorting_analyzer = create_sorting_analyzer(gt_sorting, recording, format="binary_folder", folder=folder) - sorting_analyzer.compute("random_spikes", **random_params) - sorting_analyzer.compute("templates", **job_kwargs) - sorting_analyzer.compute("noise_levels") - - def get_sorting_analyzer(self, case_key=None, dataset_key=None): - if case_key is not None: - dataset_key = self.cases[case_key]["dataset"] - - folder = self.folder / "sorting_analyzer" / self.key_to_str(dataset_key) - sorting_analyzer = load_sorting_analyzer(folder) - return sorting_analyzer - - # def get_templates(self, key, mode="average"): - # analyzer = self.get_sorting_analyzer(case_key=key) - # templates = sorting_analyzer.get_all_templates(mode=mode) - # return templates - - def compute_metrics(self, case_keys=None, metric_names=["snr", "firing_rate"], force=False): - if case_keys is None: - case_keys = self.cases.keys() - - done = [] - for key in case_keys: - dataset_key = self.cases[key]["dataset"] - if dataset_key in done: - # some case can share the same waveform extractor - continue - done.append(dataset_key) - filename = self.folder / "metrics" / f"{self.key_to_str(dataset_key)}.csv" - if filename.exists(): - if force: - os.remove(filename) - else: - continue - analyzer = self.get_sorting_analyzer(key) - metrics = compute_quality_metrics(analyzer, metric_names=metric_names) - metrics.to_csv(filename, sep="\t", index=True) - - def get_metrics(self, key): - import pandas as pd - - dataset_key = self.cases[key]["dataset"] - - filename = self.folder / "metrics" / f"{self.key_to_str(dataset_key)}.csv" - if not filename.exists(): - return - metrics = pd.read_csv(filename, sep="\t", index_col=0) - dataset_key = self.cases[key]["dataset"] - recording, gt_sorting = self.datasets[dataset_key] - metrics.index = gt_sorting.unit_ids - return metrics - - def get_units_snr(self, key): - return self.get_metrics(key)["snr"] - - def get_performance_by_unit(self, case_keys=None): - import pandas as pd - - if case_keys is None: - case_keys = self.cases.keys() - - perf_by_unit = [] - for key in case_keys: - comp = self.comparisons.get(key, None) - assert comp is not None, "You need to do study.run_comparisons() first" - - perf = comp.get_performance(method="by_unit", output="pandas") - - if isinstance(key, str): - perf[self.levels] = key - elif isinstance(key, tuple): - for col, k in zip(self.levels, key): - perf[col] = k - - perf = perf.reset_index() - perf_by_unit.append(perf) - - perf_by_unit = pd.concat(perf_by_unit) - perf_by_unit = perf_by_unit.set_index(self.levels) - perf_by_unit = perf_by_unit.sort_index() - return perf_by_unit - - def get_count_units(self, case_keys=None, well_detected_score=None, redundant_score=None, overmerged_score=None): - import pandas as pd - - if case_keys is None: - case_keys = list(self.cases.keys()) - - if isinstance(case_keys[0], str): - index = pd.Index(case_keys, name=self.levels) - else: - index = pd.MultiIndex.from_tuples(case_keys, names=self.levels) - - columns = ["num_gt", "num_sorter", "num_well_detected"] - comp = self.comparisons[case_keys[0]] - if comp.exhaustive_gt: - columns.extend(["num_false_positive", "num_redundant", "num_overmerged", "num_bad"]) - count_units = pd.DataFrame(index=index, columns=columns, dtype=int) - - for key in case_keys: - comp = self.comparisons.get(key, None) - assert comp is not None, "You need to do study.run_comparisons() first" - - gt_sorting = comp.sorting1 - sorting = comp.sorting2 - - count_units.loc[key, "num_gt"] = len(gt_sorting.get_unit_ids()) - count_units.loc[key, "num_sorter"] = len(sorting.get_unit_ids()) - count_units.loc[key, "num_well_detected"] = comp.count_well_detected_units(well_detected_score) - - if comp.exhaustive_gt: - count_units.loc[key, "num_redundant"] = comp.count_redundant_units(redundant_score) - count_units.loc[key, "num_overmerged"] = comp.count_overmerged_units(overmerged_score) - count_units.loc[key, "num_false_positive"] = comp.count_false_positive_units(redundant_score) - count_units.loc[key, "num_bad"] = comp.count_bad_units() - - return count_units + raise RuntimeError(_txt_error_message) diff --git a/src/spikeinterface/comparison/multicomparisons.py b/src/spikeinterface/comparison/multicomparisons.py index f7d9782a07..6a4be86796 100644 --- a/src/spikeinterface/comparison/multicomparisons.py +++ b/src/spikeinterface/comparison/multicomparisons.py @@ -7,7 +7,7 @@ import numpy as np -from spikeinterface.core import load_extractor, BaseSorting, BaseSortingSegment +from spikeinterface.core import load, BaseSorting, BaseSortingSegment from spikeinterface.core.core_tools import define_function_from_class from .basecomparison import BaseMultiComparison, MixinSpikeTrainComparison, MixinTemplateComparison from .paircomparisons import SymmetricSortingComparison, TemplateComparison @@ -230,7 +230,7 @@ def load_from_folder(folder_path): with (folder_path / "sortings.json").open() as f: dict_sortings = json.load(f) name_list = list(dict_sortings.keys()) - sorting_list = [load_extractor(v, base_folder=folder_path) for v in dict_sortings.values()] + sorting_list = [load(v, base_folder=folder_path) for v in dict_sortings.values()] mcmp = MultiSortingComparison(sorting_list=sorting_list, name_list=list(name_list), do_matching=False, **kwargs) filename = str(folder_path / "multicomparison.gpickle") with open(filename, "rb") as f: diff --git a/src/spikeinterface/comparison/tests/test_hybrid.py b/src/spikeinterface/comparison/tests/test_hybrid.py index ce409ca778..22cc141f65 100644 --- a/src/spikeinterface/comparison/tests/test_hybrid.py +++ b/src/spikeinterface/comparison/tests/test_hybrid.py @@ -1,7 +1,7 @@ import pytest import shutil from pathlib import Path -from spikeinterface.core import extract_waveforms, load_waveforms, load_extractor +from spikeinterface.core import extract_waveforms, load_waveforms, load from spikeinterface.core.testing import check_recordings_equal from spikeinterface.comparison import ( create_hybrid_units_recording, @@ -52,7 +52,7 @@ def test_hybrid_units_recording(setup_module): ) # Check dumpability - saved_loaded = load_extractor(hybrid_units_recording.to_dict()) + saved_loaded = load(hybrid_units_recording.to_dict()) check_recordings_equal(hybrid_units_recording, saved_loaded, return_scaled=False) saved_1job = hybrid_units_recording.save(folder=cache_folder / "units_1job") @@ -81,7 +81,7 @@ def test_hybrid_spikes_recording(setup_module): ) # Check dumpability - saved_loaded = load_extractor(hybrid_spikes_recording.to_dict()) + saved_loaded = load(hybrid_spikes_recording.to_dict()) check_recordings_equal(hybrid_spikes_recording, saved_loaded, return_scaled=False) saved_1job = hybrid_spikes_recording.save(folder=cache_folder / "spikes_1job") diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index ead7007920..f09458f6a6 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -1,10 +1,11 @@ -from .base import load_extractor # , load_extractor_from_dict, load_extractor_from_json, load_extractor_from_pickle from .baserecording import BaseRecording, BaseRecordingSegment from .basesorting import BaseSorting, BaseSortingSegment, SpikeVectorSortingSegment from .baseevent import BaseEvent, BaseEventSegment from .basesnippets import BaseSnippets, BaseSnippetsSegment from .baserecordingsnippets import BaseRecordingSnippets +from .loading import load, load_extractor + # main extractor from dump and cache from .binaryrecordingextractor import BinaryRecordingExtractor, read_binary from .npzsortingextractor import NpzSortingExtractor, read_npz_sorting @@ -90,7 +91,14 @@ write_python, normal_pdf, ) -from .job_tools import ensure_n_jobs, ensure_chunk_size, ChunkRecordingExecutor, split_job_kwargs, fix_job_kwargs +from .job_tools import ( + get_best_job_kwargs, + ensure_n_jobs, + ensure_chunk_size, + ChunkRecordingExecutor, + split_job_kwargs, + fix_job_kwargs, +) from .recording_tools import ( write_binary_recording, write_to_h5_dataset_format, diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index bc5de63d07..447bbe562e 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -22,21 +22,23 @@ class ComputeRandomSpikes(AnalyzerExtension): """ - AnalyzerExtension that select some random spikes. + AnalyzerExtension that select somes random spikes. + This allows for a subsampling of spikes for further calculations and is important + for managing that amount of memory and speed of computation in the analyzer. This will be used by the `waveforms`/`templates` extensions. - This internally use `random_spikes_selection()` parameters are the same. + This internally uses `random_spikes_selection()` parameters. Parameters ---------- - method: "uniform" | "all", default: "uniform" + method : "uniform" | "all", default: "uniform" The method to select the spikes - max_spikes_per_unit: int, default: 500 + max_spikes_per_unit : int, default: 500 The maximum number of spikes per unit, ignored if method="all" - margin_size: int, default: None + margin_size : int, default: None A margin on each border of segments to avoid border spikes, ignored if method="all" - seed: int or None, default: None + seed : int or None, default: None A seed for the random generator, ignored if method="all" Returns @@ -104,7 +106,7 @@ def get_random_spikes(self): return self._some_spikes def get_selected_indices_in_spike_train(self, unit_id, segment_index): - # usefull for Waveforms extractor backwars compatibility + # useful for WaveformExtractor backwards compatibility # In Waveforms extractor "selected_spikes" was a dict (key: unit_id) of list (segment_index) of indices of spikes in spiketrain sorting = self.sorting_analyzer.sorting random_spikes_indices = self.data["random_spikes_indices"] @@ -133,16 +135,16 @@ class ComputeWaveforms(AnalyzerExtension): Parameters ---------- - ms_before: float, default: 1.0 + ms_before : float, default: 1.0 The number of ms to extract before the spike events - ms_after: float, default: 2.0 + ms_after : float, default: 2.0 The number of ms to extract after the spike events - dtype: None | dtype, default: None + dtype : None | dtype, default: None The dtype of the waveforms. If None, the dtype of the recording is used. Returns ------- - waveforms: np.ndarray + waveforms : np.ndarray Array with computed waveforms with shape (num_random_spikes, num_samples, num_channels) """ @@ -380,7 +382,12 @@ def _set_params(self, ms_before: float = 1.0, ms_after: float = 2.0, operators=N assert isinstance(operators, list) for operator in operators: if isinstance(operator, str): - assert operator in ("average", "std", "median", "mad") + if operator not in ("average", "std", "median", "mad"): + error_msg = ( + f"You have entered an operator {operator} in your `operators` argument which is " + f"not supported. Please use any of ['average', 'std', 'median', 'mad'] instead." + ) + raise ValueError(error_msg) else: assert isinstance(operator, (list, tuple)) assert len(operator) == 2 @@ -405,9 +412,13 @@ def _run(self, verbose=False, **job_kwargs): self._compute_and_append_from_waveforms(self.params["operators"]) else: - for operator in self.params["operators"]: - if operator not in ("average", "std"): - raise ValueError(f"Computing templates with operators {operator} needs the 'waveforms' extension") + bad_operator_list = [ + operator for operator in self.params["operators"] if operator not in ("average", "std") + ] + if len(bad_operator_list) > 0: + raise ValueError( + f"Computing templates with operators {bad_operator_list} requires the 'waveforms' extension" + ) recording = self.sorting_analyzer.recording sorting = self.sorting_analyzer.sorting @@ -441,7 +452,7 @@ def _run(self, verbose=False, **job_kwargs): def _compute_and_append_from_waveforms(self, operators): if not self.sorting_analyzer.has_extension("waveforms"): - raise ValueError(f"Computing templates with operators {operators} needs the 'waveforms' extension") + raise ValueError(f"Computing templates with operators {operators} requires the 'waveforms' extension") unit_ids = self.sorting_analyzer.unit_ids channel_ids = self.sorting_analyzer.channel_ids @@ -466,7 +477,7 @@ def _compute_and_append_from_waveforms(self, operators): assert self.sorting_analyzer.has_extension( "random_spikes" - ), "compute templates requires the random_spikes extension. You can run sorting_analyzer.get_random_spikes()" + ), "compute 'templates' requires the random_spikes extension. You can run sorting_analyzer.compute('random_spikes')" some_spikes = self.sorting_analyzer.get_extension("random_spikes").get_random_spikes() for unit_index, unit_id in enumerate(unit_ids): spike_mask = some_spikes["unit_index"] == unit_index @@ -549,9 +560,17 @@ def _get_data(self, operator="average", percentile=None, outputs="numpy"): if operator != "percentile": key = operator else: - assert percentile is not None, "You must provide percentile=..." + assert percentile is not None, "You must provide percentile=... if `operator=percentile`" key = f"percentile_{percentile}" + if key not in self.data.keys(): + error_msg = ( + f"You have entered `operator={key}`, but the only operators calculated are " + f"{list(self.data.keys())}. Please use one of these as your `operator` in the " + f"`get_data` function." + ) + raise ValueError(error_msg) + templates_array = self.data[key] if outputs == "numpy": @@ -566,7 +585,7 @@ def _get_data(self, operator="average", percentile=None, outputs="numpy"): probe=self.sorting_analyzer.get_probe(), ) else: - raise ValueError("outputs must be numpy or Templates") + raise ValueError("outputs must be `numpy` or `Templates`") def get_templates(self, unit_ids=None, operator="average", percentile=None, save=True, outputs="numpy"): """ @@ -576,26 +595,26 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save Parameters ---------- - unit_ids: list or None + unit_ids : list or None Unit ids to retrieve waveforms for - operator: "average" | "median" | "std" | "percentile", default: "average" + operator : "average" | "median" | "std" | "percentile", default: "average" The operator to compute the templates - percentile: float, default: None + percentile : float, default: None Percentile to use for operator="percentile" - save: bool, default True + save : bool, default: True In case, the operator is not computed yet it can be saved to folder or zarr - outputs: "numpy" | "Templates" + outputs : "numpy" | "Templates", default: "numpy" Whether to return a numpy array or a Templates object Returns ------- - templates: np.array + templates : np.array | Templates The returned templates (num_units, num_samples, num_channels) """ if operator != "percentile": key = operator else: - assert percentile is not None, "You must provide percentile=..." + assert percentile is not None, "You must provide percentile=... if `operator='percentile'`" key = f"pencentile_{percentile}" if key in self.data: @@ -632,7 +651,7 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save is_scaled=self.sorting_analyzer.return_scaled, ) else: - raise ValueError("outputs must be numpy or Templates") + raise ValueError("`outputs` must be 'numpy' or 'Templates'") def get_unit_template(self, unit_id, operator="average"): """ @@ -642,7 +661,7 @@ def get_unit_template(self, unit_id, operator="average"): ---------- unit_id: str | int Unit id to retrieve waveforms for - operator: str + operator: str, default: "average" The operator to compute the templates Returns @@ -691,22 +710,23 @@ class ComputeNoiseLevels(AnalyzerExtension): need_recording = True use_nodepipeline = False need_job_kwargs = False + need_backward_compatibility_on_load = True def __init__(self, sorting_analyzer): AnalyzerExtension.__init__(self, sorting_analyzer) - def _set_params(self, num_chunks_per_segment=20, chunk_size=10000, seed=None): - params = dict(num_chunks_per_segment=num_chunks_per_segment, chunk_size=chunk_size, seed=seed) + def _set_params(self, **noise_level_params): + params = noise_level_params.copy() return params def _select_extension_data(self, unit_ids): - # this do not depend on units + # this does not depend on units return self.data def _merge_extension_data( self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs ): - # this do not depend on units + # this does not depend on units return self.data.copy() def _run(self, verbose=False): @@ -717,6 +737,15 @@ def _run(self, verbose=False): def _get_data(self): return self.data["noise_levels"] + def _handle_backward_compatibility_on_load(self): + # The old parameters used to be params=dict(num_chunks_per_segment=20, chunk_size=10000, seed=None) + # now it is handle more explicitly using random_slices_kwargs=dict() + for key in ("num_chunks_per_segment", "chunk_size", "seed"): + if key in self.params: + if "random_slices_kwargs" not in self.params: + self.params["random_slices_kwargs"] = dict() + self.params["random_slices_kwargs"][key] = self.params.pop(key) + register_result_extension(ComputeNoiseLevels) compute_noise_levels = ComputeNoiseLevels.function_factory() diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 1fa218851b..2dc7e0e9bc 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -16,7 +16,7 @@ from .globals import get_global_tmp_folder, is_set_global_tmp_folder from .core_tools import ( - check_json, + is_path_remote, clean_zarr_folder_name, is_dict_extractor, SIJsonEncoder, @@ -673,7 +673,7 @@ def dump_to_json( ) -> None: """ Dump recording extractor to json file. - The extractor can be re-loaded with load_extractor(json_file) + The extractor can be re-loaded with load(json_file) Parameters ---------- @@ -715,7 +715,7 @@ def dump_to_pickle( ): """ Dump recording extractor to a pickle file. - The extractor can be re-loaded with load_extractor(pickle_file) + The extractor can be re-loaded with load(pickle_file) Parameters ---------- @@ -752,7 +752,9 @@ def dump_to_pickle( file_path.write_bytes(pickle.dumps(dump_dict)) @staticmethod - def load(file_path: Union[str, Path], base_folder: Optional[Union[Path, str, bool]] = None) -> "BaseExtractor": + def load( + file_or_folder_path: Union[str, Path], base_folder: Optional[Union[Path, str, bool]] = None + ) -> "BaseExtractor": """ Load extractor from file path (.json or .pkl) @@ -761,62 +763,10 @@ def load(file_path: Union[str, Path], base_folder: Optional[Union[Path, str, boo * save (...) a folder which contain data + json (or pickle) + metadata. """ + # use loading.py and keep backward compatibility + from .loading import load - file_path = Path(file_path) - if base_folder is True: - base_folder = file_path.parent - - if file_path.is_file(): - # standard case based on a file (json or pickle) - if str(file_path).endswith(".json"): - with open(file_path, "r") as f: - d = json.load(f) - elif str(file_path).endswith(".pkl") or str(file_path).endswith(".pickle"): - with open(file_path, "rb") as f: - d = pickle.load(f) - else: - raise ValueError(f"Impossible to load {file_path}") - if "warning" in d: - print("The extractor was not serializable to file") - return None - - extractor = BaseExtractor.from_dict(d, base_folder=base_folder) - return extractor - - elif file_path.is_dir(): - # case from a folder after a calling extractor.save(...) - folder = file_path - file = None - - if folder.suffix == ".zarr": - from .zarrextractors import read_zarr - - extractor = read_zarr(folder) - else: - # the is spikeinterface<=0.94.0 - # a folder came with 'cached.json' - for dump_ext in ("json", "pkl", "pickle"): - f = folder / f"cached.{dump_ext}" - if f.is_file(): - file = f - - # spikeinterface>=0.95.0 - f = folder / f"si_folder.json" - if f.is_file(): - file = f - - if file is None: - raise ValueError(f"This folder is not a cached folder {file_path}") - extractor = BaseExtractor.load(file, base_folder=folder) - - return extractor - - else: - error_msg = ( - f"{file_path} is not a file or a folder. It should point to either a json, pickle file or a " - "folder that is the result of extractor.save(...)" - ) - raise ValueError(error_msg) + return load(file_or_folder_path, base_folder=base_folder) def __reduce__(self): """ @@ -1167,50 +1117,6 @@ def _check_same_version(class_string, version): return "unknown" -def load_extractor(file_or_folder_or_dict, base_folder=None) -> BaseExtractor: - """ - Instantiate extractor from: - * a dict - * a json file - * a pickle file - * folder (after save) - * a zarr folder (after save) - - Parameters - ---------- - file_or_folder_or_dict : dictionary or folder or file (json, pickle) - The file path, folder path, or dictionary to load the extractor from - base_folder : str | Path | bool (optional) - The base folder to make relative paths absolute. - If True and file_or_folder_or_dict is a file, the parent folder of the file is used. - - Returns - ------- - extractor: Recording or Sorting - The loaded extractor object - """ - if isinstance(file_or_folder_or_dict, dict): - assert not isinstance(base_folder, bool), "`base_folder` must be a string or Path when loading from dict" - return BaseExtractor.from_dict(file_or_folder_or_dict, base_folder=base_folder) - else: - return BaseExtractor.load(file_or_folder_or_dict, base_folder=base_folder) - - -def load_extractor_from_dict(d, base_folder=None) -> BaseExtractor: - warnings.warn("Use load_extractor(..) instead") - return BaseExtractor.from_dict(d, base_folder=base_folder) - - -def load_extractor_from_json(json_file, base_folder=None) -> "BaseExtractor": - warnings.warn("Use load_extractor(..) instead") - return BaseExtractor.load(json_file, base_folder=base_folder) - - -def load_extractor_from_pickle(pkl_file, base_folder=None) -> "BaseExtractor": - warnings.warn("Use load_extractor(..) instead") - return BaseExtractor.load(pkl_file, base_folder=base_folder) - - class BaseSegment: def __init__(self): self._parent_extractor = None diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 5e2e9e4014..3e7283090b 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -1,4 +1,5 @@ from __future__ import annotations + import warnings from pathlib import Path @@ -7,14 +8,9 @@ from .base import BaseSegment from .baserecordingsnippets import BaseRecordingSnippets -from .core_tools import ( - convert_bytes_to_str, - convert_seconds_to_str, -) -from .recording_tools import write_binary_recording - - +from .core_tools import convert_bytes_to_str, convert_seconds_to_str from .job_tools import split_job_kwargs +from .recording_tools import write_binary_recording class BaseRecording(BaseRecordingSnippets): @@ -236,15 +232,9 @@ def get_duration(self, segment_index=None) -> float: float The duration in seconds """ - segment_index = self._check_segment_index(segment_index) - - if self.has_time_vector(segment_index): - times = self.get_times(segment_index) - segment_duration = times[-1] - times[0] + (1 / self.get_sampling_frequency()) - else: - segment_num_samples = self.get_num_samples(segment_index=segment_index) - segment_duration = segment_num_samples / self.get_sampling_frequency() - + segment_duration = ( + self.get_end_time(segment_index) - self.get_start_time(segment_index) + (1 / self.get_sampling_frequency()) + ) return segment_duration def get_total_duration(self) -> float: @@ -256,7 +246,7 @@ def get_total_duration(self) -> float: float The duration in seconds """ - duration = sum([self.get_duration(idx) for idx in range(self.get_num_segments())]) + duration = sum([self.get_duration(segment_index) for segment_index in range(self.get_num_segments())]) return duration def get_memory_size(self, segment_index=None) -> int: @@ -449,6 +439,40 @@ def get_times(self, segment_index=None) -> np.ndarray: times = rs.get_times() return times + def get_start_time(self, segment_index=None) -> float: + """Get the start time of the recording segment. + + Parameters + ---------- + segment_index : int or None, default: None + The segment index (required for multi-segment) + + Returns + ------- + float + The start time in seconds + """ + segment_index = self._check_segment_index(segment_index) + rs = self._recording_segments[segment_index] + return rs.get_start_time() + + def get_end_time(self, segment_index=None) -> float: + """Get the stop time of the recording segment. + + Parameters + ---------- + segment_index : int or None, default: None + The segment index (required for multi-segment) + + Returns + ------- + float + The stop time in seconds + """ + segment_index = self._check_segment_index(segment_index) + rs = self._recording_segments[segment_index] + return rs.get_end_time() + def has_time_vector(self, segment_index=None): """Check if the segment of the recording has a time vector. @@ -509,6 +533,35 @@ def reset_times(self): rs.t_start = None rs.sampling_frequency = self.sampling_frequency + def shift_times(self, shift: int | float, segment_index: int | None = None) -> None: + """ + Shift all times by a scalar value. + + Parameters + ---------- + shift : int | float + The shift to apply. If positive, times will be increased by `shift`. + e.g. shifting by 1 will be like the recording started 1 second later. + If negative, the start time will be decreased i.e. as if the recording + started earlier. + + segment_index : int | None + The segment on which to shift the times. + If `None`, all segments will be shifted. + """ + if segment_index is None: + segments_to_shift = range(self.get_num_segments()) + else: + segments_to_shift = (segment_index,) + + for idx in segments_to_shift: + rs = self._recording_segments[idx] + + if self.has_time_vector(segment_index=idx): + rs.time_vector += shift + else: + rs.t_start += shift + def sample_index_to_time(self, sample_ind, segment_index=None): """ Transform sample index into time in seconds @@ -878,6 +931,21 @@ def get_times(self) -> np.ndarray: time_vector += self.t_start return time_vector + def get_start_time(self) -> float: + if self.time_vector is not None: + return self.time_vector[0] + else: + return self.t_start if self.t_start is not None else 0.0 + + def get_end_time(self) -> float: + if self.time_vector is not None: + return self.time_vector[-1] + else: + t_stop = (self.get_num_samples() - 1) / self.sampling_frequency + if self.t_start is not None: + t_stop += self.t_start + return t_stop + def get_times_kwargs(self) -> dict: """ Retrieves the timing attributes characterizing a RecordingSegment @@ -921,11 +989,11 @@ def time_to_sample_index(self, time_s): sample_index = time_s * self.sampling_frequency else: sample_index = (time_s - self.t_start) * self.sampling_frequency - sample_index = round(sample_index) + sample_index = np.round(sample_index).astype(int) else: sample_index = np.searchsorted(self.time_vector, time_s, side="right") - 1 - return int(sample_index) + return sample_index def get_num_samples(self) -> int: """Returns the number of samples in this signal segment diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 310533c96b..2ec3664a45 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -172,8 +172,10 @@ def _set_probes(self, probe_or_probegroup, group_mode="by_probe", in_place=False number_of_device_channel_indices = np.max(list(device_channel_indices) + [0]) if number_of_device_channel_indices >= self.get_num_channels(): error_msg = ( - f"The given Probe have 'device_channel_indices' that do not match channel count \n" - f"{number_of_device_channel_indices} vs {self.get_num_channels()} \n" + f"The given Probe either has 'device_channel_indices' that does not match channel count \n" + f"{len(device_channel_indices)} vs {self.get_num_channels()} \n" + f"or it's max index {number_of_device_channel_indices} is the same as the number of channels {self.get_num_channels()} \n" + f"If using all channels remember that python is 0-indexed so max device_channel_index should be {self.get_num_channels() - 1} \n" f"device_channel_indices are the following: {device_channel_indices} \n" f"recording channels are the following: {self.get_channel_ids()} \n" ) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 2af48407a3..9a0e242d62 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -135,7 +135,7 @@ def get_total_duration(self) -> float: def get_unit_spike_train( self, - unit_id, + unit_id: str | int, segment_index: Union[int, None] = None, start_frame: Union[int, None] = None, end_frame: Union[int, None] = None, diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 0316b3bab1..aa69fe585b 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -2,7 +2,7 @@ import math import warnings import numpy as np -from typing import Literal +from typing import Literal, Optional from math import ceil from .basesorting import SpikeVectorSortingSegment @@ -134,7 +134,7 @@ def generate_sorting( seed = _ensure_seed(seed) rng = np.random.default_rng(seed) num_segments = len(durations) - unit_ids = np.arange(num_units) + unit_ids = [str(idx) for idx in np.arange(num_units)] spikes = [] for segment_index in range(num_segments): @@ -1111,7 +1111,7 @@ def __init__( """ - unit_ids = np.arange(num_units) + unit_ids = [str(idx) for idx in np.arange(num_units)] super().__init__(sampling_frequency, unit_ids) self.num_units = num_units @@ -1138,6 +1138,7 @@ def __init__( firing_rates=firing_rates, refractory_period_seconds=self.refractory_period_seconds, seed=segment_seed, + unit_ids=unit_ids, t_start=None, ) self.add_sorting_segment(segment) @@ -1161,6 +1162,7 @@ def __init__( firing_rates: float | np.ndarray, refractory_period_seconds: float | np.ndarray, seed: int, + unit_ids: list[str], t_start: Optional[float] = None, ): self.num_units = num_units @@ -1177,7 +1179,8 @@ def __init__( self.refractory_period_seconds = np.full(num_units, self.refractory_period_seconds, dtype="float64") self.segment_seed = seed - self.units_seed = {unit_id: self.segment_seed + hash(unit_id) for unit_id in range(num_units)} + self.units_seed = {unit_id: abs(self.segment_seed + hash(unit_id)) for unit_id in unit_ids} + self.num_samples = math.ceil(sampling_frequency * duration) super().__init__(t_start) @@ -1280,7 +1283,7 @@ def __init__( noise_block_size: int = 30000, ): - channel_ids = np.arange(num_channels) + channel_ids = [str(idx) for idx in np.arange(num_channels)] dtype = np.dtype(dtype).name # Cast to string for serialization if dtype not in ("float32", "float64"): raise ValueError(f"'dtype' must be 'float32' or 'float64' but is {dtype}") diff --git a/src/spikeinterface/core/globals.py b/src/spikeinterface/core/globals.py index 23d60a5ac5..e9974adff7 100644 --- a/src/spikeinterface/core/globals.py +++ b/src/spikeinterface/core/globals.py @@ -97,8 +97,12 @@ def is_set_global_dataset_folder() -> bool: ######################################## +_default_job_kwargs = dict( + pool_engine="thread", n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_worker=1 +) + global global_job_kwargs -global_job_kwargs = dict(n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1) +global_job_kwargs = _default_job_kwargs.copy() global global_job_kwargs_set global_job_kwargs_set = False @@ -135,7 +139,7 @@ def reset_global_job_kwargs(): Reset the global job kwargs. """ global global_job_kwargs - global_job_kwargs = dict(n_jobs=1, chunk_duration="1s", progress_bar=True) + global_job_kwargs = _default_job_kwargs.copy() def is_set_global_job_kwargs_set() -> bool: diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index 5240edcee7..38a08c0fab 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -12,8 +12,9 @@ import sys from tqdm.auto import tqdm -from concurrent.futures import ProcessPoolExecutor -import multiprocessing as mp +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor +import multiprocessing +import threading from threadpoolctl import threadpool_limits @@ -39,6 +40,7 @@ job_keys = ( + "pool_engine", "n_jobs", "total_memory", "chunk_size", @@ -46,7 +48,7 @@ "chunk_duration", "progress_bar", "mp_context", - "max_threads_per_process", + "max_threads_per_worker", ) # theses key are the same and should not be in th final dict @@ -58,11 +60,66 @@ ) +def get_best_job_kwargs(): + """ + Gives best possible job_kwargs for the platform. + Currently this function is from developer experience, but may be adapted in the future. + """ + + n_cpu = os.cpu_count() + + if platform.system() == "Linux": + pool_engine = "process" + mp_context = "fork" + + elif platform.system() == "Darwin": + pool_engine = "process" + mp_context = "spawn" + + else: # windows + # on windows and macos the fork is forbidden and process+spwan is super slow at startup + # so let's go to threads + pool_engine = "thread" + mp_context = None + n_jobs = n_cpu + max_threads_per_worker = 1 + + if platform.system() in ("Linux", "Darwin"): + # here we try to balance between the number of workers (n_jobs) and the number of sub thread + # this is totally empirical but this is a good start + if n_cpu <= 16: + # for small n_cpu let's make many process + n_jobs = n_cpu + max_threads_per_worker = 1 + else: + # let's have fewer processes with more threads each + n_jobs = int(n_cpu / 4) + max_threads_per_worker = 8 + + return dict( + pool_engine=pool_engine, + mp_context=mp_context, + n_jobs=n_jobs, + max_threads_per_worker=max_threads_per_worker, + ) + + def fix_job_kwargs(runtime_job_kwargs): from .globals import get_global_job_kwargs, is_set_global_job_kwargs_set job_kwargs = get_global_job_kwargs() + # deprecation with backward compatibility + # this can be removed in 0.104.0 + if "max_threads_per_process" in runtime_job_kwargs: + runtime_job_kwargs = runtime_job_kwargs.copy() + runtime_job_kwargs["max_threads_per_worker"] = runtime_job_kwargs.pop("max_threads_per_process") + warnings.warn( + "job_kwargs: max_threads_per_process was changed to max_threads_per_worker, max_threads_per_process will be removed in 0.104", + DeprecationWarning, + stacklevel=2, + ) + for k in runtime_job_kwargs: assert k in job_keys, ( f"{k} is not a valid job keyword argument. " f"Available keyword arguments are: {list(job_keys)}" @@ -99,14 +156,14 @@ def fix_job_kwargs(runtime_job_kwargs): n_jobs = max(n_jobs, 1) job_kwargs["n_jobs"] = min(n_jobs, os.cpu_count()) - if "n_jobs" not in runtime_job_kwargs and job_kwargs["n_jobs"] == 1 and not is_set_global_job_kwargs_set(): - warnings.warn( - "`n_jobs` is not set so parallel processing is disabled! " - "To speed up computations, it is recommended to set n_jobs either " - "globally (with the `spikeinterface.set_global_job_kwargs()` function) or " - "locally (with the `n_jobs` argument). Use `spikeinterface.set_global_job_kwargs?` " - "for more information about job_kwargs." - ) + # if "n_jobs" not in runtime_job_kwargs and job_kwargs["n_jobs"] == 1 and not is_set_global_job_kwargs_set(): + # warnings.warn( + # "`n_jobs` is not set so parallel processing is disabled! " + # "To speed up computations, it is recommended to set n_jobs either " + # "globally (with the `spikeinterface.set_global_job_kwargs()` function) or " + # "locally (with the `n_jobs` argument). Use `spikeinterface.set_global_job_kwargs?` " + # "for more information about job_kwargs." + # ) return job_kwargs @@ -149,12 +206,12 @@ def divide_segment_into_chunks(num_frames, chunk_size): def divide_recording_into_chunks(recording, chunk_size): - all_chunks = [] + recording_slices = [] for segment_index in range(recording.get_num_segments()): num_frames = recording.get_num_samples(segment_index) chunks = divide_segment_into_chunks(num_frames, chunk_size) - all_chunks.extend([(segment_index, frame_start, frame_stop) for frame_start, frame_stop in chunks]) - return all_chunks + recording_slices.extend([(segment_index, frame_start, frame_stop) for frame_start, frame_stop in chunks]) + return recording_slices def ensure_n_jobs(recording, n_jobs=1): @@ -185,6 +242,22 @@ def ensure_n_jobs(recording, n_jobs=1): return n_jobs +def chunk_duration_to_chunk_size(chunk_duration, recording): + if isinstance(chunk_duration, float): + chunk_size = int(chunk_duration * recording.get_sampling_frequency()) + elif isinstance(chunk_duration, str): + if chunk_duration.endswith("ms"): + chunk_duration = float(chunk_duration.replace("ms", "")) / 1000.0 + elif chunk_duration.endswith("s"): + chunk_duration = float(chunk_duration.replace("s", "")) + else: + raise ValueError("chunk_duration must ends with s or ms") + chunk_size = int(chunk_duration * recording.get_sampling_frequency()) + else: + raise ValueError("chunk_duration must be str or float") + return chunk_size + + def ensure_chunk_size( recording, total_memory=None, chunk_size=None, chunk_memory=None, chunk_duration=None, n_jobs=1, **other_kwargs ): @@ -231,18 +304,7 @@ def ensure_chunk_size( num_channels = recording.get_num_channels() chunk_size = int(total_memory / (num_channels * n_bytes * n_jobs)) elif chunk_duration is not None: - if isinstance(chunk_duration, float): - chunk_size = int(chunk_duration * recording.get_sampling_frequency()) - elif isinstance(chunk_duration, str): - if chunk_duration.endswith("ms"): - chunk_duration = float(chunk_duration.replace("ms", "")) / 1000.0 - elif chunk_duration.endswith("s"): - chunk_duration = float(chunk_duration.replace("s", "")) - else: - raise ValueError("chunk_duration must ends with s or ms") - chunk_size = int(chunk_duration * recording.get_sampling_frequency()) - else: - raise ValueError("chunk_duration must be str or float") + chunk_size = chunk_duration_to_chunk_size(chunk_duration, recording) else: # Edge case to define single chunk per segment for n_jobs=1. # All chunking parameters equal None mean single chunk per segment @@ -282,11 +344,15 @@ class ChunkRecordingExecutor: If True, output is verbose job_name : str, default: "" Job name + progress_bar : bool, default: False + If True, a progress bar is printed to monitor the progress of the process handle_returns : bool, default: False If True, the function can return values gather_func : None or callable, default: None Optional function that is called in the main thread and retrieves the results of each worker. This function can be used instead of `handle_returns` to implement custom storage on-the-fly. + pool_engine : "process" | "thread", default: "thread" + If n_jobs>1 then use ProcessPoolExecutor or ThreadPoolExecutor n_jobs : int, default: 1 Number of jobs to be used. Use -1 to use as many jobs as number of cores total_memory : str, default: None @@ -300,13 +366,12 @@ class ChunkRecordingExecutor: mp_context : "fork" | "spawn" | None, default: None "fork" or "spawn". If None, the context is taken by the recording.get_preferred_mp_context(). "fork" is only safely available on LINUX systems. - max_threads_per_process : int or None, default: None + max_threads_per_worker : int or None, default: None Limit the number of thread per process using threadpoolctl modules. This used only when n_jobs>1 If None, no limits. - progress_bar : bool, default: False - If True, a progress bar is printed to monitor the progress of the process - + need_worker_index : bool, default False + If True then each worker will also have a "worker_index" injected in the local worker dict. Returns ------- @@ -324,6 +389,7 @@ def __init__( progress_bar=False, handle_returns=False, gather_func=None, + pool_engine="thread", n_jobs=1, total_memory=None, chunk_size=None, @@ -331,19 +397,21 @@ def __init__( chunk_duration=None, mp_context=None, job_name="", - max_threads_per_process=1, + max_threads_per_worker=1, + need_worker_index=False, ): self.recording = recording self.func = func self.init_func = init_func self.init_args = init_args - if mp_context is None: - mp_context = recording.get_preferred_mp_context() - if mp_context is not None and platform.system() == "Windows": - assert mp_context != "fork", "'fork' mp_context not supported on Windows!" - elif mp_context == "fork" and platform.system() == "Darwin": - warnings.warn('As of Python 3.8 "fork" is no longer considered safe on macOS') + if pool_engine == "process": + if mp_context is None: + mp_context = recording.get_preferred_mp_context() + if mp_context is not None and platform.system() == "Windows": + assert mp_context != "fork", "'fork' mp_context not supported on Windows!" + elif mp_context == "fork" and platform.system() == "Darwin": + warnings.warn('As of Python 3.8 "fork" is no longer considered safe on macOS') self.mp_context = mp_context @@ -363,7 +431,11 @@ def __init__( n_jobs=self.n_jobs, ) self.job_name = job_name - self.max_threads_per_process = max_threads_per_process + self.max_threads_per_worker = max_threads_per_worker + + self.pool_engine = pool_engine + + self.need_worker_index = need_worker_index if verbose: chunk_memory = self.chunk_size * recording.get_num_channels() * np.dtype(recording.get_dtype()).itemsize @@ -375,6 +447,7 @@ def __init__( print( self.job_name, "\n" + f"engine={self.pool_engine} - " f"n_jobs={self.n_jobs} - " f"samples_per_chunk={self.chunk_size:,} - " f"chunk_memory={chunk_memory_str} - " @@ -382,11 +455,13 @@ def __init__( f"chunk_duration={chunk_duration_str}", ) - def run(self): + def run(self, recording_slices=None): """ Runs the defined jobs. """ - all_chunks = divide_recording_into_chunks(self.recording, self.chunk_size) + + if recording_slices is None: + recording_slices = divide_recording_into_chunks(self.recording, self.chunk_size) if self.handle_returns: returns = [] @@ -395,69 +470,197 @@ def run(self): if self.n_jobs == 1: if self.progress_bar: - all_chunks = tqdm(all_chunks, ascii=True, desc=self.job_name) + recording_slices = tqdm( + recording_slices, desc=f"{self.job_name} (no parallelization)", total=len(recording_slices) + ) - worker_ctx = self.init_func(*self.init_args) - for segment_index, frame_start, frame_stop in all_chunks: - res = self.func(segment_index, frame_start, frame_stop, worker_ctx) + worker_dict = self.init_func(*self.init_args) + if self.need_worker_index: + worker_dict["worker_index"] = 0 + + for segment_index, frame_start, frame_stop in recording_slices: + res = self.func(segment_index, frame_start, frame_stop, worker_dict) if self.handle_returns: returns.append(res) if self.gather_func is not None: self.gather_func(res) - else: - n_jobs = min(self.n_jobs, len(all_chunks)) - # parallel - with ProcessPoolExecutor( - max_workers=n_jobs, - initializer=worker_initializer, - mp_context=mp.get_context(self.mp_context), - initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_process), - ) as executor: - results = executor.map(function_wrapper, all_chunks) + else: + n_jobs = min(self.n_jobs, len(recording_slices)) + + if self.pool_engine == "process": + + if self.need_worker_index: + lock = multiprocessing.Lock() + array_pid = multiprocessing.Array("i", n_jobs) + for i in range(n_jobs): + array_pid[i] = -1 + else: + lock = None + array_pid = None + + # parallel + with ProcessPoolExecutor( + max_workers=n_jobs, + initializer=process_worker_initializer, + mp_context=multiprocessing.get_context(self.mp_context), + initargs=( + self.func, + self.init_func, + self.init_args, + self.max_threads_per_worker, + self.need_worker_index, + lock, + array_pid, + ), + ) as executor: + results = executor.map(process_function_wrapper, recording_slices) + + if self.progress_bar: + results = tqdm( + results, desc=f"{self.job_name} (workers: {n_jobs} processes)", total=len(recording_slices) + ) + + for res in results: + if self.handle_returns: + returns.append(res) + if self.gather_func is not None: + self.gather_func(res) + + elif self.pool_engine == "thread": + # this is need to create a per worker local dict where the initializer will push the func wrapper + thread_local_data = threading.local() + + global _thread_started + _thread_started = 0 if self.progress_bar: - results = tqdm(results, desc=self.job_name, total=len(all_chunks)) + # here the tqdm threading do not work (maybe collision) so we need to create a pbar + # before thread spawning + pbar = tqdm(desc=f"{self.job_name} (workers: {n_jobs} threads)", total=len(recording_slices)) + + if self.need_worker_index: + lock = threading.Lock() + else: + lock = None + + with ThreadPoolExecutor( + max_workers=n_jobs, + initializer=thread_worker_initializer, + initargs=( + self.func, + self.init_func, + self.init_args, + self.max_threads_per_worker, + thread_local_data, + self.need_worker_index, + lock, + ), + ) as executor: + + recording_slices2 = [(thread_local_data,) + tuple(args) for args in recording_slices] + results = executor.map(thread_function_wrapper, recording_slices2) + + for res in results: + if self.progress_bar: + pbar.update(1) + if self.handle_returns: + returns.append(res) + if self.gather_func is not None: + self.gather_func(res) + if self.progress_bar: + pbar.close() + del pbar - for res in results: - if self.handle_returns: - returns.append(res) - if self.gather_func is not None: - self.gather_func(res) + else: + raise ValueError("If n_jobs>1 pool_engine must be 'process' or 'thread'") return returns +class WorkerFuncWrapper: + """ + small wrapper that handles: + * local worker_dict + * max_threads_per_worker + """ + + def __init__(self, func, worker_dict, max_threads_per_worker): + self.func = func + self.worker_dict = worker_dict + self.max_threads_per_worker = max_threads_per_worker + + def __call__(self, args): + segment_index, start_frame, end_frame = args + if self.max_threads_per_worker is None: + return self.func(segment_index, start_frame, end_frame, self.worker_dict) + else: + with threadpool_limits(limits=self.max_threads_per_worker): + return self.func(segment_index, start_frame, end_frame, self.worker_dict) + + # see # https://stackoverflow.com/questions/10117073/how-to-use-initializer-to-set-up-my-multiprocess-pool -# the tricks is : theses 2 variables are global per worker -# so they are not share in the same process -global _worker_ctx -global _func +# the trick is : this variable is global per worker (so not shared in the same process) +global _process_func_wrapper -def worker_initializer(func, init_func, init_args, max_threads_per_process): - global _worker_ctx - if max_threads_per_process is None: - _worker_ctx = init_func(*init_args) +def process_worker_initializer(func, init_func, init_args, max_threads_per_worker, need_worker_index, lock, array_pid): + global _process_func_wrapper + if max_threads_per_worker is None: + worker_dict = init_func(*init_args) else: - with threadpool_limits(limits=max_threads_per_process): - _worker_ctx = init_func(*init_args) - _worker_ctx["max_threads_per_process"] = max_threads_per_process - global _func - _func = func - - -def function_wrapper(args): - segment_index, start_frame, end_frame = args - global _func - global _worker_ctx - max_threads_per_process = _worker_ctx["max_threads_per_process"] - if max_threads_per_process is None: - return _func(segment_index, start_frame, end_frame, _worker_ctx) + with threadpool_limits(limits=max_threads_per_worker): + worker_dict = init_func(*init_args) + + if need_worker_index: + child_process = multiprocessing.current_process() + lock.acquire() + worker_index = None + for i in range(len(array_pid)): + if array_pid[i] == -1: + worker_index = i + array_pid[i] = child_process.ident + break + worker_dict["worker_index"] = worker_index + lock.release() + + _process_func_wrapper = WorkerFuncWrapper(func, worker_dict, max_threads_per_worker) + + +def process_function_wrapper(args): + global _process_func_wrapper + return _process_func_wrapper(args) + + +# use by thread at init +global _thread_started + + +def thread_worker_initializer( + func, init_func, init_args, max_threads_per_worker, thread_local_data, need_worker_index, lock +): + if max_threads_per_worker is None: + worker_dict = init_func(*init_args) else: - with threadpool_limits(limits=max_threads_per_process): - return _func(segment_index, start_frame, end_frame, _worker_ctx) + with threadpool_limits(limits=max_threads_per_worker): + worker_dict = init_func(*init_args) + + if need_worker_index: + lock.acquire() + global _thread_started + worker_index = _thread_started + _thread_started += 1 + worker_dict["worker_index"] = worker_index + lock.release() + + thread_local_data.func_wrapper = WorkerFuncWrapper(func, worker_dict, max_threads_per_worker) + + +def thread_function_wrapper(args): + thread_local_data = args[0] + args = args[1:] + return thread_local_data.func_wrapper(args) # Here some utils copy/paste from DART (Charlie Windolf) diff --git a/src/spikeinterface/core/loading.py b/src/spikeinterface/core/loading.py new file mode 100644 index 0000000000..c8dc44160e --- /dev/null +++ b/src/spikeinterface/core/loading.py @@ -0,0 +1,127 @@ +import warnings +from pathlib import Path + + +from .base import BaseExtractor +from .core_tools import is_path_remote + + +def load(file_or_folder_or_dict, base_folder=None) -> BaseExtractor: + """ + General load function to load a SpikeInterface object. + + The function can load: + - a `Recording` or `Sorting` object from: + * dictionary + * json file + * pkl file + * binary folder (after `extractor.save(..., format='binary_folder')`) + * zarr folder (after `extractor.save(..., format='zarr')`) + * remote zarr folder + - (TODO) a `SortingAnalyzer` object from : + * binary folder + * zarr folder + * remote zarr folder + * WaveformExtractor folder + + Parameters + ---------- + file_or_folder_or_dict : dictionary or folder or file (json, pickle) + The file path, folder path, or dictionary to load the extractor from + base_folder : str | Path | bool (optional) + The base folder to make relative paths absolute. + If True and file_or_folder_or_dict is a file, the parent folder of the file is used. + + Returns + ------- + extractor: Recording or Sorting + The loaded extractor object + """ + if isinstance(file_or_folder_or_dict, dict): + assert not isinstance(base_folder, bool), "`base_folder` must be a string or Path when loading from dict" + return BaseExtractor.from_dict(file_or_folder_or_dict, base_folder=base_folder) + else: + file_path = file_or_folder_or_dict + error_msg = ( + f"{file_path} is not a file or a folder. It should point to either a json, pickle file or a " + "folder that is the result of extractor.save(...)" + ) + if not is_path_remote(file_path): + file_path = Path(file_path) + + if base_folder is True: + base_folder = file_path.parent + + if file_path.is_file(): + # standard case based on a file (json or pickle) + if str(file_path).endswith(".json"): + import json + + with open(file_path, "r") as f: + d = json.load(f) + elif str(file_path).endswith(".pkl") or str(file_path).endswith(".pickle"): + import pickle + + with open(file_path, "rb") as f: + d = pickle.load(f) + else: + raise ValueError(error_msg) + + # this is for back-compatibility since now unserializable objects will not + # be saved to file + if "warning" in d: + print("The extractor was not serializable to file") + return None + + extractor = BaseExtractor.from_dict(d, base_folder=base_folder) + + elif file_path.is_dir(): + # this can be and extractor, SortingAnalyzer, or WaveformExtractor + folder = file_path + file = None + + if folder.suffix == ".zarr": + from .zarrextractors import read_zarr + + extractor = read_zarr(folder) + else: + # For backward compatibility (v<=0.94) we check for the cached.json/pkl/pickle files + # In later versions (v>0.94) we use the si_folder.json file + for dump_ext in ("json", "pkl", "pickle"): + f = folder / f"cached.{dump_ext}" + if f.is_file(): + file = f + + f = folder / f"si_folder.json" + if f.is_file(): + file = f + + if file is None: + raise ValueError(error_msg) + extractor = BaseExtractor.load(file, base_folder=folder) + + else: + raise ValueError(error_msg) + else: + # remote case - zarr + if str(file_path).endswith(".zarr") or str(file_path).endswith(".zarr/"): + from .zarrextractors import read_zarr + + extractor = read_zarr(file_path) + else: + raise NotImplementedError( + "Only zarr format is supported for remote files and you should provide a path to a .zarr " + "remote path. You can save to a valid zarr folder using: " + "`extractor.save(folder='path/to/folder', format='zarr')`" + ) + + return extractor + + +def load_extractor(file_or_folder_or_dict, base_folder=None) -> BaseExtractor: + warnings.warn( + "load_extractor() is deprecated and will be removed in the future. Please use load() instead.", + DeprecationWarning, + stacklevel=2, + ) + return load(file_or_folder_or_dict, base_folder=base_folder) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index ceff8577d3..53c2445c77 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -1,22 +1,6 @@ """ -Pipeline on spikes/peaks/detected peaks - -Functions that can be chained: - * after peak detection - * already detected peaks - * spikes (labeled peaks) -to compute some additional features on-the-fly: - * peak localization - * peak-to-peak - * pca - * amplitude - * amplitude scaling - * ... - -There are two ways for using theses "plugin nodes": - * during `peak_detect()` - * when peaks are already detected and reduced with `select_peaks()` - * on a sorting object + + """ from __future__ import annotations @@ -96,16 +80,26 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin, *ar class PeakSource(PipelineNode): - # base class for peak detector + def get_trace_margin(self): raise NotImplementedError def get_dtype(self): return base_peak_dtype + def get_peak_slice( + self, + segment_index, + start_frame, + end_frame, + ): + # not needed for PeakDetector + raise NotImplementedError + # this is used in sorting components class PeakDetector(PeakSource): + # base class for peak detector or template matching pass @@ -127,11 +121,18 @@ def get_trace_margin(self): def get_dtype(self): return base_peak_dtype - def compute(self, traces, start_frame, end_frame, segment_index, max_margin): - # get local peaks + def get_peak_slice(self, segment_index, start_frame, end_frame, max_margin): sl = self.segment_slices[segment_index] peaks_in_segment = self.peaks[sl] i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame]) + return i0, i1 + + def compute(self, traces, start_frame, end_frame, segment_index, max_margin, peak_slice): + # get local peaks + sl = self.segment_slices[segment_index] + peaks_in_segment = self.peaks[sl] + # i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame]) + i0, i1 = peak_slice local_peaks = peaks_in_segment[i0:i1] # make sample index local to traces @@ -212,8 +213,7 @@ def get_trace_margin(self): def get_dtype(self): return self._dtype - def compute(self, traces, start_frame, end_frame, segment_index, max_margin): - # get local peaks + def get_peak_slice(self, segment_index, start_frame, end_frame, max_margin): sl = self.segment_slices[segment_index] peaks_in_segment = self.peaks[sl] if self.include_spikes_in_margin: @@ -222,6 +222,20 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): ) else: i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame]) + return i0, i1 + + def compute(self, traces, start_frame, end_frame, segment_index, max_margin, peak_slice): + # get local peaks + sl = self.segment_slices[segment_index] + peaks_in_segment = self.peaks[sl] + # if self.include_spikes_in_margin: + # i0, i1 = np.searchsorted( + # peaks_in_segment["sample_index"], [start_frame - max_margin, end_frame + max_margin] + # ) + # else: + # i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame]) + i0, i1 = peak_slice + local_peaks = peaks_in_segment[i0:i1] # make sample index local to traces @@ -467,16 +481,76 @@ def run_node_pipeline( nodes, job_kwargs, job_name="pipeline", - mp_context=None, + # mp_context=None, gather_mode="memory", gather_kwargs={}, squeeze_output=True, folder=None, names=None, verbose=False, + skip_after_n_peaks=None, + recording_slices=None, ): """ - Common function to run pipeline with peak detector or already detected peak. + Machinery to compute in parallel operations on peaks and traces. + + This useful in several use cases: + * in sortingcomponents : detect peaks and make some computation on then (localize, pca, ...) + * in sortingcomponents : replay some peaks and make some computation on then (localize, pca, ...) + * postprocessing : replay some spikes and make some computation on then (localize, pca, ...) + + Here a "peak" is a spike without any labels just a "detected". + Here a "spike" is a spike with any a label so already sorted. + + The main idea is to have a graph of nodes. + Every node is doing a computaion of some peaks and related traces. + The first node is PeakSource so either a peak detector PeakDetector or peak/spike replay (PeakRetriever/SpikeRetriever) + + Every node can have one or several output that can be directed to other nodes (aka nodes have parents). + + Every node can optionally have a global output that will be gathered by the main process. + This is controlled by return_output = True. + + The gather consists of concatenating features related to peaks (localization, pca, scaling, ...) into a single big vector. + These vectors can be in "memory" or in files ("npy") + + + Parameters + ---------- + + recording: Recording + + nodes: a list of PipelineNode + + job_kwargs: dict + The classical job_kwargs + job_name : str + The name of the pipeline used for the progress_bar + gather_mode : "memory" | "npz" + + gather_kwargs : dict + OPtions to control the "gather engine". See GatherToMemory or GatherToNpy. + squeeze_output : bool, default True + If only one output node then squeeze the tuple + folder : str | Path | None + Used for gather_mode="npz" + names : list of str + Names of outputs. + verbose : bool, default False + Verbosity. + skip_after_n_peaks : None | int + Skip the computation after n_peaks. + This is not an exact because internally this skip is done per worker in average. + recording_slices : None | list[tuple] + Optionaly give a list of slices to run the pipeline only on some chunks of the recording. + It must be a list of (segment_index, frame_start, frame_stop). + If None (default), the function iterates over the entire duration of the recording. + + Returns + ------- + outputs: tuple of np.array | np.array + a tuple of vector for the output of nodes having return_output=True. + If squeeze_output=True and only one output then directly np.array. """ check_graph(nodes) @@ -484,6 +558,11 @@ def run_node_pipeline( job_kwargs = fix_job_kwargs(job_kwargs) assert all(isinstance(node, PipelineNode) for node in nodes) + if skip_after_n_peaks is not None: + skip_after_n_peaks_per_worker = skip_after_n_peaks / job_kwargs["n_jobs"] + else: + skip_after_n_peaks_per_worker = None + if gather_mode == "memory": gather_func = GatherToMemory() elif gather_mode == "npy": @@ -491,7 +570,7 @@ def run_node_pipeline( else: raise ValueError(f"wrong gather_mode : {gather_mode}") - init_args = (recording, nodes) + init_args = (recording, nodes, skip_after_n_peaks_per_worker) processor = ChunkRecordingExecutor( recording, @@ -504,18 +583,20 @@ def run_node_pipeline( **job_kwargs, ) - processor.run() + processor.run(recording_slices=recording_slices) outs = gather_func.finalize_buffers(squeeze_output=squeeze_output) return outs -def _init_peak_pipeline(recording, nodes): +def _init_peak_pipeline(recording, nodes, skip_after_n_peaks_per_worker): # create a local dict per worker worker_ctx = {} worker_ctx["recording"] = recording worker_ctx["nodes"] = nodes worker_ctx["max_margin"] = max(node.get_trace_margin() for node in nodes) + worker_ctx["skip_after_n_peaks_per_worker"] = skip_after_n_peaks_per_worker + worker_ctx["num_peaks"] = 0 return worker_ctx @@ -523,66 +604,88 @@ def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_c recording = worker_ctx["recording"] max_margin = worker_ctx["max_margin"] nodes = worker_ctx["nodes"] + skip_after_n_peaks_per_worker = worker_ctx["skip_after_n_peaks_per_worker"] recording_segment = recording._recording_segments[segment_index] - traces_chunk, left_margin, right_margin = get_chunk_with_margin( - recording_segment, start_frame, end_frame, None, max_margin, add_zeros=True - ) + node0 = nodes[0] - # compute the graph - pipeline_outputs = {} - for node in nodes: - node_parents = node.parents if node.parents else list() - node_input_args = tuple() - for parent in node_parents: - parent_output = pipeline_outputs[parent] - parent_outputs_tuple = parent_output if isinstance(parent_output, tuple) else (parent_output,) - node_input_args += parent_outputs_tuple - if isinstance(node, PeakDetector): - # to handle compatibility peak detector is a special case - # with specific margin - # TODO later when in master: change this later - extra_margin = max_margin - node.get_trace_margin() - if extra_margin: - trace_detection = traces_chunk[extra_margin:-extra_margin] + if isinstance(node0, (SpikeRetriever, PeakRetriever)): + # in this case PeakSource could have no peaks and so no need to load traces just skip + peak_slice = i0, i1 = node0.get_peak_slice(segment_index, start_frame, end_frame, max_margin) + load_trace_and_compute = i0 < i1 + else: + # PeakDetector always need traces + load_trace_and_compute = True + + if skip_after_n_peaks_per_worker is not None: + if worker_ctx["num_peaks"] > skip_after_n_peaks_per_worker: + load_trace_and_compute = False + + if load_trace_and_compute: + traces_chunk, left_margin, right_margin = get_chunk_with_margin( + recording_segment, start_frame, end_frame, None, max_margin, add_zeros=True + ) + # compute the graph + pipeline_outputs = {} + for node in nodes: + node_parents = node.parents if node.parents else list() + node_input_args = tuple() + for parent in node_parents: + parent_output = pipeline_outputs[parent] + parent_outputs_tuple = parent_output if isinstance(parent_output, tuple) else (parent_output,) + node_input_args += parent_outputs_tuple + if isinstance(node, PeakDetector): + # to handle compatibility peak detector is a special case + # with specific margin + # TODO later when in master: change this later + extra_margin = max_margin - node.get_trace_margin() + if extra_margin: + trace_detection = traces_chunk[extra_margin:-extra_margin] + else: + trace_detection = traces_chunk + node_output = node.compute(trace_detection, start_frame, end_frame, segment_index, max_margin) + # set sample index to local + node_output[0]["sample_index"] += extra_margin + elif isinstance(node, PeakSource): + node_output = node.compute(traces_chunk, start_frame, end_frame, segment_index, max_margin, peak_slice) else: - trace_detection = traces_chunk - node_output = node.compute(trace_detection, start_frame, end_frame, segment_index, max_margin) - # set sample index to local - node_output[0]["sample_index"] += extra_margin - elif isinstance(node, PeakSource): - node_output = node.compute(traces_chunk, start_frame, end_frame, segment_index, max_margin) - else: - # TODO later when in master: change the signature of all nodes (or maybe not!) - node_output = node.compute(traces_chunk, *node_input_args) - pipeline_outputs[node] = node_output - - # propagate the output - pipeline_outputs_tuple = tuple() - for node in nodes: - # handle which buffer are given to the output - # this is controlled by node.return_output being a bool or tuple of bool - out = pipeline_outputs[node] - if isinstance(out, tuple): - if isinstance(node.return_output, bool) and node.return_output: - pipeline_outputs_tuple += out - elif isinstance(node.return_output, tuple): - for flag, e in zip(node.return_output, out): - if flag: - pipeline_outputs_tuple += (e,) - else: - if isinstance(node.return_output, bool) and node.return_output: - pipeline_outputs_tuple += (out,) - elif isinstance(node.return_output, tuple): - # this should not apppend : maybe a checker somewhere before ? - pass + # TODO later when in master: change the signature of all nodes (or maybe not!) + node_output = node.compute(traces_chunk, *node_input_args) + pipeline_outputs[node] = node_output + + if skip_after_n_peaks_per_worker is not None and isinstance(node, PeakSource): + worker_ctx["num_peaks"] += node_output[0].size + + # propagate the output + pipeline_outputs_tuple = tuple() + for node in nodes: + # handle which buffer are given to the output + # this is controlled by node.return_output being a bool or tuple of bool + out = pipeline_outputs[node] + if isinstance(out, tuple): + if isinstance(node.return_output, bool) and node.return_output: + pipeline_outputs_tuple += out + elif isinstance(node.return_output, tuple): + for flag, e in zip(node.return_output, out): + if flag: + pipeline_outputs_tuple += (e,) + else: + if isinstance(node.return_output, bool) and node.return_output: + pipeline_outputs_tuple += (out,) + elif isinstance(node.return_output, tuple): + # this should not apppend : maybe a checker somewhere before ? + pass - if isinstance(nodes[0], PeakDetector): - # the first out element is the peak vector - # we need to go back to absolut sample index - pipeline_outputs_tuple[0]["sample_index"] += start_frame - left_margin + if isinstance(nodes[0], PeakDetector): + # the first out element is the peak vector + # we need to go back to absolut sample index + pipeline_outputs_tuple[0]["sample_index"] += start_frame - left_margin - return pipeline_outputs_tuple + return pipeline_outputs_tuple + + else: + # the gather will skip this output and not concatenate it + return class GatherToMemory: @@ -595,6 +698,9 @@ def __init__(self): self.tuple_mode = None def __call__(self, res): + if res is None: + return + if self.tuple_mode is None: # first loop only self.tuple_mode = isinstance(res, tuple) @@ -655,6 +761,9 @@ def __init__(self, folder, names, npy_header_size=1024, exist_ok=False): self.final_shapes.append(None) def __call__(self, res): + if res is None: + return + if self.tuple_mode is None: # first loop only self.tuple_mode = isinstance(res, tuple) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 77d427bc88..284b1141ae 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -18,6 +18,8 @@ fix_job_kwargs, ChunkRecordingExecutor, _shared_job_kwargs_doc, + chunk_duration_to_chunk_size, + split_job_kwargs, ) @@ -245,9 +247,9 @@ def _init_memory_worker(recording, arrays, shm_names, shapes, dtype, cast_unsign # create a local dict per worker worker_ctx = {} if isinstance(recording, dict): - from spikeinterface.core import load_extractor + from spikeinterface.core import load - worker_ctx["recording"] = load_extractor(recording) + worker_ctx["recording"] = load(recording) else: worker_ctx["recording"] = recording @@ -512,33 +514,38 @@ def determine_cast_unsigned(recording, dtype): return cast_unsigned -def get_random_data_chunks( +def get_random_recording_slices( recording, - return_scaled=False, + method="full_random", num_chunks_per_segment=20, - chunk_size=10000, - concatenated=True, - seed=0, + chunk_duration="500ms", + chunk_size=None, margin_frames=0, + seed=None, ): """ - Extract random chunks across segments + Get random slice of a recording across segments. - This is used for instance in get_noise_levels() to estimate noise on traces. + This is used for instance in get_noise_levels() and get_random_data_chunks() to estimate noise on traces. Parameters ---------- recording : BaseRecording The recording to get random chunks from - return_scaled : bool, default: False - If True, returned chunks are scaled to uV + method : "full_random" + The method used to get random slices. + * "full_random" : legacy method, used until version 0.101.0, there is no constrain on slices + and they can overlap. num_chunks_per_segment : int, default: 20 Number of chunks per segment - chunk_size : int, default: 10000 - Size of a chunk in number of frames + chunk_duration : str | float | None, default "500ms" + The duration of each chunk in 's' or 'ms' + chunk_size : int | None + Size of a chunk in number of frames. This is ued only if chunk_duration is None. + This is kept for backward compatibility, you should prefer 'chunk_duration=500ms' instead. concatenated : bool, default: True If True chunk are concatenated along time axis - seed : int, default: 0 + seed : int, default: None Random seed margin_frames : int, default: 0 Margin in number of frames to avoid edge effects @@ -547,42 +554,89 @@ def get_random_data_chunks( ------- chunk_list : np.array Array of concatenate chunks per segment + + """ # TODO: if segment have differents length make another sampling that dependant on the length of the segment # Should be done by changing kwargs with total_num_chunks=XXX and total_duration=YYYY # And randomize the number of chunk per segment weighted by segment duration - # check chunk size - num_segments = recording.get_num_segments() - for segment_index in range(num_segments): - chunk_size_limit = recording.get_num_frames(segment_index) - 2 * margin_frames - if chunk_size > chunk_size_limit: - chunk_size = chunk_size_limit - 1 - warnings.warn( - f"chunk_size is greater than the number " - f"of samples for segment index {segment_index}. " - f"Using {chunk_size}." - ) + if method == "full_random": + if chunk_size is None: + if chunk_duration is not None: + chunk_size = chunk_duration_to_chunk_size(chunk_duration, recording) + else: + raise ValueError("get_random_recording_slices need chunk_size or chunk_duration") + + # check chunk size + num_segments = recording.get_num_segments() + for segment_index in range(num_segments): + chunk_size_limit = recording.get_num_frames(segment_index) - 2 * margin_frames + if chunk_size > chunk_size_limit: + chunk_size = chunk_size_limit - 1 + warnings.warn( + f"chunk_size is greater than the number " + f"of samples for segment index {segment_index}. " + f"Using {chunk_size}." + ) + rng = np.random.default_rng(seed) + recording_slices = [] + low = margin_frames + size = num_chunks_per_segment + for segment_index in range(num_segments): + num_frames = recording.get_num_frames(segment_index) + high = num_frames - chunk_size - margin_frames + random_starts = rng.integers(low=low, high=high, size=size) + random_starts = np.sort(random_starts) + recording_slices += [ + (segment_index, start_frame, (start_frame + chunk_size)) for start_frame in random_starts + ] + else: + raise ValueError(f"get_random_recording_slices : wrong method {method}") - rng = np.random.default_rng(seed) - chunk_list = [] - low = margin_frames - size = num_chunks_per_segment - for segment_index in range(num_segments): - num_frames = recording.get_num_frames(segment_index) - high = num_frames - chunk_size - margin_frames - random_starts = rng.integers(low=low, high=high, size=size) - segment_trace_chunk = [ - recording.get_traces( - start_frame=start_frame, - end_frame=(start_frame + chunk_size), - segment_index=segment_index, - return_scaled=return_scaled, - ) - for start_frame in random_starts - ] + return recording_slices - chunk_list.extend(segment_trace_chunk) + +def get_random_data_chunks(recording, return_scaled=False, concatenated=True, **random_slices_kwargs): + """ + Extract random chunks across segments. + + Internally, it uses `get_random_recording_slices()` and retrieves the traces chunk as a list + or a concatenated unique array. + + Please read `get_random_recording_slices()` for more details on parameters. + + + Parameters + ---------- + recording : BaseRecording + The recording to get random chunks from + return_scaled : bool, default: False + If True, returned chunks are scaled to uV + num_chunks_per_segment : int, default: 20 + Number of chunks per segment + concatenated : bool, default: True + If True chunk are concatenated along time axis + **random_slices_kwargs : dict + Options transmited to get_random_recording_slices(), please read documentation from this + function for more details. + + Returns + ------- + chunk_list : np.array | list of np.array + Array of concatenate chunks per segment + """ + recording_slices = get_random_recording_slices(recording, **random_slices_kwargs) + + chunk_list = [] + for segment_index, start_frame, end_frame in recording_slices: + traces_chunk = recording.get_traces( + start_frame=start_frame, + end_frame=end_frame, + segment_index=segment_index, + return_scaled=return_scaled, + ) + chunk_list.append(traces_chunk) if concatenated: return np.concatenate(chunk_list, axis=0) @@ -637,19 +691,52 @@ def get_closest_channels(recording, channel_ids=None, num_channels=None): return np.array(closest_channels_inds), np.array(dists) +def _noise_level_chunk(segment_index, start_frame, end_frame, worker_ctx): + recording = worker_ctx["recording"] + + one_chunk = recording.get_traces( + start_frame=start_frame, + end_frame=end_frame, + segment_index=segment_index, + return_scaled=worker_ctx["return_scaled"], + ) + + if worker_ctx["method"] == "mad": + med = np.median(one_chunk, axis=0, keepdims=True) + # hard-coded so that core doesn't depend on scipy + noise_levels = np.median(np.abs(one_chunk - med), axis=0) / 0.6744897501960817 + elif worker_ctx["method"] == "std": + noise_levels = np.std(one_chunk, axis=0) + + return noise_levels + + +def _noise_level_chunk_init(recording, return_scaled, method): + worker_ctx = {} + worker_ctx["recording"] = recording + worker_ctx["return_scaled"] = return_scaled + worker_ctx["method"] = method + return worker_ctx + + def get_noise_levels( recording: "BaseRecording", return_scaled: bool = True, method: Literal["mad", "std"] = "mad", force_recompute: bool = False, - **random_chunk_kwargs, + random_slices_kwargs: dict = {}, + **kwargs, ) -> np.ndarray: """ Estimate noise for each channel using MAD methods. You can use standard deviation with `method="std"` Internally it samples some chunk across segment. - And then, it use MAD estimator (more robust than STD) + And then, it uses the MAD estimator (more robust than STD) or the STD on each chunk. + Finally the average of all MAD/STD values is performed. + + The result is cached in a property of the recording, so that the next call on the same + recording will use the cached result unless `force_recompute=True`. Parameters ---------- @@ -662,8 +749,11 @@ def get_noise_levels( The method to use to estimate noise levels force_recompute : bool If True, noise levels are recomputed even if they are already stored in the recording extractor - random_chunk_kwargs : dict - Kwargs for get_random_data_chunks + random_slices_kwargs : dict + Options transmited to get_random_recording_slices(), please read documentation from this + function for more details. + + {} Returns ------- @@ -679,19 +769,56 @@ def get_noise_levels( if key in recording.get_property_keys() and not force_recompute: noise_levels = recording.get_property(key=key) else: - random_chunks = get_random_data_chunks(recording, return_scaled=return_scaled, **random_chunk_kwargs) - - if method == "mad": - med = np.median(random_chunks, axis=0, keepdims=True) - # hard-coded so that core doesn't depend on scipy - noise_levels = np.median(np.abs(random_chunks - med), axis=0) / 0.6744897501960817 - elif method == "std": - noise_levels = np.std(random_chunks, axis=0) + # This is to keep backward compatibility + # lets keep for a while and remove this maybe in 0.103.0 + # chunk_size used to be in the signature and now is ambiguous + random_slices_kwargs_, job_kwargs = split_job_kwargs(kwargs) + if len(random_slices_kwargs_) > 0 or "chunk_size" in job_kwargs: + msg = ( + "get_noise_levels(recording, num_chunks_per_segment=20) is deprecated\n" + "Now, you need to use get_noise_levels(recording, random_slices_kwargs=dict(num_chunks_per_segment=20, chunk_size=1000))\n" + "Please read get_random_recording_slices() documentation for more options." + ) + # if the user use both the old and the new behavior then an error is raised + assert len(random_slices_kwargs) == 0, msg + warnings.warn(msg) + random_slices_kwargs = random_slices_kwargs_ + if "chunk_size" in job_kwargs: + random_slices_kwargs["chunk_size"] = job_kwargs["chunk_size"] + + recording_slices = get_random_recording_slices(recording, **random_slices_kwargs) + + noise_levels_chunks = [] + + def append_noise_chunk(res): + noise_levels_chunks.append(res) + + func = _noise_level_chunk + init_func = _noise_level_chunk_init + init_args = (recording, return_scaled, method) + executor = ChunkRecordingExecutor( + recording, + func, + init_func, + init_args, + job_name="noise_level", + verbose=False, + gather_func=append_noise_chunk, + **job_kwargs, + ) + executor.run(recording_slices=recording_slices) + noise_levels_chunks = np.stack(noise_levels_chunks) + noise_levels = np.mean(noise_levels_chunks, axis=0) + + # set property recording.set_property(key, noise_levels) return noise_levels +get_noise_levels.__doc__ = get_noise_levels.__doc__.format(_shared_job_kwargs_doc) + + def get_chunk_with_margin( rec_segment, start_frame, diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 5f33350820..213968a80b 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -197,17 +197,23 @@ def random_spikes_selection( cum_sizes = np.cumsum([0] + [s.size for s in spikes]) # this fast when numba - spike_indices = spike_vector_to_indices(spikes, sorting.unit_ids) + spike_indices = spike_vector_to_indices(spikes, sorting.unit_ids, absolute_index=False) random_spikes_indices = [] for unit_index, unit_id in enumerate(sorting.unit_ids): all_unit_indices = [] for segment_index in range(sorting.get_num_segments()): - inds_in_seg = spike_indices[segment_index][unit_id] + cum_sizes[segment_index] + # this is local index + inds_in_seg = spike_indices[segment_index][unit_id] if margin_size is not None: - inds_in_seg = inds_in_seg[inds_in_seg >= margin_size] - inds_in_seg = inds_in_seg[inds_in_seg < (num_samples[segment_index] - margin_size)] - all_unit_indices.append(inds_in_seg) + local_spikes = spikes[segment_index][inds_in_seg] + mask = (local_spikes["sample_index"] >= margin_size) & ( + local_spikes["sample_index"] < (num_samples[segment_index] - margin_size) + ) + inds_in_seg = inds_in_seg[mask] + # go back to absolut index + inds_in_seg_abs = inds_in_seg + cum_sizes[segment_index] + all_unit_indices.append(inds_in_seg_abs) all_unit_indices = np.concatenate(all_unit_indices) selected_unit_indices = rng.choice( all_unit_indices, size=min(max_spikes_per_unit, all_unit_indices.size), replace=False, shuffle=False diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 4961db8524..81f0a1cc56 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -11,6 +11,7 @@ import shutil import warnings import importlib +from copy import copy from packaging.version import parse from time import perf_counter @@ -23,7 +24,7 @@ from .baserecording import BaseRecording from .basesorting import BaseSorting -from .base import load_extractor +from .loading import load from .recording_tools import check_probe_do_not_overlap, get_rec_attributes, do_recording_attributes_match from .core_tools import check_json, retrieve_importing_provenance, is_path_remote, clean_zarr_folder_name from .sorting_tools import generate_unit_ids_for_merge_group, _get_ids_after_merging @@ -45,6 +46,7 @@ def create_sorting_analyzer( sparsity=None, return_scaled=True, overwrite=False, + backend_options=None, **sparsity_kwargs, ) -> "SortingAnalyzer": """ @@ -63,24 +65,29 @@ def create_sorting_analyzer( recording : Recording The recording object folder : str or Path or None, default: None - The folder where waveforms are cached + The folder where analyzer is cached format : "memory | "binary_folder" | "zarr", default: "memory" - The mode to store waveforms. If "folder", waveforms are stored on disk in the specified folder. + The mode to store analyzer. If "folder", the analyzer is stored on disk in the specified folder. The "folder" argument must be specified in case of mode "folder". - If "memory" is used, the waveforms are stored in RAM. Use this option carefully! + If "memory" is used, the analyzer is stored in RAM. Use this option carefully! sparse : bool, default: True If True, then a sparsity mask is computed using the `estimate_sparsity()` function using a few spikes to get an estimate of dense templates to create a ChannelSparsity object. Then, the sparsity will be propagated to all ResultExtention that handle sparsity (like wavforms, pca, ...) You can control `estimate_sparsity()` : all extra arguments are propagated to it (included job_kwargs) sparsity : ChannelSparsity or None, default: None - The sparsity used to compute waveforms. If this is given, `sparse` is ignored. + The sparsity used to compute exensions. If this is given, `sparse` is ignored. return_scaled : bool, default: True All extensions that play with traces will use this global return_scaled : "waveforms", "noise_levels", "templates". This prevent return_scaled being differents from different extensions and having wrong snr for instance. overwrite: bool, default: False If True, overwrite the folder if it already exists. - + backend_options : dict | None, default: None + Keyword arguments for the backend specified by format. It can contain the: + - storage_options: dict | None (fsspec storage options) + - saving_options: dict | None (additional saving options for creating and saving datasets, + e.g. compression/filters for zarr) + sparsity_kwargs : keyword arguments Returns ------- @@ -91,7 +98,7 @@ def create_sorting_analyzer( -------- >>> import spikeinterface as si - >>> # Extract dense waveforms and save to disk with binary_folder format. + >>> # Create dense analyzer and save to disk with binary_folder format. >>> sorting_analyzer = si.create_sorting_analyzer(sorting, recording, format="binary_folder", folder="/path/to_my/result") >>> # Can be reload @@ -117,12 +124,14 @@ def create_sorting_analyzer( """ if format != "memory": if format == "zarr": - folder = clean_zarr_folder_name(folder) - if Path(folder).is_dir(): - if not overwrite: - raise ValueError(f"Folder already exists {folder}! Use overwrite=True to overwrite it.") - else: - shutil.rmtree(folder) + if not is_path_remote(folder): + folder = clean_zarr_folder_name(folder) + if not is_path_remote(folder): + if Path(folder).is_dir(): + if not overwrite: + raise ValueError(f"Folder already exists {folder}! Use overwrite=True to overwrite it.") + else: + shutil.rmtree(folder) # handle sparsity if sparsity is not None: @@ -144,27 +153,38 @@ def create_sorting_analyzer( return_scaled = False sorting_analyzer = SortingAnalyzer.create( - sorting, recording, format=format, folder=folder, sparsity=sparsity, return_scaled=return_scaled + sorting, + recording, + format=format, + folder=folder, + sparsity=sparsity, + return_scaled=return_scaled, + backend_options=backend_options, ) return sorting_analyzer -def load_sorting_analyzer(folder, load_extensions=True, format="auto", storage_options=None) -> "SortingAnalyzer": +def load_sorting_analyzer(folder, load_extensions=True, format="auto", backend_options=None) -> "SortingAnalyzer": """ Load a SortingAnalyzer object from disk. Parameters ---------- folder : str or Path - The folder / zarr folder where the waveform extractor is stored + The folder / zarr folder where the analyzer is stored. If the folder is a remote path stored in the cloud, + the backend_options can be used to specify credentials. If the remote path is not accessible, + and backend_options is not provided, the function will try to load the object in anonymous mode (anon=True), + which enables to load data from open buckets. load_extensions : bool, default: True Load all extensions or not. format : "auto" | "binary_folder" | "zarr" The format of the folder. - storage_options : dict | None, default: None - The storage options to specify credentials to remote zarr bucket. - For open buckets, it doesn't need to be specified. + backend_options : dict | None, default: None + The backend options for the backend. + The dictionary can contain the following keys: + - storage_options: dict | None (fsspec storage options) + - saving_options: dict | None (additional saving options for creating and saving datasets) Returns ------- @@ -172,7 +192,20 @@ def load_sorting_analyzer(folder, load_extensions=True, format="auto", storage_o The loaded SortingAnalyzer """ - return SortingAnalyzer.load(folder, load_extensions=load_extensions, format=format, storage_options=storage_options) + if is_path_remote(folder) and backend_options is None: + try: + return SortingAnalyzer.load( + folder, load_extensions=load_extensions, format=format, backend_options=backend_options + ) + except Exception as e: + backend_options = dict(storage_options=dict(anon=True)) + return SortingAnalyzer.load( + folder, load_extensions=load_extensions, format=format, backend_options=backend_options + ) + else: + return SortingAnalyzer.load( + folder, load_extensions=load_extensions, format=format, backend_options=backend_options + ) class SortingAnalyzer: @@ -205,7 +238,7 @@ def __init__( format=None, sparsity=None, return_scaled=True, - storage_options=None, + backend_options=None, ): # very fast init because checks are done in load and create self.sorting = sorting @@ -215,10 +248,18 @@ def __init__( self.format = format self.sparsity = sparsity self.return_scaled = return_scaled - self.storage_options = storage_options + # this is used to store temporary recording self._temporary_recording = None + # backend-specific kwargs for different formats, which can be used to + # set some parameters for saving (e.g., compression) + # + # - storage_options: dict | None (fsspec storage options) + # - saving_options: dict | None + # (additional saving options for creating and saving datasets, e.g. compression/filters for zarr) + self._backend_options = {} if backend_options is None else backend_options + # extensions are not loaded at init self.extensions = dict() @@ -228,13 +269,18 @@ def __repr__(self) -> str: nchan = self.get_num_channels() nunits = self.get_num_units() txt = f"{clsname}: {nchan} channels - {nunits} units - {nseg} segments - {self.format}" + if self.format != "memory": + if is_path_remote(str(self.folder)): + txt += f" (remote)" if self.is_sparse(): txt += " - sparse" if self.has_recording(): txt += " - has recording" if self.has_temporary_recording(): txt += " - has temporary recording" - ext_txt = f"Loaded {len(self.extensions)} extensions: " + ", ".join(self.extensions.keys()) + ext_txt = f"Loaded {len(self.extensions)} extensions" + if len(self.extensions) > 0: + ext_txt += f": {', '.join(self.extensions.keys())}" txt += "\n" + ext_txt return txt @@ -253,7 +299,9 @@ def create( folder=None, sparsity=None, return_scaled=True, + backend_options=None, ): + assert recording is not None, "To create a SortingAnalyzer you need to specify the recording" # some checks if sorting.sampling_frequency != recording.sampling_frequency: if math.isclose(sorting.sampling_frequency, recording.sampling_frequency, abs_tol=1e-2, rel_tol=1e-5): @@ -277,22 +325,35 @@ def create( if format == "memory": sorting_analyzer = cls.create_memory(sorting, recording, sparsity, return_scaled, rec_attributes=None) elif format == "binary_folder": - cls.create_binary_folder(folder, sorting, recording, sparsity, return_scaled, rec_attributes=None) - sorting_analyzer = cls.load_from_binary_folder(folder, recording=recording) - sorting_analyzer.folder = Path(folder) + sorting_analyzer = cls.create_binary_folder( + folder, + sorting, + recording, + sparsity, + return_scaled, + rec_attributes=None, + backend_options=backend_options, + ) elif format == "zarr": assert folder is not None, "For format='zarr' folder must be provided" - folder = clean_zarr_folder_name(folder) - cls.create_zarr(folder, sorting, recording, sparsity, return_scaled, rec_attributes=None) - sorting_analyzer = cls.load_from_zarr(folder, recording=recording) - sorting_analyzer.folder = Path(folder) + if not is_path_remote(folder): + folder = clean_zarr_folder_name(folder) + sorting_analyzer = cls.create_zarr( + folder, + sorting, + recording, + sparsity, + return_scaled, + rec_attributes=None, + backend_options=backend_options, + ) else: raise ValueError("SortingAnalyzer.create: wrong format") return sorting_analyzer @classmethod - def load(cls, folder, recording=None, load_extensions=True, format="auto", storage_options=None): + def load(cls, folder, recording=None, load_extensions=True, format="auto", backend_options=None): """ Load folder or zarr. The recording can be given if the recording location has changed. @@ -306,18 +367,15 @@ def load(cls, folder, recording=None, load_extensions=True, format="auto", stora format = "binary_folder" if format == "binary_folder": - sorting_analyzer = SortingAnalyzer.load_from_binary_folder(folder, recording=recording) + sorting_analyzer = SortingAnalyzer.load_from_binary_folder( + folder, recording=recording, backend_options=backend_options + ) elif format == "zarr": sorting_analyzer = SortingAnalyzer.load_from_zarr( - folder, recording=recording, storage_options=storage_options + folder, recording=recording, backend_options=backend_options ) - if is_path_remote(str(folder)): - sorting_analyzer.folder = folder - # in this case we only load extensions when needed - else: - sorting_analyzer.folder = Path(folder) - + if not is_path_remote(str(folder)): if load_extensions: sorting_analyzer.load_all_saved_extension() @@ -349,11 +407,9 @@ def create_memory(cls, sorting, recording, sparsity, return_scaled, rec_attribut return sorting_analyzer @classmethod - def create_binary_folder(cls, folder, sorting, recording, sparsity, return_scaled, rec_attributes): + def create_binary_folder(cls, folder, sorting, recording, sparsity, return_scaled, rec_attributes, backend_options): # used by create and save_as - assert recording is not None, "To create a SortingAnalyzer you need to specify the recording" - folder = Path(folder) if folder.is_dir(): raise ValueError(f"Folder already exists {folder}") @@ -369,26 +425,34 @@ def create_binary_folder(cls, folder, sorting, recording, sparsity, return_scale json.dump(check_json(info), f, indent=4) # save a copy of the sorting - # NumpyFolderSorting.write_sorting(sorting, folder / "sorting") sorting.save(folder=folder / "sorting") - # save recording and sorting provenance - if recording.check_serializability("json"): - recording.dump(folder / "recording.json", relative_to=folder) - elif recording.check_serializability("pickle"): - recording.dump(folder / "recording.pickle", relative_to=folder) + if recording is not None: + # save recording and sorting provenance + if recording.check_serializability("json"): + recording.dump(folder / "recording.json", relative_to=folder) + elif recording.check_serializability("pickle"): + recording.dump(folder / "recording.pickle", relative_to=folder) + else: + warnings.warn("The Recording is not serializable! The recording link will be lost for future load") + else: + assert rec_attributes is not None, "recording or rec_attributes must be provided" + warnings.warn("Recording not provided, instntiating SortingAnalyzer in recordingless mode.") if sorting.check_serializability("json"): sorting.dump(folder / "sorting_provenance.json", relative_to=folder) elif sorting.check_serializability("pickle"): sorting.dump(folder / "sorting_provenance.pickle", relative_to=folder) + else: + warnings.warn( + "The sorting provenance is not serializable! The sorting provenance link will be lost for future load" + ) # dump recording attributes probegroup = None rec_attributes_file = folder / "recording_info" / "recording_attributes.json" rec_attributes_file.parent.mkdir() if rec_attributes is None: - assert recording is not None rec_attributes = get_rec_attributes(recording) rec_attributes_file.write_text(json.dumps(check_json(rec_attributes), indent=4), encoding="utf8") probegroup = recording.get_probegroup() @@ -411,8 +475,10 @@ def create_binary_folder(cls, folder, sorting, recording, sparsity, return_scale with open(settings_file, mode="w") as f: json.dump(check_json(settings), f, indent=4) + return cls.load_from_binary_folder(folder, recording=recording, backend_options=backend_options) + @classmethod - def load_from_binary_folder(cls, folder, recording=None): + def load_from_binary_folder(cls, folder, recording=None, backend_options=None): folder = Path(folder) assert folder.is_dir(), f"This folder does not exists {folder}" @@ -428,7 +494,7 @@ def load_from_binary_folder(cls, folder, recording=None): filename = folder / f"recording.{type}" if filename.exists(): try: - recording = load_extractor(filename, base_folder=folder) + recording = load(filename, base_folder=folder) break except: recording = None @@ -483,34 +549,50 @@ def load_from_binary_folder(cls, folder, recording=None): format="binary_folder", sparsity=sparsity, return_scaled=return_scaled, + backend_options=backend_options, ) + sorting_analyzer.folder = folder return sorting_analyzer def _get_zarr_root(self, mode="r+"): import zarr - if is_path_remote(str(self.folder)): - mode = "r" + assert mode in ("r+", "a", "r"), "mode must be 'r+', 'a' or 'r'" + + storage_options = self._backend_options.get("storage_options", {}) # we open_consolidated only if we are in read mode if mode in ("r+", "a"): - zarr_root = zarr.open(str(self.folder), mode=mode, storage_options=self.storage_options) + try: + zarr_root = zarr.open(str(self.folder), mode=mode, storage_options=storage_options) + except Exception as e: + # this could happen in remote mode, and it's a way to check if the folder is still there + zarr_root = zarr.open_consolidated(self.folder, mode=mode, storage_options=storage_options) else: - zarr_root = zarr.open_consolidated(self.folder, mode=mode, storage_options=self.storage_options) + zarr_root = zarr.open_consolidated(self.folder, mode=mode, storage_options=storage_options) return zarr_root @classmethod - def create_zarr(cls, folder, sorting, recording, sparsity, return_scaled, rec_attributes): + def create_zarr(cls, folder, sorting, recording, sparsity, return_scaled, rec_attributes, backend_options): # used by create and save_as import zarr import numcodecs + from .zarrextractors import add_sorting_to_zarr_group - folder = clean_zarr_folder_name(folder) + if is_path_remote(folder): + remote = True + else: + remote = False + if not remote: + folder = clean_zarr_folder_name(folder) + if folder.is_dir(): + raise ValueError(f"Folder already exists {folder}") - if folder.is_dir(): - raise ValueError(f"Folder already exists {folder}") + backend_options = {} if backend_options is None else backend_options + storage_options = backend_options.get("storage_options", {}) + saving_options = backend_options.get("saving_options", {}) - zarr_root = zarr.open(folder, mode="w") + zarr_root = zarr.open(folder, mode="w", storage_options=storage_options) info = dict(version=spikeinterface.__version__, dev_mode=spikeinterface.DEV_MODE, object="SortingAnalyzer") zarr_root.attrs["spikeinterface_info"] = check_json(info) @@ -519,37 +601,39 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_scaled, rec_at zarr_root.attrs["settings"] = check_json(settings) # the recording - rec_dict = recording.to_dict(relative_to=folder, recursive=True) - - if recording.check_serializability("json"): - # zarr_root.create_dataset("recording", data=rec_dict, object_codec=numcodecs.JSON()) - zarr_rec = np.array([check_json(rec_dict)], dtype=object) - zarr_root.create_dataset("recording", data=zarr_rec, object_codec=numcodecs.JSON()) - elif recording.check_serializability("pickle"): - # zarr_root.create_dataset("recording", data=rec_dict, object_codec=numcodecs.Pickle()) - zarr_rec = np.array([rec_dict], dtype=object) - zarr_root.create_dataset("recording", data=zarr_rec, object_codec=numcodecs.Pickle()) + relative_to = folder if not remote else None + if recording is not None: + rec_dict = recording.to_dict(relative_to=relative_to, recursive=True) + if recording.check_serializability("json"): + # zarr_root.create_dataset("recording", data=rec_dict, object_codec=numcodecs.JSON()) + zarr_rec = np.array([check_json(rec_dict)], dtype=object) + zarr_root.create_dataset("recording", data=zarr_rec, object_codec=numcodecs.JSON()) + elif recording.check_serializability("pickle"): + # zarr_root.create_dataset("recording", data=rec_dict, object_codec=numcodecs.Pickle()) + zarr_rec = np.array([rec_dict], dtype=object) + zarr_root.create_dataset("recording", data=zarr_rec, object_codec=numcodecs.Pickle()) + else: + warnings.warn("The Recording is not serializable! The recording link will be lost for future load") else: - warnings.warn( - "SortingAnalyzer with zarr : the Recording is not json serializable, the recording link will be lost for future load" - ) + assert rec_attributes is not None, "recording or rec_attributes must be provided" + warnings.warn("Recording not provided, instntiating SortingAnalyzer in recordingless mode.") # sorting provenance - sort_dict = sorting.to_dict(relative_to=folder, recursive=True) + sort_dict = sorting.to_dict(relative_to=relative_to, recursive=True) if sorting.check_serializability("json"): zarr_sort = np.array([check_json(sort_dict)], dtype=object) zarr_root.create_dataset("sorting_provenance", data=zarr_sort, object_codec=numcodecs.JSON()) elif sorting.check_serializability("pickle"): zarr_sort = np.array([sort_dict], dtype=object) zarr_root.create_dataset("sorting_provenance", data=zarr_sort, object_codec=numcodecs.Pickle()) - - # else: - # warnings.warn("SortingAnalyzer with zarr : the sorting provenance is not json serializable, the sorting provenance link will be lost for futur load") + else: + warnings.warn( + "The sorting provenance is not serializable! The sorting provenance link will be lost for future load" + ) recording_info = zarr_root.create_group("recording_info") if rec_attributes is None: - assert recording is not None rec_attributes = get_rec_attributes(recording) probegroup = recording.get_probegroup() else: @@ -562,24 +646,23 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_scaled, rec_at recording_info.attrs["probegroup"] = check_json(probegroup.to_dict()) if sparsity is not None: - zarr_root.create_dataset("sparsity_mask", data=sparsity.mask) - - # write sorting copy - from .zarrextractors import add_sorting_to_zarr_group + zarr_root.create_dataset("sparsity_mask", data=sparsity.mask, **saving_options) - # Alessio : we need to find a way to propagate compressor for all steps. - # kwargs = dict(compressor=...) - zarr_kwargs = dict() - add_sorting_to_zarr_group(sorting, zarr_root.create_group("sorting"), **zarr_kwargs) + add_sorting_to_zarr_group(sorting, zarr_root.create_group("sorting"), **saving_options) recording_info = zarr_root.create_group("extensions") zarr.consolidate_metadata(zarr_root.store) + return cls.load_from_zarr(folder, recording=recording, backend_options=backend_options) + @classmethod - def load_from_zarr(cls, folder, recording=None, storage_options=None): + def load_from_zarr(cls, folder, recording=None, backend_options=None): import zarr + backend_options = {} if backend_options is None else backend_options + storage_options = backend_options.get("storage_options", {}) + zarr_root = zarr.open_consolidated(str(folder), mode="r", storage_options=storage_options) si_info = zarr_root.attrs["spikeinterface_info"] @@ -605,11 +688,13 @@ def load_from_zarr(cls, folder, recording=None, storage_options=None): # load recording if possible if recording is None: - rec_dict = zarr_root["recording"][0] - try: - recording = load_extractor(rec_dict, base_folder=folder) - except: - recording = None + rec_field = zarr_root.get("recording") + if rec_field is not None: + rec_dict = rec_field[0] + try: + recording = load(rec_dict, base_folder=folder) + except: + recording = None else: # TODO maybe maybe not??? : do we need to check attributes match internal rec_attributes # Note this will make the loading too slow @@ -640,8 +725,9 @@ def load_from_zarr(cls, folder, recording=None, storage_options=None): format="zarr", sparsity=sparsity, return_scaled=return_scaled, - storage_options=storage_options, + backend_options=backend_options, ) + sorting_analyzer.folder = folder return sorting_analyzer @@ -683,6 +769,7 @@ def _save_or_select_or_merge( sparsity_overlap=0.75, verbose=False, new_unit_ids=None, + backend_options=None, **job_kwargs, ) -> "SortingAnalyzer": """ @@ -712,8 +799,13 @@ def _save_or_select_or_merge( The new unit ids for merged units. Required if `merge_unit_groups` is not None. verbose : bool, default: False If True, output is verbose. - job_kwargs : dict - Keyword arguments for parallelization. + backend_options : dict | None, default: None + Keyword arguments for the backend specified by format. It can contain the: + - storage_options: dict | None (fsspec storage options) + - saving_options: dict | None (additional saving options for creating and saving datasets, + e.g. compression/filters for zarr) + job_kwargs : keyword arguments + Keyword arguments for the job parallelization. Returns ------- @@ -787,6 +879,8 @@ def _save_or_select_or_merge( # TODO: sam/pierre would create a curation field / curation.json with the applied merges. # What do you think? + backend_options = {} if backend_options is None else backend_options + if format == "memory": # This make a copy of actual SortingAnalyzer new_sorting_analyzer = SortingAnalyzer.create_memory( @@ -797,20 +891,28 @@ def _save_or_select_or_merge( # create a new folder assert folder is not None, "For format='binary_folder' folder must be provided" folder = Path(folder) - SortingAnalyzer.create_binary_folder( - folder, sorting_provenance, recording, sparsity, self.return_scaled, self.rec_attributes + new_sorting_analyzer = SortingAnalyzer.create_binary_folder( + folder, + sorting_provenance, + recording, + sparsity, + self.return_scaled, + self.rec_attributes, + backend_options=backend_options, ) - new_sorting_analyzer = SortingAnalyzer.load_from_binary_folder(folder, recording=recording) - new_sorting_analyzer.folder = folder elif format == "zarr": assert folder is not None, "For format='zarr' folder must be provided" folder = clean_zarr_folder_name(folder) - SortingAnalyzer.create_zarr( - folder, sorting_provenance, recording, sparsity, self.return_scaled, self.rec_attributes + new_sorting_analyzer = SortingAnalyzer.create_zarr( + folder, + sorting_provenance, + recording, + sparsity, + self.return_scaled, + self.rec_attributes, + backend_options=backend_options, ) - new_sorting_analyzer = SortingAnalyzer.load_from_zarr(folder, recording=recording) - new_sorting_analyzer.folder = folder else: raise ValueError(f"SortingAnalyzer.save: unsupported format: {format}") @@ -848,7 +950,7 @@ def _save_or_select_or_merge( return new_sorting_analyzer - def save_as(self, format="memory", folder=None) -> "SortingAnalyzer": + def save_as(self, format="memory", folder=None, backend_options=None) -> "SortingAnalyzer": """ Save SortingAnalyzer object into another format. Uselful for memory to zarr or memory to binary. @@ -863,10 +965,15 @@ def save_as(self, format="memory", folder=None) -> "SortingAnalyzer": The output folder if `format` is "zarr" or "binary_folder" format : "memory" | "binary_folder" | "zarr", default: "memory" The new backend format to use + backend_options : dict | None, default: None + Keyword arguments for the backend specified by format. It can contain the: + - storage_options: dict | None (fsspec storage options) + - saving_options: dict | None (additional saving options for creating and saving datasets, + e.g. compression/filters for zarr) """ if format == "zarr": folder = clean_zarr_folder_name(folder) - return self._save_or_select_or_merge(format=format, folder=folder) + return self._save_or_select_or_merge(format=format, folder=folder, backend_options=backend_options) def select_units(self, unit_ids, format="memory", folder=None) -> "SortingAnalyzer": """ @@ -1029,7 +1136,15 @@ def copy(self): def is_read_only(self) -> bool: if self.format == "memory": return False - return not os.access(self.folder, os.W_OK) + elif self.format == "binary_folder": + return not os.access(self.folder, os.W_OK) + else: + if not is_path_remote(str(self.folder)): + return not os.access(self.folder, os.W_OK) + else: + # in this case we don't know if the file is read only so an error + # will be raised if we try to save/append + return False ## map attribute and property zone @@ -1077,7 +1192,7 @@ def get_sorting_provenance(self): sorting_provenance = None if filename.exists(): try: - sorting_provenance = load_extractor(filename, base_folder=self.folder) + sorting_provenance = load(filename, base_folder=self.folder) break except: pass @@ -1087,7 +1202,7 @@ def get_sorting_provenance(self): zarr_root = self._get_zarr_root(mode="r") if "sorting_provenance" in zarr_root.keys(): sort_dict = zarr_root["sorting_provenance"][0] - sorting_provenance = load_extractor(sort_dict, base_folder=self.folder) + sorting_provenance = load(sort_dict, base_folder=self.folder) else: sorting_provenance = None @@ -1965,7 +2080,8 @@ def load_data(self): continue ext_data_name = ext_data_file.stem if ext_data_file.suffix == ".json": - ext_data = json.load(ext_data_file.open("r")) + with ext_data_file.open("r") as f: + ext_data = json.load(f) elif ext_data_file.suffix == ".npy": # The lazy loading of an extension is complicated because if we compute again # and have a link to the old buffer on windows then it fails @@ -1976,8 +2092,16 @@ def load_data(self): import pandas as pd ext_data = pd.read_csv(ext_data_file, index_col=0) + # we need to cast the index to the unit id dtype (int or str) + unit_ids = self.sorting_analyzer.unit_ids + if ext_data.shape[0] == unit_ids.size: + # we force dtype to be the same as unit_ids + if ext_data.index.dtype != unit_ids.dtype: + ext_data.index = ext_data.index.astype(unit_ids.dtype) + elif ext_data_file.suffix == ".pkl": - ext_data = pickle.load(ext_data_file.open("rb")) + with ext_data_file.open("rb") as f: + ext_data = pickle.load(f) else: continue self.data[ext_data_name] = ext_data @@ -2015,7 +2139,7 @@ def copy(self, new_sorting_analyzer, unit_ids=None): new_extension.data = self.data else: new_extension.data = self._select_extension_data(unit_ids) - new_extension.run_info = self.run_info.copy() + new_extension.run_info = copy(self.run_info) new_extension.save() return new_extension @@ -2033,7 +2157,7 @@ def merge( new_extension.data = self._merge_extension_data( merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask, verbose=verbose, **job_kwargs ) - new_extension.run_info = self.run_info.copy() + new_extension.run_info = copy(self.run_info) new_extension.save() return new_extension @@ -2051,24 +2175,24 @@ def run(self, save=True, **kwargs): if save and not self.sorting_analyzer.is_read_only(): self._save_run_info() - self._save_data(**kwargs) + self._save_data() if self.format == "zarr": import zarr zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root().store) - def save(self, **kwargs): + def save(self): self._save_params() self._save_importing_provenance() self._save_run_info() - self._save_data(**kwargs) + self._save_data() if self.format == "zarr": import zarr zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root().store) - def _save_data(self, **kwargs): + def _save_data(self): if self.format == "memory": return @@ -2107,14 +2231,14 @@ def _save_data(self, **kwargs): except: raise Exception(f"Could not save {ext_data_name} as extension data") elif self.format == "zarr": - import zarr import numcodecs + saving_options = self.sorting_analyzer._backend_options.get("saving_options", {}) extension_group = self._get_zarr_extension_group(mode="r+") - compressor = kwargs.get("compressor", None) - if compressor is None: - compressor = get_default_zarr_compressor() + # if compression is not externally given, we use the default + if "compressor" not in saving_options: + saving_options["compressor"] = get_default_zarr_compressor() for ext_data_name, ext_data in self.data.items(): if ext_data_name in extension_group: @@ -2124,13 +2248,19 @@ def _save_data(self, **kwargs): name=ext_data_name, data=np.array([ext_data], dtype=object), object_codec=numcodecs.JSON() ) elif isinstance(ext_data, np.ndarray): - extension_group.create_dataset(name=ext_data_name, data=ext_data, compressor=compressor) + extension_group.create_dataset(name=ext_data_name, data=ext_data, **saving_options) elif HAS_PANDAS and isinstance(ext_data, pd.DataFrame): df_group = extension_group.create_group(ext_data_name) # first we save the index - df_group.create_dataset(name="index", data=ext_data.index.to_numpy()) + indices = ext_data.index.to_numpy() + if indices.dtype.kind == "O": + indices = indices.astype(str) + df_group.create_dataset(name="index", data=indices) for col in ext_data.columns: - df_group.create_dataset(name=col, data=ext_data[col].to_numpy()) + col_data = ext_data[col].to_numpy() + if col_data.dtype.kind == "O": + col_data = col_data.astype(str) + df_group.create_dataset(name=col, data=col_data) df_group.attrs["dataframe"] = True else: # any object @@ -2187,7 +2317,7 @@ def delete(self): def reset(self): """ - Reset the waveform extension. + Reset the extension. Delete the sub folder and create a new empty one. """ self._reset_extension_folder() @@ -2202,7 +2332,8 @@ def set_params(self, save=True, **params): """ # this ensure data is also deleted and corresponds to params # this also ensure the group is created - self._reset_extension_folder() + if save: + self._reset_extension_folder() params = self._set_params(**params) self.params = params @@ -2251,15 +2382,16 @@ def _save_importing_provenance(self): extension_group.attrs["info"] = info def _save_run_info(self): - run_info = self.run_info.copy() - - if self.format == "binary_folder": - extension_folder = self._get_binary_extension_folder() - run_info_file = extension_folder / "run_info.json" - run_info_file.write_text(json.dumps(run_info, indent=4), encoding="utf8") - elif self.format == "zarr": - extension_group = self._get_zarr_extension_group(mode="r+") - extension_group.attrs["run_info"] = run_info + if self.run_info is not None: + run_info = self.run_info.copy() + + if self.format == "binary_folder": + extension_folder = self._get_binary_extension_folder() + run_info_file = extension_folder / "run_info.json" + run_info_file.write_text(json.dumps(run_info, indent=4), encoding="utf8") + elif self.format == "zarr": + extension_group = self._get_zarr_extension_group(mode="r+") + extension_group.attrs["run_info"] = run_info def get_pipeline_nodes(self): assert ( diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index b64f0610ea..3e3fcc7384 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -205,6 +205,7 @@ def to_sparse(self, sparsity): unit_ids=self.unit_ids, probe=self.probe, check_for_consistent_sparsity=self.check_for_consistent_sparsity, + is_scaled=self.is_scaled, ) def get_one_template_dense(self, unit_index): diff --git a/src/spikeinterface/core/template_tools.py b/src/spikeinterface/core/template_tools.py index 934b18ed49..3c8663df70 100644 --- a/src/spikeinterface/core/template_tools.py +++ b/src/spikeinterface/core/template_tools.py @@ -31,7 +31,12 @@ def get_dense_templates_array(one_object: Templates | SortingAnalyzer, return_sc ) ext = one_object.get_extension("templates") if ext is not None: - templates_array = ext.data["average"] + if "average" in ext.data: + templates_array = ext.data.get("average") + elif "median" in ext.data: + templates_array = ext.data.get("median") + else: + raise ValueError("Average or median templates have not been computed.") else: raise ValueError("SortingAnalyzer need extension 'templates' to be computed to retrieve templates") else: diff --git a/src/spikeinterface/core/tests/test_analyzer_extension_core.py b/src/spikeinterface/core/tests/test_analyzer_extension_core.py index 626899ab6e..6f5bef3c6c 100644 --- a/src/spikeinterface/core/tests/test_analyzer_extension_core.py +++ b/src/spikeinterface/core/tests/test_analyzer_extension_core.py @@ -2,6 +2,8 @@ import shutil +from pathlib import Path + from spikeinterface.core import generate_ground_truth_recording from spikeinterface.core import create_sorting_analyzer from spikeinterface.core import Templates @@ -250,16 +252,17 @@ def test_compute_several(create_cache_folder): if __name__ == "__main__": - - test_ComputeWaveforms(format="memory", sparse=True) - test_ComputeWaveforms(format="memory", sparse=False) - test_ComputeWaveforms(format="binary_folder", sparse=True) - test_ComputeWaveforms(format="binary_folder", sparse=False) - test_ComputeWaveforms(format="zarr", sparse=True) - test_ComputeWaveforms(format="zarr", sparse=False) - test_ComputeRandomSpikes(format="memory", sparse=True) - test_ComputeTemplates(format="memory", sparse=True) - test_ComputeNoiseLevels(format="memory", sparse=False) - - test_get_children_dependencies() - test_delete_on_recompute() + cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "core" + # test_ComputeWaveforms(format="memory", sparse=True, create_cache_folder=cache_folder) + # test_ComputeWaveforms(format="memory", sparse=False, create_cache_folder=cache_folder) + # test_ComputeWaveforms(format="binary_folder", sparse=True, create_cache_folder=cache_folder) + # test_ComputeWaveforms(format="binary_folder", sparse=False, create_cache_folder=cache_folder) + # test_ComputeWaveforms(format="zarr", sparse=True, create_cache_folder=cache_folder) + # test_ComputeWaveforms(format="zarr", sparse=False, create_cache_folder=cache_folder) + # test_ComputeRandomSpikes(format="memory", sparse=True, create_cache_folder=cache_folder) + test_ComputeRandomSpikes(format="binary_folder", sparse=False, create_cache_folder=cache_folder) + test_ComputeTemplates(format="memory", sparse=True, create_cache_folder=cache_folder) + test_ComputeNoiseLevels(format="memory", sparse=False, create_cache_folder=cache_folder) + + # test_get_children_dependencies() + # test_delete_on_recompute(cache_folder) diff --git a/src/spikeinterface/core/tests/test_baserecording.py b/src/spikeinterface/core/tests/test_baserecording.py index df614978ba..7d7ce52f27 100644 --- a/src/spikeinterface/core/tests/test_baserecording.py +++ b/src/spikeinterface/core/tests/test_baserecording.py @@ -12,7 +12,7 @@ from probeinterface import Probe, ProbeGroup, generate_linear_probe -from spikeinterface.core import BinaryRecordingExtractor, NumpyRecording, load_extractor, get_default_zarr_compressor +from spikeinterface.core import BinaryRecordingExtractor, NumpyRecording, load, get_default_zarr_compressor from spikeinterface.core.base import BaseExtractor from spikeinterface.core.testing import check_recordings_equal @@ -84,38 +84,38 @@ def test_BaseRecording(create_cache_folder): # dump/load dict d = rec.to_dict(include_annotations=True, include_properties=True) rec2 = BaseExtractor.from_dict(d) - rec3 = load_extractor(d) + rec3 = load(d) check_recordings_equal(rec, rec2, return_scaled=False, check_annotations=True, check_properties=True) check_recordings_equal(rec, rec3, return_scaled=False, check_annotations=True, check_properties=True) # dump/load json rec.dump_to_json(cache_folder / "test_BaseRecording.json") rec2 = BaseExtractor.load(cache_folder / "test_BaseRecording.json") - rec3 = load_extractor(cache_folder / "test_BaseRecording.json") + rec3 = load(cache_folder / "test_BaseRecording.json") check_recordings_equal(rec, rec2, return_scaled=False, check_annotations=True, check_properties=False) check_recordings_equal(rec, rec3, return_scaled=False, check_annotations=True, check_properties=False) # dump/load pickle rec.dump_to_pickle(cache_folder / "test_BaseRecording.pkl") rec2 = BaseExtractor.load(cache_folder / "test_BaseRecording.pkl") - rec3 = load_extractor(cache_folder / "test_BaseRecording.pkl") + rec3 = load(cache_folder / "test_BaseRecording.pkl") check_recordings_equal(rec, rec2, return_scaled=False, check_annotations=True, check_properties=True) check_recordings_equal(rec, rec3, return_scaled=False, check_annotations=True, check_properties=True) # dump/load dict - relative d = rec.to_dict(relative_to=cache_folder, recursive=True) rec2 = BaseExtractor.from_dict(d, base_folder=cache_folder) - rec3 = load_extractor(d, base_folder=cache_folder) + rec3 = load(d, base_folder=cache_folder) # dump/load json - relative to rec.dump_to_json(cache_folder / "test_BaseRecording_rel.json", relative_to=cache_folder) rec2 = BaseExtractor.load(cache_folder / "test_BaseRecording_rel.json", base_folder=cache_folder) - rec3 = load_extractor(cache_folder / "test_BaseRecording_rel.json", base_folder=cache_folder) + rec3 = load(cache_folder / "test_BaseRecording_rel.json", base_folder=cache_folder) # dump/load relative=True rec.dump_to_json(cache_folder / "test_BaseRecording_rel_true.json", relative_to=True) rec2 = BaseExtractor.load(cache_folder / "test_BaseRecording_rel_true.json", base_folder=True) - rec3 = load_extractor(cache_folder / "test_BaseRecording_rel_true.json", base_folder=True) + rec3 = load(cache_folder / "test_BaseRecording_rel_true.json", base_folder=True) check_recordings_equal(rec, rec2, return_scaled=False, check_annotations=True) check_recordings_equal(rec, rec3, return_scaled=False, check_annotations=True) with open(cache_folder / "test_BaseRecording_rel_true.json") as json_file: @@ -127,12 +127,12 @@ def test_BaseRecording(create_cache_folder): # dump/load pkl - relative to rec.dump_to_pickle(cache_folder / "test_BaseRecording_rel.pkl", relative_to=cache_folder) rec2 = BaseExtractor.load(cache_folder / "test_BaseRecording_rel.pkl", base_folder=cache_folder) - rec3 = load_extractor(cache_folder / "test_BaseRecording_rel.pkl", base_folder=cache_folder) + rec3 = load(cache_folder / "test_BaseRecording_rel.pkl", base_folder=cache_folder) # dump/load relative=True rec.dump_to_pickle(cache_folder / "test_BaseRecording_rel_true.pkl", relative_to=True) rec2 = BaseExtractor.load(cache_folder / "test_BaseRecording_rel_true.pkl", base_folder=True) - rec3 = load_extractor(cache_folder / "test_BaseRecording_rel_true.pkl", base_folder=True) + rec3 = load(cache_folder / "test_BaseRecording_rel_true.pkl", base_folder=True) check_recordings_equal(rec, rec2, return_scaled=False, check_annotations=True) check_recordings_equal(rec, rec3, return_scaled=False, check_annotations=True) with open(cache_folder / "test_BaseRecording_rel_true.pkl", "rb") as pkl_file: @@ -195,7 +195,7 @@ def test_BaseRecording(create_cache_folder): # test save with probe folder = cache_folder / "simple_recording3" rec2 = rec_p.save(folder=folder, chunk_size=10, n_jobs=2) - rec2 = load_extractor(folder) + rec2 = load(folder) probe2 = rec2.get_probe() assert np.array_equal(probe2.contact_positions, [[0, 30.0], [0.0, 0.0]]) positions2 = rec_p.get_channel_locations() @@ -286,7 +286,7 @@ def test_BaseRecording(create_cache_folder): folder = cache_folder / "recording_with_times" rec2 = rec.save(folder=folder) assert np.allclose(times1, rec2.get_times(1)) - rec3 = load_extractor(folder) + rec3 = load(folder) assert np.allclose(times1, rec3.get_times(1)) # reset times @@ -323,7 +323,7 @@ def test_BaseRecording(create_cache_folder): # test save to zarr compressor = get_default_zarr_compressor() rec_zarr = rec2.save(format="zarr", folder=cache_folder / "recording", compressor=compressor) - rec_zarr_loaded = load_extractor(cache_folder / "recording.zarr") + rec_zarr_loaded = load(cache_folder / "recording.zarr") # annotations is False because Zarr adds compression ratios check_recordings_equal(rec2, rec_zarr, return_scaled=False, check_annotations=False, check_properties=True) check_recordings_equal( @@ -336,7 +336,7 @@ def test_BaseRecording(create_cache_folder): rec_zarr2 = rec2.save( format="zarr", folder=cache_folder / "recording_channel_chunk", compressor=compressor, channel_chunk_size=2 ) - rec_zarr2_loaded = load_extractor(cache_folder / "recording_channel_chunk.zarr") + rec_zarr2_loaded = load(cache_folder / "recording_channel_chunk.zarr") # annotations is False because Zarr adds compression ratios check_recordings_equal(rec2, rec_zarr2, return_scaled=False, check_annotations=False, check_properties=True) diff --git a/src/spikeinterface/core/tests/test_basesnippets.py b/src/spikeinterface/core/tests/test_basesnippets.py index 64f7f76819..3d6c19c974 100644 --- a/src/spikeinterface/core/tests/test_basesnippets.py +++ b/src/spikeinterface/core/tests/test_basesnippets.py @@ -10,7 +10,7 @@ from probeinterface import Probe from spikeinterface.core import generate_snippets -from spikeinterface.core import NumpySnippets, load_extractor +from spikeinterface.core import NumpySnippets, load from spikeinterface.core.npysnippetsextractor import NpySnippetsExtractor from spikeinterface.core.base import BaseExtractor @@ -41,8 +41,8 @@ def test_BaseSnippets(create_cache_folder): assert snippets.get_num_segments() == len(duration) assert snippets.get_num_channels() == num_channels - assert np.all(snippets.ids_to_indices([0, 1, 2]) == [0, 1, 2]) - assert np.all(snippets.ids_to_indices([0, 1, 2], prefer_slice=True) == slice(0, 3, None)) + assert np.all(snippets.ids_to_indices(["0", "1", "2"]) == [0, 1, 2]) + assert np.all(snippets.ids_to_indices(["0", "1", "2"], prefer_slice=True) == slice(0, 3, None)) # annotations / properties snippets.annotate(gre="ta") @@ -60,7 +60,7 @@ def test_BaseSnippets(create_cache_folder): ) # missing property - snippets.set_property("string_property", ["ciao", "bello"], ids=[0, 1]) + snippets.set_property("string_property", ["ciao", "bello"], ids=["0", "1"]) values = snippets.get_property("string_property") assert values[2] == "" @@ -70,14 +70,14 @@ def test_BaseSnippets(create_cache_folder): snippets.set_property, key="string_property_nan", values=["hola", "chabon"], - ids=[0, 1], + ids=["0", "1"], missing_value=np.nan, ) # int properties without missing values raise an error assert_raises(Exception, snippets.set_property, key="int_property", values=[5, 6], ids=[1, 2]) - snippets.set_property("int_property", [5, 6], ids=[1, 2], missing_value=200) + snippets.set_property("int_property", [5, 6], ids=["1", "2"], missing_value=200) values = snippets.get_property("int_property") assert values.dtype.kind == "i" @@ -90,27 +90,27 @@ def test_BaseSnippets(create_cache_folder): # dump/load dict d = snippets.to_dict() snippets2 = BaseExtractor.from_dict(d) - snippets3 = load_extractor(d) + snippets3 = load(d) # dump/load json snippets.dump_to_json(cache_folder / "test_BaseSnippets.json") snippets2 = BaseExtractor.load(cache_folder / "test_BaseSnippets.json") - snippets3 = load_extractor(cache_folder / "test_BaseSnippets.json") + snippets3 = load(cache_folder / "test_BaseSnippets.json") # dump/load pickle snippets.dump_to_pickle(cache_folder / "test_BaseSnippets.pkl") snippets2 = BaseExtractor.load(cache_folder / "test_BaseSnippets.pkl") - snippets3 = load_extractor(cache_folder / "test_BaseSnippets.pkl") + snippets3 = load(cache_folder / "test_BaseSnippets.pkl") # dump/load dict - relative d = snippets.to_dict(relative_to=cache_folder, recursive=True) snippets2 = BaseExtractor.from_dict(d, base_folder=cache_folder) - snippets3 = load_extractor(d, base_folder=cache_folder) + snippets3 = load(d, base_folder=cache_folder) # dump/load json snippets.dump_to_json(cache_folder / "test_BaseSnippets_rel.json", relative_to=cache_folder) snippets2 = BaseExtractor.load(cache_folder / "test_BaseSnippets_rel.json", base_folder=cache_folder) - snippets3 = load_extractor(cache_folder / "test_BaseSnippets_rel.json", base_folder=cache_folder) + snippets3 = load(cache_folder / "test_BaseSnippets_rel.json", base_folder=cache_folder) # cache to npy folder = cache_folder / "simple_snippets" @@ -156,7 +156,7 @@ def test_BaseSnippets(create_cache_folder): # test save with probe folder = cache_folder / "simple_snippets3" snippets2 = snippets_p.save(folder=folder) - snippets2 = load_extractor(folder) + snippets2 = load(folder) probe2 = snippets2.get_probe() assert np.array_equal(probe2.contact_positions, [[0, 30.0], [0.0, 0.0]]) positions2 = snippets_p.get_channel_locations() diff --git a/src/spikeinterface/core/tests/test_basesorting.py b/src/spikeinterface/core/tests/test_basesorting.py index 42fdf52eb1..557617ae12 100644 --- a/src/spikeinterface/core/tests/test_basesorting.py +++ b/src/spikeinterface/core/tests/test_basesorting.py @@ -19,7 +19,7 @@ NumpyFolderSorting, create_sorting_npz, generate_sorting, - load_extractor, + load, ) from spikeinterface.core.base import BaseExtractor from spikeinterface.core.testing import check_sorted_arrays_equal, check_sortings_equal @@ -51,21 +51,21 @@ def test_BaseSorting(create_cache_folder): # dump/load dict d = sorting.to_dict(include_annotations=True, include_properties=True) sorting2 = BaseExtractor.from_dict(d) - sorting3 = load_extractor(d) + sorting3 = load(d) check_sortings_equal(sorting, sorting2, check_annotations=True, check_properties=True) check_sortings_equal(sorting, sorting3, check_annotations=True, check_properties=True) # dump/load json sorting.dump_to_json(cache_folder / "test_BaseSorting.json") sorting2 = BaseExtractor.load(cache_folder / "test_BaseSorting.json") - sorting3 = load_extractor(cache_folder / "test_BaseSorting.json") + sorting3 = load(cache_folder / "test_BaseSorting.json") check_sortings_equal(sorting, sorting2, check_annotations=True, check_properties=False) check_sortings_equal(sorting, sorting3, check_annotations=True, check_properties=False) # dump/load pickle sorting.dump_to_pickle(cache_folder / "test_BaseSorting.pkl") sorting2 = BaseExtractor.load(cache_folder / "test_BaseSorting.pkl") - sorting3 = load_extractor(cache_folder / "test_BaseSorting.pkl") + sorting3 = load(cache_folder / "test_BaseSorting.pkl") check_sortings_equal(sorting, sorting2, check_annotations=True, check_properties=True) check_sortings_equal(sorting, sorting3, check_annotations=True, check_properties=True) @@ -122,7 +122,7 @@ def test_BaseSorting(create_cache_folder): sorting4 = sorting.to_numpy_sorting() sorting5 = sorting.to_multiprocessing(n_jobs=2) # create a clone with the same share mem buffer - sorting6 = load_extractor(sorting5.to_dict()) + sorting6 = load(sorting5.to_dict()) assert isinstance(sorting6, SharedMemorySorting) del sorting6 del sorting5 @@ -130,7 +130,7 @@ def test_BaseSorting(create_cache_folder): # test save to zarr # compressor = get_default_zarr_compressor() sorting_zarr = sorting.save(format="zarr", folder=cache_folder / "sorting") - sorting_zarr_loaded = load_extractor(cache_folder / "sorting.zarr") + sorting_zarr_loaded = load(cache_folder / "sorting.zarr") # annotations is False because Zarr adds compression ratios check_sortings_equal(sorting, sorting_zarr, check_annotations=False, check_properties=True) check_sortings_equal(sorting_zarr, sorting_zarr_loaded, check_annotations=False, check_properties=True) diff --git a/src/spikeinterface/core/tests/test_binaryfolder.py b/src/spikeinterface/core/tests/test_binaryfolder.py index 1e64afe4e4..049e613541 100644 --- a/src/spikeinterface/core/tests/test_binaryfolder.py +++ b/src/spikeinterface/core/tests/test_binaryfolder.py @@ -5,7 +5,7 @@ import numpy as np -from spikeinterface.core import BinaryFolderRecording, read_binary_folder, load_extractor +from spikeinterface.core import load from spikeinterface.core import generate_recording @@ -20,7 +20,7 @@ def test_BinaryFolderRecording(create_cache_folder): saved_rec = rec.save(folder=folder) print(saved_rec) - loaded_rec = load_extractor(folder) + loaded_rec = load(folder) print(loaded_rec) diff --git a/src/spikeinterface/core/tests/test_channelsaggregationrecording.py b/src/spikeinterface/core/tests/test_channelsaggregationrecording.py index 118b6092a9..99d6890dfd 100644 --- a/src/spikeinterface/core/tests/test_channelsaggregationrecording.py +++ b/src/spikeinterface/core/tests/test_channelsaggregationrecording.py @@ -38,10 +38,12 @@ def test_channelsaggregationrecording(): assert np.allclose(traces1_1, recording_agg.get_traces(channel_ids=[str(channel_ids[1])], segment_index=seg)) assert np.allclose( - traces2_0, recording_agg.get_traces(channel_ids=[str(num_channels + channel_ids[0])], segment_index=seg) + traces2_0, + recording_agg.get_traces(channel_ids=[str(num_channels + int(channel_ids[0]))], segment_index=seg), ) assert np.allclose( - traces3_2, recording_agg.get_traces(channel_ids=[str(2 * num_channels + channel_ids[2])], segment_index=seg) + traces3_2, + recording_agg.get_traces(channel_ids=[str(2 * num_channels + int(channel_ids[2]))], segment_index=seg), ) # all traces traces1 = recording1.get_traces(segment_index=seg) diff --git a/src/spikeinterface/core/tests/test_generate.py b/src/spikeinterface/core/tests/test_generate.py index cb7debf3e0..3f067c7cf8 100644 --- a/src/spikeinterface/core/tests/test_generate.py +++ b/src/spikeinterface/core/tests/test_generate.py @@ -3,7 +3,7 @@ import numpy as np -from spikeinterface.core import load_extractor +from spikeinterface.core import load from probeinterface import generate_multi_columns_probe from spikeinterface.core.generate import ( @@ -363,7 +363,7 @@ def test_noise_generator_consistency_after_dump(strategy, seed): ) traces0 = rec0.get_traces() - rec1 = load_extractor(rec0.to_dict()) + rec1 = load(rec0.to_dict()) traces1 = rec1.get_traces() assert np.allclose(traces0, traces1) @@ -545,7 +545,7 @@ def test_inject_templates(): assert rec.get_traces(start_frame=rec_noise.get_num_frames(0) - 200, segment_index=0).shape == (200, 4) # Check dumpability - saved_loaded = load_extractor(rec.to_dict()) + saved_loaded = load(rec.to_dict()) check_recordings_equal(rec, saved_loaded, return_scaled=False) diff --git a/src/spikeinterface/core/tests/test_globals.py b/src/spikeinterface/core/tests/test_globals.py index 9677378fc5..cc8ff10075 100644 --- a/src/spikeinterface/core/tests/test_globals.py +++ b/src/spikeinterface/core/tests/test_globals.py @@ -36,30 +36,38 @@ def test_global_tmp_folder(create_cache_folder): def test_global_job_kwargs(): - job_kwargs = dict(n_jobs=4, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1) + job_kwargs = dict( + pool_engine="thread", + n_jobs=4, + chunk_duration="1s", + progress_bar=True, + mp_context=None, + max_threads_per_worker=1, + ) global_job_kwargs = get_global_job_kwargs() - # test warning when not setting n_jobs and calling fix_job_kwargs - with pytest.warns(UserWarning): - job_kwargs_split = fix_job_kwargs({}) - assert global_job_kwargs == dict( - n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1 + pool_engine="thread", + n_jobs=1, + chunk_duration="1s", + progress_bar=True, + mp_context=None, + max_threads_per_worker=1, ) set_global_job_kwargs(**job_kwargs) assert get_global_job_kwargs() == job_kwargs - # after setting global job kwargs, fix_job_kwargs should not raise a warning - with warnings.catch_warnings(): - warnings.simplefilter("error") - job_kwargs_split = fix_job_kwargs({}) - # test updating only one field partial_job_kwargs = dict(n_jobs=2) set_global_job_kwargs(**partial_job_kwargs) global_job_kwargs = get_global_job_kwargs() assert global_job_kwargs == dict( - n_jobs=2, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1 + pool_engine="thread", + n_jobs=2, + chunk_duration="1s", + progress_bar=True, + mp_context=None, + max_threads_per_worker=1, ) # test that fix_job_kwargs grabs global kwargs new_job_kwargs = dict(n_jobs=cpu_count()) @@ -80,6 +88,6 @@ def test_global_job_kwargs(): if __name__ == "__main__": - test_global_dataset_folder() - test_global_tmp_folder() + # test_global_dataset_folder() + # test_global_tmp_folder() test_global_job_kwargs() diff --git a/src/spikeinterface/core/tests/test_job_tools.py b/src/spikeinterface/core/tests/test_job_tools.py index 2f3aff0023..b0c169890c 100644 --- a/src/spikeinterface/core/tests/test_job_tools.py +++ b/src/spikeinterface/core/tests/test_job_tools.py @@ -1,7 +1,9 @@ import pytest import os -from spikeinterface.core import generate_recording, set_global_job_kwargs, get_global_job_kwargs +import time + +from spikeinterface.core import generate_recording, set_global_job_kwargs, get_global_job_kwargs, get_best_job_kwargs from spikeinterface.core.job_tools import ( divide_segment_into_chunks, @@ -77,28 +79,25 @@ def test_ensure_chunk_size(): assert end_frame == recording.get_num_frames(segment_index=segment_index) -def func(segment_index, start_frame, end_frame, worker_ctx): +def func(segment_index, start_frame, end_frame, worker_dict): import os - import time - #  print('func', segment_index, start_frame, end_frame, worker_ctx, os.getpid()) + #  print('func', segment_index, start_frame, end_frame, worker_dict, os.getpid()) time.sleep(0.010) # time.sleep(1.0) return os.getpid() def init_func(arg1, arg2, arg3): - worker_ctx = {} - worker_ctx["arg1"] = arg1 - worker_ctx["arg2"] = arg2 - worker_ctx["arg3"] = arg3 - return worker_ctx + worker_dict = {} + worker_dict["arg1"] = arg1 + worker_dict["arg2"] = arg2 + worker_dict["arg3"] = arg3 + return worker_dict def test_ChunkRecordingExecutor(): recording = generate_recording(num_channels=2) - # make serializable - recording = recording.save() init_args = "a", 120, "yep" @@ -139,7 +138,7 @@ def __call__(self, res): gathering_func2 = GatherClass() - # chunk + parallel + gather_func + # process + gather_func processor = ChunkRecordingExecutor( recording, func, @@ -148,6 +147,7 @@ def __call__(self, res): verbose=True, progress_bar=True, gather_func=gathering_func2, + pool_engine="process", n_jobs=2, chunk_duration="200ms", job_name="job_name", @@ -157,7 +157,7 @@ def __call__(self, res): assert gathering_func2.pos == num_chunks - # chunk + parallel + spawn + # process spawn processor = ChunkRecordingExecutor( recording, func, @@ -165,6 +165,7 @@ def __call__(self, res): init_args, verbose=True, progress_bar=True, + pool_engine="process", mp_context="spawn", n_jobs=2, chunk_duration="200ms", @@ -172,6 +173,21 @@ def __call__(self, res): ) processor.run() + # thread + processor = ChunkRecordingExecutor( + recording, + func, + init_func, + init_args, + verbose=True, + progress_bar=True, + pool_engine="thread", + n_jobs=2, + chunk_duration="200ms", + job_name="job_name", + ) + processor.run() + def test_fix_job_kwargs(): # test negative n_jobs @@ -220,10 +236,94 @@ def test_split_job_kwargs(): assert "other_param" not in job_kwargs and "n_jobs" in job_kwargs and "progress_bar" in job_kwargs +def func2(segment_index, start_frame, end_frame, worker_dict): + time.sleep(0.010) + # print(os.getpid(), worker_dict["worker_index"]) + return worker_dict["worker_index"] + + +def init_func2(): + # this leave time for other thread/process to start + time.sleep(0.010) + worker_dict = {} + return worker_dict + + +def test_worker_index(): + recording = generate_recording(num_channels=2) + init_args = tuple() + + for i in range(2): + # making this 2 times ensure to test that global variables are correctly reset + for pool_engine in ("process", "thread"): + processor = ChunkRecordingExecutor( + recording, + func2, + init_func2, + init_args, + progress_bar=False, + gather_func=None, + pool_engine=pool_engine, + n_jobs=2, + handle_returns=True, + chunk_duration="200ms", + need_worker_index=True, + ) + res = processor.run() + # we should have a mix of 0 and 1 + assert 0 in res + assert 1 in res + + +def test_get_best_job_kwargs(): + job_kwargs = get_best_job_kwargs() + print(job_kwargs) + + +# def quick_becnhmark(): +# # keep this commented do not remove + +# from spikeinterface.generation import generate_drifting_recording +# from spikeinterface.sortingcomponents.peak_detection import detect_peaks +# from spikeinterface import get_noise_levels +# import time + +# all_job_kwargs = [ +# dict(pool_engine="process", n_jobs=2, mp_context="spawn", max_threads_per_worker=2), +# dict(pool_engine="process", n_jobs=4, mp_context="spawn", max_threads_per_worker=1), +# dict(pool_engine="thread", n_jobs=4, mp_context=None, max_threads_per_worker=1), +# dict(pool_engine="thread", n_jobs=2, mp_context=None, max_threads_per_worker=2), +# dict(n_jobs=1), +# ] + + +# rec, _, sorting = generate_drifting_recording( +# num_units=50, +# duration=120.0, +# sampling_frequency=30000.0, +# probe_name="Neuropixel-128", + +# ) +# # print(rec) + +# noise_levels = get_noise_levels(rec, return_scaled=False) +# for job_kwargs in all_job_kwargs: +# print() +# print(job_kwargs) +# t0 = time.perf_counter() +# peaks = detect_peaks(rec, method="locally_exclusive", noise_levels=noise_levels, **job_kwargs) +# t1 = time.perf_counter() +# print("time included the spawn:", t1-t0) + + if __name__ == "__main__": # test_divide_segment_into_chunks() # test_ensure_n_jobs() # test_ensure_chunk_size() # test_ChunkRecordingExecutor() - test_fix_job_kwargs() + # test_fix_job_kwargs() # test_split_job_kwargs() + # test_worker_index() + test_get_best_job_kwargs() + + # quick_becnhmark() diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index 8d788acbad..028eaecf12 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -4,7 +4,7 @@ import shutil from spikeinterface import create_sorting_analyzer, get_template_extremum_channel, generate_ground_truth_recording - +from spikeinterface.core.job_tools import divide_recording_into_chunks # from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.core.node_pipeline import ( @@ -83,8 +83,12 @@ def test_run_node_pipeline(cache_folder_creation): extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, peak_sign="neg", outputs="index") peaks = sorting_to_peaks(sorting, extremum_channel_inds, spike_peak_dtype) + # print(peaks.size) peak_retriever = PeakRetriever(recording, peaks) + # this test when no spikes in last chunks + peak_retriever_few = PeakRetriever(recording, peaks[: peaks.size // 2]) + # channel index is from template spike_retriever_T = SpikeRetriever( sorting, recording, channel_from_template=True, extremum_channel_inds=extremum_channel_inds @@ -100,7 +104,7 @@ def test_run_node_pipeline(cache_folder_creation): ) # test with 3 differents first nodes - for loop, peak_source in enumerate((peak_retriever, spike_retriever_T, spike_retriever_S)): + for loop, peak_source in enumerate((peak_retriever, peak_retriever_few, spike_retriever_T, spike_retriever_S)): # one step only : squeeze output nodes = [ peak_source, @@ -139,10 +143,12 @@ def test_run_node_pipeline(cache_folder_creation): num_peaks = peaks.shape[0] num_channels = recording.get_num_channels() - assert waveforms_rms.shape[0] == num_peaks + if peak_source != peak_retriever_few: + assert waveforms_rms.shape[0] == num_peaks assert waveforms_rms.shape[1] == num_channels - assert waveforms_rms.shape[0] == num_peaks + if peak_source != peak_retriever_few: + assert waveforms_rms.shape[0] == num_peaks assert waveforms_rms.shape[1] == num_channels # gather npy mode @@ -185,5 +191,47 @@ def test_run_node_pipeline(cache_folder_creation): unpickled_node = pickle.loads(pickled_node) +def test_skip_after_n_peaks_and_recording_slices(): + recording, sorting = generate_ground_truth_recording(num_channels=10, num_units=10, durations=[10.0], seed=2205) + + # job_kwargs = dict(chunk_duration="0.5s", n_jobs=2, progress_bar=False) + job_kwargs = dict(chunk_duration="0.5s", n_jobs=1, progress_bar=False) + + spikes = sorting.to_spike_vector() + + # create peaks from spikes + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory") + sorting_analyzer.compute(["random_spikes", "templates"], **job_kwargs) + extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, peak_sign="neg", outputs="index") + + peaks = sorting_to_peaks(sorting, extremum_channel_inds, spike_peak_dtype) + # print(peaks.size) + + node0 = PeakRetriever(recording, peaks) + node1 = AmplitudeExtractionNode(recording, parents=[node0], param0=6.6, return_output=True) + nodes = [node0, node1] + + # skip + skip_after_n_peaks = 30 + some_amplitudes = run_node_pipeline( + recording, nodes, job_kwargs, gather_mode="memory", skip_after_n_peaks=skip_after_n_peaks + ) + assert some_amplitudes.size >= skip_after_n_peaks + assert some_amplitudes.size < spikes.size + + # slices : 1 every 4 + recording_slices = divide_recording_into_chunks(recording, 10_000) + recording_slices = recording_slices[::4] + some_amplitudes = run_node_pipeline( + recording, nodes, job_kwargs, gather_mode="memory", recording_slices=recording_slices + ) + tolerance = 1.2 + assert some_amplitudes.size < (spikes.size // 4) * tolerance + + +# the following is for testing locally with python or ipython. It is not used in ci or with pytest. if __name__ == "__main__": - test_run_node_pipeline() + # folder = Path("./cache_folder/core") + # test_run_node_pipeline(folder) + + test_skip_after_n_peaks_and_recording_slices() diff --git a/src/spikeinterface/core/tests/test_npyfoldersnippets.py b/src/spikeinterface/core/tests/test_npyfoldersnippets.py index c0d7f303bf..ebf56e3985 100644 --- a/src/spikeinterface/core/tests/test_npyfoldersnippets.py +++ b/src/spikeinterface/core/tests/test_npyfoldersnippets.py @@ -3,7 +3,7 @@ from pathlib import Path import shutil -from spikeinterface.core import load_extractor +from spikeinterface.core import load from spikeinterface.core import generate_snippets @@ -25,7 +25,7 @@ def test_NpyFolderSnippets(cache_folder_creation): saved_snippets = snippets.save(folder=folder) print(snippets) - loaded_snippets = load_extractor(folder) + loaded_snippets = load(folder) print(loaded_snippets) diff --git a/src/spikeinterface/core/tests/test_numpy_extractors.py b/src/spikeinterface/core/tests/test_numpy_extractors.py index fecafb8989..21bc1b7879 100644 --- a/src/spikeinterface/core/tests/test_numpy_extractors.py +++ b/src/spikeinterface/core/tests/test_numpy_extractors.py @@ -9,7 +9,7 @@ SharedMemorySorting, NumpyEvent, create_sorting_npz, - load_extractor, + load, NpzSortingExtractor, generate_recording, ) @@ -41,7 +41,7 @@ def test_SharedMemoryRecording(): rec = SharedMemoryRecording.from_recording(rec0, **job_kwargs) d = rec.to_dict() - rec_clone = load_extractor(d) + rec_clone = load(d) traces = rec_clone.get_traces(start_frame=0, end_frame=30000, segment_index=0) assert rec.shms[0].name == rec_clone.shms[0].name @@ -87,7 +87,7 @@ def test_NumpySorting(setup_NumpyRecording): # print(sorting) # construct back from kwargs keep the same array - sorting2 = load_extractor(sorting.to_dict()) + sorting2 = load(sorting.to_dict()) assert np.shares_memory(sorting2._cached_spike_vector, sorting._cached_spike_vector) @@ -109,7 +109,7 @@ def test_SharedMemorySorting(): # print(sorting.to_spike_vector()) d = sorting.to_dict() - sorting_reload = load_extractor(d) + sorting_reload = load(d) # print(sorting_reload) # print(sorting_reload.to_spike_vector()) diff --git a/src/spikeinterface/core/tests/test_recording_tools.py b/src/spikeinterface/core/tests/test_recording_tools.py index 23a1574f2a..dad5273f12 100644 --- a/src/spikeinterface/core/tests/test_recording_tools.py +++ b/src/spikeinterface/core/tests/test_recording_tools.py @@ -11,6 +11,7 @@ from spikeinterface.core.recording_tools import ( write_binary_recording, write_memory_recording, + get_random_recording_slices, get_random_data_chunks, get_chunk_with_margin, get_closest_channels, @@ -167,6 +168,17 @@ def test_write_memory_recording(): shm.unlink() +def test_get_random_recording_slices(): + rec = generate_recording(num_channels=1, sampling_frequency=1000.0, durations=[10.0, 20.0]) + rec_slices = get_random_recording_slices( + rec, method="full_random", num_chunks_per_segment=20, chunk_duration="500ms", margin_frames=0, seed=0 + ) + assert len(rec_slices) == 40 + for seg_ind, start, stop in rec_slices: + assert stop - start == 500 + assert seg_ind in (0, 1) + + def test_get_random_data_chunks(): rec = generate_recording(num_channels=1, sampling_frequency=1000.0, durations=[10.0, 20.0]) chunks = get_random_data_chunks(rec, num_chunks_per_segment=50, chunk_size=500, seed=0) @@ -182,16 +194,17 @@ def test_get_closest_channels(): def test_get_noise_levels(): + job_kwargs = dict(n_jobs=1, progress_bar=True) rec = generate_recording(num_channels=2, sampling_frequency=1000.0, durations=[60.0]) - noise_levels_1 = get_noise_levels(rec, return_scaled=False) - noise_levels_2 = get_noise_levels(rec, return_scaled=False) + noise_levels_1 = get_noise_levels(rec, return_scaled=False, **job_kwargs) + noise_levels_2 = get_noise_levels(rec, return_scaled=False, **job_kwargs) rec.set_channel_gains(0.1) rec.set_channel_offsets(0) - noise_levels = get_noise_levels(rec, return_scaled=True, force_recompute=True) + noise_levels = get_noise_levels(rec, return_scaled=True, force_recompute=True, **job_kwargs) - noise_levels = get_noise_levels(rec, return_scaled=True, method="std") + noise_levels = get_noise_levels(rec, return_scaled=True, method="std", **job_kwargs) # Generate a recording following a gaussian distribution to check the result of get_noise. std = 6.0 @@ -201,8 +214,10 @@ def test_get_noise_levels(): recording = NumpyRecording(traces, 30000) assert np.all(noise_levels_1 == noise_levels_2) - assert np.allclose(get_noise_levels(recording, return_scaled=False), [std, std], rtol=1e-2, atol=1e-3) - assert np.allclose(get_noise_levels(recording, method="std", return_scaled=False), [std, std], rtol=1e-2, atol=1e-3) + assert np.allclose(get_noise_levels(recording, return_scaled=False, **job_kwargs), [std, std], rtol=1e-2, atol=1e-3) + assert np.allclose( + get_noise_levels(recording, method="std", return_scaled=False, **job_kwargs), [std, std], rtol=1e-2, atol=1e-3 + ) def test_get_noise_levels_output(): @@ -216,10 +231,21 @@ def test_get_noise_levels_output(): traces = rng.normal(loc=10.0, scale=std, size=(num_samples, num_channels)) recording = NumpyRecording(traces_list=traces, sampling_frequency=sampling_frequency) - std_estimated_with_mad = get_noise_levels(recording, method="mad", return_scaled=False, chunk_size=1_000) + std_estimated_with_mad = get_noise_levels( + recording, + method="mad", + return_scaled=False, + random_slices_kwargs=dict(num_chunks_per_segment=40, chunk_size=1_000, seed=seed), + ) + print(std_estimated_with_mad) assert np.allclose(std_estimated_with_mad, [std, std], rtol=1e-2, atol=1e-3) - std_estimated_with_std = get_noise_levels(recording, method="std", return_scaled=False, chunk_size=1_000) + std_estimated_with_std = get_noise_levels( + recording, + method="std", + return_scaled=False, + random_slices_kwargs=dict(num_chunks_per_segment=40, chunk_size=1_000, seed=seed), + ) assert np.allclose(std_estimated_with_std, [std, std], rtol=1e-2, atol=1e-3) @@ -333,14 +359,16 @@ def test_do_recording_attributes_match(): if __name__ == "__main__": # Create a temporary folder using the standard library - import tempfile - - with tempfile.TemporaryDirectory() as tmpdirname: - tmp_path = Path(tmpdirname) - test_write_binary_recording(tmp_path) - test_write_memory_recording() - - test_get_random_data_chunks() - test_get_closest_channels() - test_get_noise_levels() - test_order_channels_by_depth() + # import tempfile + + # with tempfile.TemporaryDirectory() as tmpdirname: + # tmp_path = Path(tmpdirname) + # test_write_binary_recording(tmp_path) + # test_write_memory_recording() + + test_get_random_recording_slices() + # test_get_random_data_chunks() + # test_get_closest_channels() + # test_get_noise_levels() + # test_get_noise_levels_output() + # test_order_channels_by_depth() diff --git a/src/spikeinterface/core/tests/test_sorting_tools.py b/src/spikeinterface/core/tests/test_sorting_tools.py index 34bb3a221d..7d26773ac3 100644 --- a/src/spikeinterface/core/tests/test_sorting_tools.py +++ b/src/spikeinterface/core/tests/test_sorting_tools.py @@ -162,8 +162,8 @@ def test_generate_unit_ids_for_merge_group(): if __name__ == "__main__": # test_spike_vector_to_spike_trains() # test_spike_vector_to_indices() - # test_random_spikes_selection() + test_random_spikes_selection() - test_apply_merges_to_sorting() - test_get_ids_after_merging() + # test_apply_merges_to_sorting() + # test_get_ids_after_merging() # test_generate_unit_ids_for_merge_group() diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index 5c7e267cc6..15f089f784 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -10,6 +10,7 @@ load_sorting_analyzer, get_available_analyzer_extensions, get_default_analyzer_extension_params, + get_default_zarr_compressor, ) from spikeinterface.core.sortinganalyzer import ( register_result_extension, @@ -30,6 +31,14 @@ def get_dataset(): noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"), seed=2205, ) + + # TODO: the tests or the sorting analyzer make assumptions about the ids being integers + # So keeping this the way it was + integer_channel_ids = [int(id) for id in recording.get_channel_ids()] + integer_unit_ids = [int(id) for id in sorting.get_unit_ids()] + + recording = recording.rename_channels(new_channel_ids=integer_channel_ids) + sorting = sorting.rename_units(new_unit_ids=integer_unit_ids) return recording, sorting @@ -99,16 +108,25 @@ def test_SortingAnalyzer_zarr(tmp_path, dataset): recording, sorting = dataset folder = tmp_path / "test_SortingAnalyzer_zarr.zarr" - if folder.exists(): - shutil.rmtree(folder) + default_compressor = get_default_zarr_compressor() sorting_analyzer = create_sorting_analyzer( - sorting, recording, format="zarr", folder=folder, sparse=False, sparsity=None + sorting, recording, format="zarr", folder=folder, sparse=False, sparsity=None, overwrite=True ) sorting_analyzer.compute(["random_spikes", "templates"]) sorting_analyzer = load_sorting_analyzer(folder, format="auto") _check_sorting_analyzers(sorting_analyzer, sorting, cache_folder=tmp_path) + # check that compression is applied + assert ( + sorting_analyzer._get_zarr_root()["extensions"]["random_spikes"]["random_spikes_indices"].compressor.codec_id + == default_compressor.codec_id + ) + assert ( + sorting_analyzer._get_zarr_root()["extensions"]["templates"]["average"].compressor.codec_id + == default_compressor.codec_id + ) + # test select_units see https://github.com/SpikeInterface/spikeinterface/issues/3041 # this bug requires that we have an info.json file so we calculate templates above select_units_sorting_analyer = sorting_analyzer.select_units(unit_ids=[1]) @@ -117,11 +135,45 @@ def test_SortingAnalyzer_zarr(tmp_path, dataset): assert len(remove_units_sorting_analyer.unit_ids) == len(sorting_analyzer.unit_ids) - 1 assert 1 not in remove_units_sorting_analyer.unit_ids - folder = tmp_path / "test_SortingAnalyzer_zarr.zarr" - if folder.exists(): - shutil.rmtree(folder) - sorting_analyzer = create_sorting_analyzer( - sorting, recording, format="zarr", folder=folder, sparse=False, sparsity=None, return_scaled=False + # test no compression + sorting_analyzer_no_compression = create_sorting_analyzer( + sorting, + recording, + format="zarr", + folder=folder, + sparse=False, + sparsity=None, + return_scaled=False, + overwrite=True, + backend_options={"saving_options": {"compressor": None}}, + ) + print(sorting_analyzer_no_compression._backend_options) + sorting_analyzer_no_compression.compute(["random_spikes", "templates"]) + assert ( + sorting_analyzer_no_compression._get_zarr_root()["extensions"]["random_spikes"][ + "random_spikes_indices" + ].compressor + is None + ) + assert sorting_analyzer_no_compression._get_zarr_root()["extensions"]["templates"]["average"].compressor is None + + # test a different compressor + from numcodecs import LZMA + + lzma_compressor = LZMA() + folder = tmp_path / "test_SortingAnalyzer_zarr_lzma.zarr" + sorting_analyzer_lzma = sorting_analyzer_no_compression.save_as( + format="zarr", folder=folder, backend_options={"saving_options": {"compressor": lzma_compressor}} + ) + assert ( + sorting_analyzer_lzma._get_zarr_root()["extensions"]["random_spikes"][ + "random_spikes_indices" + ].compressor.codec_id + == LZMA.codec_id + ) + assert ( + sorting_analyzer_lzma._get_zarr_root()["extensions"]["templates"]["average"].compressor.codec_id + == LZMA.codec_id ) @@ -326,7 +378,7 @@ def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder): else: folder = None sorting_analyzer5 = sorting_analyzer.merge_units( - merge_unit_groups=[[0, 1]], new_unit_ids=[50], format=format, folder=folder, mode="hard" + merge_unit_groups=[[0, 1]], new_unit_ids=[50], format=format, folder=folder, merging_mode="hard" ) # test compute with extension-specific params diff --git a/src/spikeinterface/core/tests/test_time_handling.py b/src/spikeinterface/core/tests/test_time_handling.py index a129316ee7..ffdb121316 100644 --- a/src/spikeinterface/core/tests/test_time_handling.py +++ b/src/spikeinterface/core/tests/test_time_handling.py @@ -15,7 +15,10 @@ class TestTimeHandling: is generated on the fly. Both time representations are tested here. """ - # Fixtures ##### + # ######################################################################### + # Fixtures + # ######################################################################### + @pytest.fixture(scope="session") def time_vector_recording(self): """ @@ -95,7 +98,10 @@ def _get_fixture_data(self, request, fixture_name): raw_recording, times_recording, all_times = time_recording_fixture return (raw_recording, times_recording, all_times) - # Tests ##### + # ######################################################################### + # Tests + # ######################################################################### + def test_has_time_vector(self, time_vector_recording): """ Test the `has_time_vector` function returns `False` before @@ -122,7 +128,7 @@ def test_times_propagated_to_save_folder(self, request, fixture_name, mode, tmp_ if mode == "zarr": folder_name += ".zarr" - recording_load = si.load_extractor(tmp_path / folder_name) + recording_load = si.load(tmp_path / folder_name) self._check_times_match(recording_cache, all_times) self._check_times_match(recording_load, all_times) @@ -305,7 +311,87 @@ def test_sorting_analyzer_get_durations_no_recording(self, time_vector_recording assert np.array_equal(sorting_analyzer.get_total_duration(), raw_recording.get_total_duration()) - # Helpers #### + @pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) + @pytest.mark.parametrize("shift", [-123.456, 123.456]) + def test_shift_time_all_segments(self, request, fixture_name, shift): + """ + Shift the times in every segment using the `None` default, then + check that every segment of the recording is shifted as expected. + """ + _, times_recording, all_times = self._get_fixture_data(request, fixture_name) + + num_segments, orig_seg_data = self._store_all_times(times_recording) + + times_recording.shift_times(shift) # use default `segment_index=None` + + for idx in range(num_segments): + assert np.allclose( + orig_seg_data[idx], times_recording.get_times(segment_index=idx) - shift, rtol=0, atol=1e-8 + ) + + @pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) + @pytest.mark.parametrize("shift", [-123.456, 123.456]) + def test_shift_times_different_segments(self, request, fixture_name, shift): + """ + Shift each segment separately, and check the shifted segment only + is shifted as expected. + """ + _, times_recording, all_times = self._get_fixture_data(request, fixture_name) + + num_segments, orig_seg_data = self._store_all_times(times_recording) + + # For each segment, shift the segment only and check the + # times are updated as expected. + for idx in range(num_segments): + + scaler = idx + 2 + times_recording.shift_times(shift * scaler, segment_index=idx) + + assert np.allclose( + orig_seg_data[idx], times_recording.get_times(segment_index=idx) - shift * scaler, rtol=0, atol=1e-8 + ) + + # Just do a little check that we are not + # accidentally changing some other segments, + # which should remain unchanged at this point in the loop. + if idx != num_segments - 1: + assert np.array_equal(orig_seg_data[idx + 1], times_recording.get_times(segment_index=idx + 1)) + + @pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) + def test_save_and_load_time_shift(self, request, fixture_name, tmp_path): + """ + Save the shifted data and check the shift is propagated correctly. + """ + _, times_recording, all_times = self._get_fixture_data(request, fixture_name) + + shift = 100 + times_recording.shift_times(shift=shift) + + times_recording.save(folder=tmp_path / "my_file") + + loaded_recording = si.load(tmp_path / "my_file") + + for idx in range(times_recording.get_num_segments()): + assert np.array_equal( + times_recording.get_times(segment_index=idx), loaded_recording.get_times(segment_index=idx) + ) + + def _store_all_times(self, recording): + """ + Convenience function to store original times of all segments to a dict. + """ + num_segments = recording.get_num_segments() + seg_data = {} + + for idx in range(num_segments): + seg_data[idx] = copy.deepcopy(recording.get_times(segment_index=idx)) + + return num_segments, seg_data + + # ######################################################################### + # Helpers + # ######################################################################### + def _check_times_match(self, recording, all_times): """ For every segment in a recording, check the `get_times()` diff --git a/src/spikeinterface/core/tests/test_unitsselectionsorting.py b/src/spikeinterface/core/tests/test_unitsselectionsorting.py index 1e72b0ab28..3ecb702aa2 100644 --- a/src/spikeinterface/core/tests/test_unitsselectionsorting.py +++ b/src/spikeinterface/core/tests/test_unitsselectionsorting.py @@ -10,25 +10,29 @@ def test_basic_functions(): sorting = generate_sorting(num_units=3, durations=[0.100, 0.100], sampling_frequency=30000.0) - sorting2 = UnitsSelectionSorting(sorting, unit_ids=[0, 2]) - assert np.array_equal(sorting2.unit_ids, [0, 2]) + sorting2 = UnitsSelectionSorting(sorting, unit_ids=["0", "2"]) + assert np.array_equal(sorting2.unit_ids, ["0", "2"]) assert sorting2.get_parent() == sorting - sorting3 = UnitsSelectionSorting(sorting, unit_ids=[0, 2], renamed_unit_ids=["a", "b"]) + sorting3 = UnitsSelectionSorting(sorting, unit_ids=["0", "2"], renamed_unit_ids=["a", "b"]) assert np.array_equal(sorting3.unit_ids, ["a", "b"]) assert np.array_equal( - sorting.get_unit_spike_train(0, segment_index=0), sorting2.get_unit_spike_train(0, segment_index=0) + sorting.get_unit_spike_train(unit_id="0", segment_index=0), + sorting2.get_unit_spike_train(unit_id="0", segment_index=0), ) assert np.array_equal( - sorting.get_unit_spike_train(0, segment_index=0), sorting3.get_unit_spike_train("a", segment_index=0) + sorting.get_unit_spike_train(unit_id="0", segment_index=0), + sorting3.get_unit_spike_train(unit_id="a", segment_index=0), ) assert np.array_equal( - sorting.get_unit_spike_train(2, segment_index=0), sorting2.get_unit_spike_train(2, segment_index=0) + sorting.get_unit_spike_train(unit_id="2", segment_index=0), + sorting2.get_unit_spike_train(unit_id="2", segment_index=0), ) assert np.array_equal( - sorting.get_unit_spike_train(2, segment_index=0), sorting3.get_unit_spike_train("b", segment_index=0) + sorting.get_unit_spike_train(unit_id="2", segment_index=0), + sorting3.get_unit_spike_train(unit_id="b", segment_index=0), ) @@ -36,13 +40,13 @@ def test_failure_with_non_unique_unit_ids(): seed = 10 sorting = generate_sorting(num_units=3, durations=[0.100], sampling_frequency=30000.0, seed=seed) with pytest.raises(AssertionError): - sorting2 = UnitsSelectionSorting(sorting, unit_ids=[0, 2], renamed_unit_ids=["a", "a"]) + sorting2 = UnitsSelectionSorting(sorting, unit_ids=["0", "2"], renamed_unit_ids=["a", "a"]) def test_custom_cache_spike_vector(): sorting = generate_sorting(num_units=3, durations=[0.100, 0.100], sampling_frequency=30000.0) - sub_sorting = UnitsSelectionSorting(sorting, unit_ids=[2, 0], renamed_unit_ids=["b", "a"]) + sub_sorting = UnitsSelectionSorting(sorting, unit_ids=["2", "0"], renamed_unit_ids=["b", "a"]) cached_spike_vector = sub_sorting.to_spike_vector(use_cache=True) computed_spike_vector = sub_sorting.to_spike_vector(use_cache=False) assert np.all(cached_spike_vector == computed_spike_vector) diff --git a/src/spikeinterface/core/tests/test_waveform_tools.py b/src/spikeinterface/core/tests/test_waveform_tools.py index 845eaf1310..a516e6d42b 100644 --- a/src/spikeinterface/core/tests/test_waveform_tools.py +++ b/src/spikeinterface/core/tests/test_waveform_tools.py @@ -173,21 +173,43 @@ def test_estimate_templates_with_accumulator(): job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") - templates = estimate_templates_with_accumulator( - recording, spikes, sorting.unit_ids, nbefore, nafter, return_scaled=True, **job_kwargs - ) - print(templates.shape) - assert templates.shape[0] == sorting.unit_ids.size - assert templates.shape[1] == nbefore + nafter - assert templates.shape[2] == recording.get_num_channels() - - assert np.any(templates != 0) - - # import matplotlib.pyplot as plt - # fig, ax = plt.subplots() - # for unit_index, unit_id in enumerate(sorting.unit_ids): - # ax.plot(templates[unit_index, :, :].T.flatten()) - # plt.show() + # here we compare the result with the same mechanism with with several worker pool size + # this means that that acumulator are splitted and then agglomerated back + # this should lead to very small diff + # n_jobs=1 is done in loop + templates_by_worker = [] + + if platform.system() == "Linux": + engine_loop = ["thread", "process"] + else: + engine_loop = ["thread"] + + for pool_engine in engine_loop: + for n_jobs in (1, 2, 8): + job_kwargs = dict(pool_engine=pool_engine, n_jobs=n_jobs, progress_bar=True, chunk_duration="1s") + templates = estimate_templates_with_accumulator( + recording, spikes, sorting.unit_ids, nbefore, nafter, return_scaled=True, **job_kwargs + ) + assert templates.shape[0] == sorting.unit_ids.size + assert templates.shape[1] == nbefore + nafter + assert templates.shape[2] == recording.get_num_channels() + assert np.any(templates != 0) + + templates_by_worker.append(templates) + if len(templates_by_worker) > 1: + templates_loop = templates_by_worker[0] + np.testing.assert_almost_equal(templates, templates_loop, decimal=4) + + # import matplotlib.pyplot as plt + # fig, axs = plt.subplots(nrows=2, sharex=True) + # for unit_index, unit_id in enumerate(sorting.unit_ids): + # ax = axs[0] + # ax.set_title(f"{pool_engine} {n_jobs}") + # ax.plot(templates[unit_index, :, :].T.flatten()) + # ax.plot(templates_loop[unit_index, :, :].T.flatten(), color="k", ls="--") + # ax = axs[1] + # ax.plot((templates - templates_loop)[unit_index, :, :].T.flatten(), color="k", ls="--") + # plt.show() def test_estimate_templates(): @@ -225,6 +247,6 @@ def test_estimate_templates(): if __name__ == "__main__": - test_waveform_tools() + # test_waveform_tools() test_estimate_templates_with_accumulator() - test_estimate_templates() + # test_estimate_templates() diff --git a/src/spikeinterface/core/tests/test_zarrextractors.py b/src/spikeinterface/core/tests/test_zarrextractors.py index 2fc1f42ec5..cc0c60721e 100644 --- a/src/spikeinterface/core/tests/test_zarrextractors.py +++ b/src/spikeinterface/core/tests/test_zarrextractors.py @@ -8,7 +8,7 @@ ZarrSortingExtractor, generate_recording, generate_sorting, - load_extractor, + load, ) from spikeinterface.core.zarrextractors import add_sorting_to_zarr_group, get_default_zarr_compressor @@ -63,7 +63,7 @@ def test_ZarrSortingExtractor(tmp_path): folder = tmp_path / "zarr_sorting" ZarrSortingExtractor.write_sorting(np_sorting, folder) sorting = ZarrSortingExtractor(folder) - sorting = load_extractor(sorting.to_dict()) + sorting = load(sorting.to_dict()) # store the sorting in a sub group (for instance SortingResult) folder = tmp_path / "zarr_sorting_sub_group" @@ -72,7 +72,7 @@ def test_ZarrSortingExtractor(tmp_path): add_sorting_to_zarr_group(sorting, zarr_sorting_group) sorting = ZarrSortingExtractor(folder, zarr_group="sorting") # and reaload - sorting = load_extractor(sorting.to_dict()) + sorting = load(sorting.to_dict()) if __name__ == "__main__": diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index 3affd7f0ec..76a1289711 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -296,17 +296,18 @@ def _init_worker_distribute_buffers( recording, unit_ids, spikes, arrays_info, nbefore, nafter, return_scaled, inds_by_unit, mode, sparsity_mask ): # create a local dict per worker - worker_ctx = {} + worker_dict = {} if isinstance(recording, dict): - from spikeinterface.core import load_extractor + from spikeinterface.core import load - recording = load_extractor(recording) - worker_ctx["recording"] = recording + recording = load(recording) + + worker_dict["recording"] = recording if mode == "memmap": # in memmap mode we have the "too many open file" problem with linux # memmap file will be open on demand and not globally per worker - worker_ctx["arrays_info"] = arrays_info + worker_dict["arrays_info"] = arrays_info elif mode == "shared_memory": from multiprocessing.shared_memory import SharedMemory @@ -321,33 +322,33 @@ def _init_worker_distribute_buffers( waveforms_by_units[unit_id] = arr # we need a reference to all sham otherwise we get segment fault!!! shms[unit_id] = shm - worker_ctx["shms"] = shms - worker_ctx["waveforms_by_units"] = waveforms_by_units + worker_dict["shms"] = shms + worker_dict["waveforms_by_units"] = waveforms_by_units - worker_ctx["unit_ids"] = unit_ids - worker_ctx["spikes"] = spikes + worker_dict["unit_ids"] = unit_ids + worker_dict["spikes"] = spikes - worker_ctx["nbefore"] = nbefore - worker_ctx["nafter"] = nafter - worker_ctx["return_scaled"] = return_scaled - worker_ctx["inds_by_unit"] = inds_by_unit - worker_ctx["sparsity_mask"] = sparsity_mask - worker_ctx["mode"] = mode + worker_dict["nbefore"] = nbefore + worker_dict["nafter"] = nafter + worker_dict["return_scaled"] = return_scaled + worker_dict["inds_by_unit"] = inds_by_unit + worker_dict["sparsity_mask"] = sparsity_mask + worker_dict["mode"] = mode - return worker_ctx + return worker_dict # used by ChunkRecordingExecutor -def _worker_distribute_buffers(segment_index, start_frame, end_frame, worker_ctx): +def _worker_distribute_buffers(segment_index, start_frame, end_frame, worker_dict): # recover variables of the worker - recording = worker_ctx["recording"] - unit_ids = worker_ctx["unit_ids"] - spikes = worker_ctx["spikes"] - nbefore = worker_ctx["nbefore"] - nafter = worker_ctx["nafter"] - return_scaled = worker_ctx["return_scaled"] - inds_by_unit = worker_ctx["inds_by_unit"] - sparsity_mask = worker_ctx["sparsity_mask"] + recording = worker_dict["recording"] + unit_ids = worker_dict["unit_ids"] + spikes = worker_dict["spikes"] + nbefore = worker_dict["nbefore"] + nafter = worker_dict["nafter"] + return_scaled = worker_dict["return_scaled"] + inds_by_unit = worker_dict["inds_by_unit"] + sparsity_mask = worker_dict["sparsity_mask"] seg_size = recording.get_num_samples(segment_index=segment_index) @@ -383,12 +384,12 @@ def _worker_distribute_buffers(segment_index, start_frame, end_frame, worker_ctx if in_chunk_pos.size == 0: continue - if worker_ctx["mode"] == "memmap": + if worker_dict["mode"] == "memmap": # open file in demand (and also autoclose it after) - filename = worker_ctx["arrays_info"][unit_id] + filename = worker_dict["arrays_info"][unit_id] wfs = np.load(str(filename), mmap_mode="r+") - elif worker_ctx["mode"] == "shared_memory": - wfs = worker_ctx["waveforms_by_units"][unit_id] + elif worker_dict["mode"] == "shared_memory": + wfs = worker_dict["waveforms_by_units"][unit_id] for pos in in_chunk_pos: sample_index = spikes[inds[pos]]["sample_index"] @@ -548,50 +549,50 @@ def extract_waveforms_to_single_buffer( def _init_worker_distribute_single_buffer( recording, spikes, wf_array_info, nbefore, nafter, return_scaled, mode, sparsity_mask ): - worker_ctx = {} - worker_ctx["recording"] = recording - worker_ctx["wf_array_info"] = wf_array_info - worker_ctx["spikes"] = spikes - worker_ctx["nbefore"] = nbefore - worker_ctx["nafter"] = nafter - worker_ctx["return_scaled"] = return_scaled - worker_ctx["sparsity_mask"] = sparsity_mask - worker_ctx["mode"] = mode + worker_dict = {} + worker_dict["recording"] = recording + worker_dict["wf_array_info"] = wf_array_info + worker_dict["spikes"] = spikes + worker_dict["nbefore"] = nbefore + worker_dict["nafter"] = nafter + worker_dict["return_scaled"] = return_scaled + worker_dict["sparsity_mask"] = sparsity_mask + worker_dict["mode"] = mode if mode == "memmap": filename = wf_array_info["filename"] all_waveforms = np.load(str(filename), mmap_mode="r+") - worker_ctx["all_waveforms"] = all_waveforms + worker_dict["all_waveforms"] = all_waveforms elif mode == "shared_memory": from multiprocessing.shared_memory import SharedMemory shm_name, dtype, shape = wf_array_info["shm_name"], wf_array_info["dtype"], wf_array_info["shape"] shm = SharedMemory(shm_name) all_waveforms = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf) - worker_ctx["shm"] = shm - worker_ctx["all_waveforms"] = all_waveforms + worker_dict["shm"] = shm + worker_dict["all_waveforms"] = all_waveforms # prepare segment slices segment_slices = [] for segment_index in range(recording.get_num_segments()): s0, s1 = np.searchsorted(spikes["segment_index"], [segment_index, segment_index + 1]) segment_slices.append((s0, s1)) - worker_ctx["segment_slices"] = segment_slices + worker_dict["segment_slices"] = segment_slices - return worker_ctx + return worker_dict # used by ChunkRecordingExecutor -def _worker_distribute_single_buffer(segment_index, start_frame, end_frame, worker_ctx): +def _worker_distribute_single_buffer(segment_index, start_frame, end_frame, worker_dict): # recover variables of the worker - recording = worker_ctx["recording"] - segment_slices = worker_ctx["segment_slices"] - spikes = worker_ctx["spikes"] - nbefore = worker_ctx["nbefore"] - nafter = worker_ctx["nafter"] - return_scaled = worker_ctx["return_scaled"] - sparsity_mask = worker_ctx["sparsity_mask"] - all_waveforms = worker_ctx["all_waveforms"] + recording = worker_dict["recording"] + segment_slices = worker_dict["segment_slices"] + spikes = worker_dict["spikes"] + nbefore = worker_dict["nbefore"] + nafter = worker_dict["nafter"] + return_scaled = worker_dict["return_scaled"] + sparsity_mask = worker_dict["sparsity_mask"] + all_waveforms = worker_dict["all_waveforms"] seg_size = recording.get_num_samples(segment_index=segment_index) @@ -630,7 +631,7 @@ def _worker_distribute_single_buffer(segment_index, start_frame, end_frame, work wf = wf[:, mask] all_waveforms[spike_index, :, : wf.shape[1]] = wf - if worker_ctx["mode"] == "memmap": + if worker_dict["mode"] == "memmap": all_waveforms.flush() @@ -843,12 +844,6 @@ def estimate_templates_with_accumulator( waveform_squared_accumulator_per_worker = None shm_squared_name = None - # trick to get the work_index given pid arrays - lock = multiprocessing.Lock() - array_pid = multiprocessing.Array("i", num_worker) - for i in range(num_worker): - array_pid[i] = -1 - func = _worker_estimate_templates init_func = _init_worker_estimate_templates @@ -862,14 +857,12 @@ def estimate_templates_with_accumulator( nbefore, nafter, return_scaled, - lock, - array_pid, ) if job_name is None: job_name = "estimate_templates_with_accumulator" processor = ChunkRecordingExecutor( - recording, func, init_func, init_args, job_name=job_name, verbose=verbose, **job_kwargs + recording, func, init_func, init_args, job_name=job_name, verbose=verbose, need_worker_index=True, **job_kwargs ) processor.run() @@ -920,15 +913,13 @@ def _init_worker_estimate_templates( nbefore, nafter, return_scaled, - lock, - array_pid, ): - worker_ctx = {} - worker_ctx["recording"] = recording - worker_ctx["spikes"] = spikes - worker_ctx["nbefore"] = nbefore - worker_ctx["nafter"] = nafter - worker_ctx["return_scaled"] = return_scaled + worker_dict = {} + worker_dict["recording"] = recording + worker_dict["spikes"] = spikes + worker_dict["nbefore"] = nbefore + worker_dict["nafter"] = nafter + worker_dict["return_scaled"] = return_scaled from multiprocessing.shared_memory import SharedMemory import multiprocessing @@ -936,48 +927,36 @@ def _init_worker_estimate_templates( shm = SharedMemory(shm_name) waveform_accumulator_per_worker = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf) - worker_ctx["shm"] = shm - worker_ctx["waveform_accumulator_per_worker"] = waveform_accumulator_per_worker + worker_dict["shm"] = shm + worker_dict["waveform_accumulator_per_worker"] = waveform_accumulator_per_worker if shm_squared_name is not None: shm_squared = SharedMemory(shm_squared_name) waveform_squared_accumulator_per_worker = np.ndarray(shape=shape, dtype=dtype, buffer=shm_squared.buf) - worker_ctx["shm_squared"] = shm_squared - worker_ctx["waveform_squared_accumulator_per_worker"] = waveform_squared_accumulator_per_worker + worker_dict["shm_squared"] = shm_squared + worker_dict["waveform_squared_accumulator_per_worker"] = waveform_squared_accumulator_per_worker # prepare segment slices segment_slices = [] for segment_index in range(recording.get_num_segments()): s0, s1 = np.searchsorted(spikes["segment_index"], [segment_index, segment_index + 1]) segment_slices.append((s0, s1)) - worker_ctx["segment_slices"] = segment_slices - - child_process = multiprocessing.current_process() - - lock.acquire() - num_worker = None - for i in range(len(array_pid)): - if array_pid[i] == -1: - num_worker = i - array_pid[i] = child_process.ident - break - worker_ctx["worker_index"] = num_worker - lock.release() + worker_dict["segment_slices"] = segment_slices - return worker_ctx + return worker_dict # used by ChunkRecordingExecutor -def _worker_estimate_templates(segment_index, start_frame, end_frame, worker_ctx): +def _worker_estimate_templates(segment_index, start_frame, end_frame, worker_dict): # recover variables of the worker - recording = worker_ctx["recording"] - segment_slices = worker_ctx["segment_slices"] - spikes = worker_ctx["spikes"] - nbefore = worker_ctx["nbefore"] - nafter = worker_ctx["nafter"] - waveform_accumulator_per_worker = worker_ctx["waveform_accumulator_per_worker"] - waveform_squared_accumulator_per_worker = worker_ctx.get("waveform_squared_accumulator_per_worker", None) - worker_index = worker_ctx["worker_index"] - return_scaled = worker_ctx["return_scaled"] + recording = worker_dict["recording"] + segment_slices = worker_dict["segment_slices"] + spikes = worker_dict["spikes"] + nbefore = worker_dict["nbefore"] + nafter = worker_dict["nafter"] + waveform_accumulator_per_worker = worker_dict["waveform_accumulator_per_worker"] + waveform_squared_accumulator_per_worker = worker_dict.get("waveform_squared_accumulator_per_worker", None) + worker_index = worker_dict["worker_index"] + return_scaled = worker_dict["return_scaled"] seg_size = recording.get_num_samples(segment_index=segment_index) diff --git a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py index 5c7584ecd8..ffe4755c75 100644 --- a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py +++ b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py @@ -23,7 +23,7 @@ from .job_tools import split_job_kwargs from .sparsity import ChannelSparsity from .sortinganalyzer import SortingAnalyzer, load_sorting_analyzer -from .base import load_extractor +from .loading import load from .analyzer_extension_core import ComputeRandomSpikes, ComputeWaveforms, ComputeTemplates _backwards_compatibility_msg = """#### @@ -475,21 +475,21 @@ def _read_old_waveforms_extractor_binary(folder, sorting): recording = None if (folder / "recording.json").exists(): try: - recording = load_extractor(folder / "recording.json", base_folder=folder) + recording = load(folder / "recording.json", base_folder=folder) except: pass elif (folder / "recording.pickle").exists(): try: - recording = load_extractor(folder / "recording.pickle", base_folder=folder) + recording = load(folder / "recording.pickle", base_folder=folder) except: pass # sorting if sorting is None: if (folder / "sorting.json").exists(): - sorting = load_extractor(folder / "sorting.json", base_folder=folder) + sorting = load(folder / "sorting.json", base_folder=folder) elif (folder / "sorting.pickle").exists(): - sorting = load_extractor(folder / "sorting.pickle", base_folder=folder) + sorting = load(folder / "sorting.pickle", base_folder=folder) sorting_analyzer = SortingAnalyzer.create_memory( sorting, recording, sparsity=sparsity, return_scaled=return_scaled, rec_attributes=rec_attributes @@ -676,7 +676,7 @@ def make_ext_params_up_to_date(ext, old_params, new_params): # recording = None # try: # recording_dict = waveforms_root.attrs["recording"] -# recording = load_extractor(recording_dict, base_folder=folder) +# recording = load(recording_dict, base_folder=folder) # except: # pass @@ -684,7 +684,7 @@ def make_ext_params_up_to_date(ext, old_params, new_params): # if sorting is None: # assert "sorting" in waveforms_root.attrs, "Could not load sorting object" # sorting_dict = waveforms_root.attrs["sorting"] -# sorting = load_extractor(sorting_dict, base_folder=folder) +# sorting = load(sorting_dict, base_folder=folder) # if "sparsity" in waveforms_root.attrs: # sparsity_dict = waveforms_root.attrs["sparsity"] diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index 17f1ac08b3..ff552dfb54 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -12,6 +12,19 @@ from .core_tools import define_function_from_class, check_json from .job_tools import split_job_kwargs from .recording_tools import determine_cast_unsigned +from .core_tools import is_path_remote + + +def anononymous_zarr_open(folder_path: str | Path, mode: str = "r", storage_options: dict | None = None): + if is_path_remote(str(folder_path)) and storage_options is None: + try: + root = zarr.open(str(folder_path), mode="r", storage_options=storage_options) + except Exception as e: + storage_options = {"anon": True} + root = zarr.open(str(folder_path), mode="r", storage_options=storage_options) + else: + root = zarr.open(str(folder_path), mode="r", storage_options=storage_options) + return root class ZarrRecordingExtractor(BaseRecording): @@ -21,7 +34,11 @@ class ZarrRecordingExtractor(BaseRecording): Parameters ---------- folder_path : str or Path - Path to the zarr root folder + Path to the zarr root folder. This can be a local path or a remote path (s3:// or gcs://). + If the path is a remote path, the storage_options can be provided to specify credentials. + If the remote path is not accessible and backend_options is not provided, + the function will try to load the object in anonymous mode (anon=True), + which enables to load data from open buckets. storage_options : dict or None Storage options for zarr `store`. E.g., if "s3://" or "gcs://" they can provide authentication methods, etc. @@ -35,7 +52,7 @@ def __init__(self, folder_path: Path | str, storage_options: dict | None = None) folder_path, folder_path_kwarg = resolve_zarr_path(folder_path) - self._root = zarr.open(str(folder_path), mode="r", storage_options=storage_options) + self._root = anononymous_zarr_open(folder_path, mode="r", storage_options=storage_options) sampling_frequency = self._root.attrs.get("sampling_frequency", None) num_segments = self._root.attrs.get("num_segments", None) @@ -81,7 +98,10 @@ def __init__(self, folder_path: Path | str, storage_options: dict | None = None) nbytes_segment = self._root[trace_name].nbytes nbytes_stored_segment = self._root[trace_name].nbytes_stored - cr_by_segment[segment_index] = nbytes_segment / nbytes_stored_segment + if nbytes_stored_segment > 0: + cr_by_segment[segment_index] = nbytes_segment / nbytes_stored_segment + else: + cr_by_segment[segment_index] = np.nan total_nbytes += nbytes_segment total_nbytes_stored += nbytes_stored_segment @@ -105,7 +125,10 @@ def __init__(self, folder_path: Path | str, storage_options: dict | None = None) if annotations is not None: self.annotate(**annotations) # annotate compression ratios - cr = total_nbytes / total_nbytes_stored + if total_nbytes_stored > 0: + cr = total_nbytes / total_nbytes_stored + else: + cr = np.nan self.annotate(compression_ratio=cr, compression_ratio_segments=cr_by_segment) self._kwargs = {"folder_path": folder_path_kwarg, "storage_options": storage_options} @@ -150,7 +173,11 @@ class ZarrSortingExtractor(BaseSorting): Parameters ---------- folder_path : str or Path - Path to the zarr root file + Path to the zarr root file. This can be a local path or a remote path (s3:// or gcs://). + If the path is a remote path, the storage_options can be provided to specify credentials. + If the remote path is not accessible and backend_options is not provided, + the function will try to load the object in anonymous mode (anon=True), + which enables to load data from open buckets. storage_options : dict or None Storage options for zarr `store`. E.g., if "s3://" or "gcs://" they can provide authentication methods, etc. zarr_group : str or None, default: None @@ -165,7 +192,8 @@ def __init__(self, folder_path: Path | str, storage_options: dict | None = None, folder_path, folder_path_kwarg = resolve_zarr_path(folder_path) - zarr_root = self._root = zarr.open(str(folder_path), mode="r", storage_options=storage_options) + zarr_root = anononymous_zarr_open(folder_path, mode="r", storage_options=storage_options) + if zarr_group is None: self._root = zarr_root else: @@ -243,7 +271,7 @@ def read_zarr( """ # TODO @alessio : we should have something more explicit in our zarr format to tell which object it is. # for the futur SortingAnalyzer we will have this 2 fields!!! - root = zarr.open(str(folder_path), mode="r", storage_options=storage_options) + root = anononymous_zarr_open(folder_path, mode="r", storage_options=storage_options) if "channel_ids" in root.keys(): return read_zarr_recording(folder_path, storage_options=storage_options) elif "unit_ids" in root.keys(): @@ -329,8 +357,7 @@ def add_sorting_to_zarr_group(sorting: BaseSorting, zarr_group: zarr.hierarchy.G zarr_group.attrs["num_segments"] = int(num_segments) zarr_group.create_dataset(name="unit_ids", data=sorting.unit_ids, compressor=None) - if "compressor" not in kwargs: - compressor = get_default_zarr_compressor() + compressor = kwargs.get("compressor", get_default_zarr_compressor()) # save sub fields spikes_group = zarr_group.create_group(name="spikes") diff --git a/src/spikeinterface/curation/__init__.py b/src/spikeinterface/curation/__init__.py index 657b936fb9..975f2fe22f 100644 --- a/src/spikeinterface/curation/__init__.py +++ b/src/spikeinterface/curation/__init__.py @@ -3,7 +3,7 @@ from .remove_redundant import remove_redundant_units, find_redundant_units from .remove_duplicated_spikes import remove_duplicated_spikes from .remove_excess_spikes import remove_excess_spikes -from .auto_merge import get_potential_auto_merge +from .auto_merge import compute_merge_unit_groups, auto_merge_units, get_potential_auto_merge # manual sorting, @@ -15,3 +15,7 @@ from .curation_format import validate_curation_dict, curation_label_to_dataframe, apply_curation from .sortingview_curation import apply_sortingview_curation + +# automated curation +from .model_based_curation import auto_label_units, load_model +from .train_manual_curation import train_model, get_default_classifier_search_spaces diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 19336e5943..4f4cff144e 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -1,5 +1,7 @@ from __future__ import annotations +import warnings + from typing import Tuple import numpy as np import math @@ -12,49 +14,90 @@ HAVE_NUMBA = False from ..core import SortingAnalyzer, Templates -from ..core.template_tools import get_template_extremum_channel -from ..postprocessing import compute_correlograms from ..qualitymetrics import compute_refrac_period_violations, compute_firing_rates from .mergeunitssorting import MergeUnitsSorting from .curation_tools import resolve_merging_graph - -_possible_presets = ["similarity_correlograms", "x_contaminations", "temporal_splits", "feature_neighbors"] +_compute_merge_presets = { + "similarity_correlograms": [ + "num_spikes", + "remove_contaminated", + "unit_locations", + "template_similarity", + "correlogram", + "quality_score", + ], + "temporal_splits": [ + "num_spikes", + "remove_contaminated", + "unit_locations", + "template_similarity", + "presence_distance", + "quality_score", + ], + "x_contaminations": [ + "num_spikes", + "remove_contaminated", + "unit_locations", + "template_similarity", + "cross_contamination", + "quality_score", + ], + "feature_neighbors": [ + "num_spikes", + "snr", + "remove_contaminated", + "unit_locations", + "knn", + "quality_score", + ], +} _required_extensions = { - "unit_locations": ["unit_locations"], + "unit_locations": ["templates", "unit_locations"], "correlogram": ["correlograms"], - "template_similarity": ["template_similarity"], - "knn": ["spike_locations", "spike_amplitudes"], + "snr": ["templates", "noise_levels"], + "template_similarity": ["templates", "template_similarity"], + "knn": ["templates", "spike_locations", "spike_amplitudes"], } -def get_potential_auto_merge( +_default_step_params = { + "num_spikes": {"min_spikes": 100}, + "snr": {"min_snr": 2}, + "remove_contaminated": {"contamination_thresh": 0.2, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, + "unit_locations": {"max_distance_um": 150}, + "correlogram": { + "corr_diff_thresh": 0.16, + "censor_correlograms_ms": 0.15, + "sigma_smooth_ms": 0.6, + "adaptative_window_thresh": 0.5, + }, + "template_similarity": {"template_diff_thresh": 0.25}, + "presence_distance": {"presence_distance_thresh": 100}, + "knn": {"k_nn": 10}, + "cross_contamination": { + "cc_thresh": 0.1, + "p_value": 0.2, + "refractory_period_ms": 1.0, + "censored_period_ms": 0.3, + }, + "quality_score": {"firing_contamination_balance": 1.5, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, +} + + +def compute_merge_unit_groups( sorting_analyzer: SortingAnalyzer, preset: str | None = "similarity_correlograms", - resolve_graph: bool = False, - min_spikes: int = 100, - min_snr: float = 2, - max_distance_um: float = 150.0, - corr_diff_thresh: float = 0.16, - template_diff_thresh: float = 0.25, - contamination_thresh: float = 0.2, - presence_distance_thresh: float = 100, - p_value: float = 0.2, - cc_thresh: float = 0.1, - censored_period_ms: float = 0.3, - refractory_period_ms: float = 1.0, - sigma_smooth_ms: float = 0.6, - adaptative_window_thresh: float = 0.5, - censor_correlograms_ms: float = 0.15, - firing_contamination_balance: float = 2.5, - k_nn: int = 10, - knn_kwargs: dict | None = None, - presence_distance_kwargs: dict | None = None, + resolve_graph: bool = True, + steps_params: dict = None, + compute_needed_extensions: bool = True, extra_outputs: bool = False, steps: list[str] | None = None, -) -> list[tuple[int | str, int | str]] | Tuple[tuple[int | str, int | str], dict]: + force_copy: bool = True, + **job_kwargs, +) -> list[tuple[int | str, int | str]] | Tuple[list[tuple[int | str, int | str]], dict]: """ Algorithm to find and check potential merges between units. @@ -78,6 +121,9 @@ def get_potential_auto_merge( Q = f(1 - (k + 1)C) + IMPORTANT: internally, all computations are relying on extensions of the analyzer, that are computed + with default parameters if not present (i.e. correlograms, template_similarity, ...) If you want to + have a finer control on these values, please precompute the extensions before applying the auto_merge Parameters ---------- @@ -98,47 +144,11 @@ def get_potential_auto_merge( * | "feature_neighbors": focused on finding unit pairs whose spikes are close in the feature space using kNN. | It uses the following steps: "num_spikes", "snr", "remove_contaminated", "unit_locations", | "knn", "quality_score" - If `preset` is None, you can specify the steps manually with the `steps` parameter. - resolve_graph : bool, default: False + resolve_graph : bool, default: True If True, the function resolves the potential unit pairs to be merged into multiple-unit merges. - min_spikes : int, default: 100 - Minimum number of spikes for each unit to consider a potential merge. - Enough spikes are needed to estimate the correlogram - min_snr : float, default 2 - Minimum Signal to Noise ratio for templates to be considered while merging - max_distance_um : float, default: 150 - Maximum distance between units for considering a merge - corr_diff_thresh : float, default: 0.16 - The threshold on the "correlogram distance metric" for considering a merge. - It needs to be between 0 and 1 - template_diff_thresh : float, default: 0.25 - The threshold on the "template distance metric" for considering a merge. - It needs to be between 0 and 1 - contamination_thresh : float, default: 0.2 - Threshold for not taking in account a unit when it is too contaminated. - presence_distance_thresh : float, default: 100 - Parameter to control how present two units should be simultaneously. - p_value : float, default: 0.2 - The p-value threshold for the cross-contamination test. - cc_thresh : float, default: 0.1 - The threshold on the cross-contamination for considering a merge. - censored_period_ms : float, default: 0.3 - Used to compute the refractory period violations aka "contamination". - refractory_period_ms : float, default: 1 - Used to compute the refractory period violations aka "contamination". - sigma_smooth_ms : float, default: 0.6 - Parameters to smooth the correlogram estimation. - adaptative_window_thresh : float, default: 0.5 - Parameter to detect the window size in correlogram estimation. - censor_correlograms_ms : float, default: 0.15 - The period to censor on the auto and cross-correlograms. - firing_contamination_balance : float, default: 2.5 - Parameter to control the balance between firing rate and contamination in computing unit "quality score". - k_nn : int, default 5 - The number of neighbors to consider for every spike in the recording. - knn_kwargs : dict, default None - The dict of extra params to be passed to knn. + compute_needed_extensions : bool, default : True + Should we force the computation of needed extensions, if not already computed? extra_outputs : bool, default: False If True, an additional dictionary (`outs`) with processed data is returned. steps : None or list of str, default: None @@ -146,157 +156,141 @@ def get_potential_auto_merge( Pontential steps : "num_spikes", "snr", "remove_contaminated", "unit_locations", "correlogram", "template_similarity", "presence_distance", "cross_contamination", "knn", "quality_score" Please check steps explanations above! - presence_distance_kwargs : None|dict, default: None - A dictionary of kwargs to be passed to compute_presence_distance(). + steps_params : dict + A dictionary whose keys are the steps, and keys are steps parameters. + force_copy : boolean, default: True + When new extensions are computed, the default is to make a copy of the analyzer, to avoid overwriting + already computed extensions. False if you want to overwrite Returns ------- - potential_merges: - A list of tuples of 2 elements (if `resolve_graph`if false) or 2+ elements (if `resolve_graph` is true). - List of pairs that could be merged. + merge_unit_groups: + List of groups that need to be merge. + When `resolve_graph` is true (default) a list of tuples of 2+ elements + If `resolve_graph` is false then a list of tuple of 2 elements is returned instead. outs: Returned only when extra_outputs=True A dictionary that contains data for debugging and plotting. References ---------- - This function is inspired and built upon similar functions from Lussac [Llobet]_, + This function used to be inspired and built upon similar functions from Lussac [Llobet]_, done by Aurelien Wyngaard and Victor Llobet. https://github.com/BarbourLab/lussac/blob/v1.0.0/postprocessing/merge_units.py + + However, it has been greatly consolidated and refined depending on the presets. """ import scipy sorting = sorting_analyzer.sorting unit_ids = sorting.unit_ids - # to get fast computation we will not analyse pairs when: - # * not enough spikes for one of theses - # * auto correlogram is contaminated - # * to far away one from each other - - all_steps = [ - "num_spikes", - "snr", - "remove_contaminated", - "unit_locations", - "correlogram", - "template_similarity", - "presence_distance", - "knn", - "cross_contamination", - "quality_score", - ] - - if preset is not None and preset not in _possible_presets: - raise ValueError(f"preset must be one of {_possible_presets}") - - if steps is None: - if preset is None: - if steps is None: - raise ValueError("You need to specify a preset or steps for the auto-merge function") - elif preset == "similarity_correlograms": - steps = [ - "num_spikes", - "remove_contaminated", - "unit_locations", - "template_similarity", - "correlogram", - "quality_score", - ] - elif preset == "temporal_splits": - steps = [ - "num_spikes", - "remove_contaminated", - "unit_locations", - "template_similarity", - "presence_distance", - "quality_score", - ] - elif preset == "x_contaminations": - steps = [ - "num_spikes", - "remove_contaminated", - "unit_locations", - "template_similarity", - "cross_contamination", - "quality_score", - ] - elif preset == "feature_neighbors": - steps = [ - "num_spikes", - "snr", - "remove_contaminated", - "unit_locations", - "knn", - "quality_score", - ] - + if preset is None and steps is None: + raise ValueError("You need to specify a preset or steps for the auto-merge function") + elif steps is not None: + # steps has precedence on presets + pass + elif preset is not None: + if preset not in _compute_merge_presets: + raise ValueError(f"preset must be one of {list(_compute_merge_presets.keys())}") + steps = _compute_merge_presets[preset] + + # check at least one extension is needed + at_least_one_extension_to_compute = False for step in steps: + assert step in _default_step_params, f"{step} is not a valid step" if step in _required_extensions: for ext in _required_extensions[step]: - if not sorting_analyzer.has_extension(ext): + if sorting_analyzer.has_extension(ext): + continue + if not compute_needed_extensions: raise ValueError(f"{step} requires {ext} extension") + at_least_one_extension_to_compute = True + + if force_copy and at_least_one_extension_to_compute: + # To avoid erasing the extensions of the user + sorting_analyzer = sorting_analyzer.copy() n = unit_ids.size - pair_mask = np.triu(np.arange(n)) > 0 + pair_mask = np.triu(np.arange(n), 1) > 0 outs = dict() for step in steps: - assert step in all_steps, f"{step} is not a valid step" + if step in _required_extensions: + for ext in _required_extensions[step]: + if sorting_analyzer.has_extension(ext): + continue + + # special case for templates + if ext == "templates" and not sorting_analyzer.has_extension("random_spikes"): + sorting_analyzer.compute(["random_spikes", "templates"], **job_kwargs) + else: + sorting_analyzer.compute(ext, **job_kwargs) + + params = _default_step_params.get(step).copy() + if steps_params is not None and step in steps_params: + params.update(steps_params[step]) # STEP : remove units with too few spikes if step == "num_spikes": + num_spikes = sorting.count_num_spikes_per_unit(outputs="array") - to_remove = num_spikes < min_spikes + to_remove = num_spikes < params["min_spikes"] pair_mask[to_remove, :] = False pair_mask[:, to_remove] = False + outs["num_spikes"] = to_remove # STEP : remove units with too small SNR elif step == "snr": qm_ext = sorting_analyzer.get_extension("quality_metrics") if qm_ext is None: - sorting_analyzer.compute("noise_levels") - sorting_analyzer.compute("quality_metrics", metric_names=["snr"]) + sorting_analyzer.compute("quality_metrics", metric_names=["snr"], **job_kwargs) qm_ext = sorting_analyzer.get_extension("quality_metrics") snrs = qm_ext.get_data()["snr"].values - to_remove = snrs < min_snr + to_remove = snrs < params["min_snr"] pair_mask[to_remove, :] = False pair_mask[:, to_remove] = False + outs["snr"] = to_remove # STEP : remove contaminated auto corr elif step == "remove_contaminated": contaminations, nb_violations = compute_refrac_period_violations( - sorting_analyzer, refractory_period_ms=refractory_period_ms, censored_period_ms=censored_period_ms + sorting_analyzer, + refractory_period_ms=params["refractory_period_ms"], + censored_period_ms=params["censored_period_ms"], ) nb_violations = np.array(list(nb_violations.values())) contaminations = np.array(list(contaminations.values())) - to_remove = contaminations > contamination_thresh + to_remove = contaminations > params["contamination_thresh"] pair_mask[to_remove, :] = False pair_mask[:, to_remove] = False + outs["remove_contaminated"] = to_remove # STEP : unit positions are estimated roughly with channel - elif step == "unit_locations" in steps: + elif step == "unit_locations": location_ext = sorting_analyzer.get_extension("unit_locations") unit_locations = location_ext.get_data()[:, :2] unit_distances = scipy.spatial.distance.cdist(unit_locations, unit_locations, metric="euclidean") - pair_mask = pair_mask & (unit_distances <= max_distance_um) + pair_mask = pair_mask & (unit_distances <= params["max_distance_um"]) outs["unit_distances"] = unit_distances # STEP : potential auto merge by correlogram - elif step == "correlogram" in steps: + elif step == "correlogram": correlograms_ext = sorting_analyzer.get_extension("correlograms") correlograms, bins = correlograms_ext.get_data() - mask = (bins[:-1] >= -censor_correlograms_ms) & (bins[:-1] < censor_correlograms_ms) + censor_ms = params["censor_correlograms_ms"] + sigma_smooth_ms = params["sigma_smooth_ms"] + mask = (bins[:-1] >= -censor_ms) & (bins[:-1] < censor_ms) correlograms[:, :, mask] = 0 correlograms_smoothed = smooth_correlogram(correlograms, bins, sigma_smooth_ms=sigma_smooth_ms) # find correlogram window for each units win_sizes = np.zeros(n, dtype=int) for unit_ind in range(n): auto_corr = correlograms_smoothed[unit_ind, unit_ind, :] - thresh = np.max(auto_corr) * adaptative_window_thresh + thresh = np.max(auto_corr) * params["adaptative_window_thresh"] win_size = get_unit_adaptive_window(auto_corr, thresh) win_sizes[unit_ind] = win_size correlogram_diff = compute_correlogram_diff( @@ -306,7 +300,7 @@ def get_potential_auto_merge( pair_mask=pair_mask, ) # print(correlogram_diff) - pair_mask = pair_mask & (correlogram_diff < corr_diff_thresh) + pair_mask = pair_mask & (correlogram_diff < params["corr_diff_thresh"]) outs["correlograms"] = correlograms outs["bins"] = bins outs["correlograms_smoothed"] = correlograms_smoothed @@ -314,22 +308,21 @@ def get_potential_auto_merge( outs["win_sizes"] = win_sizes # STEP : check if potential merge with CC also have template similarity - elif step == "template_similarity" in steps: + elif step == "template_similarity": template_similarity_ext = sorting_analyzer.get_extension("template_similarity") templates_similarity = template_similarity_ext.get_data() templates_diff = 1 - templates_similarity - pair_mask = pair_mask & (templates_diff < template_diff_thresh) + pair_mask = pair_mask & (templates_diff < params["template_diff_thresh"]) outs["templates_diff"] = templates_diff # STEP : check the vicinity of the spikes - elif step == "knn" in steps: - if knn_kwargs is None: - knn_kwargs = dict() - pair_mask = get_pairs_via_nntree(sorting_analyzer, k_nn, pair_mask, **knn_kwargs) + elif step == "knn": + pair_mask = get_pairs_via_nntree(sorting_analyzer, **params, pair_mask=pair_mask) # STEP : check how the rates overlap in times - elif step == "presence_distance" in steps: - presence_distance_kwargs = presence_distance_kwargs or dict() + elif step == "presence_distance": + presence_distance_kwargs = params.copy() + presence_distance_thresh = presence_distance_kwargs.pop("presence_distance_thresh") num_samples = [ sorting_analyzer.get_num_samples(segment_index) for segment_index in range(sorting.get_num_segments()) ] @@ -340,40 +333,243 @@ def get_potential_auto_merge( outs["presence_distances"] = presence_distances # STEP : check if the cross contamination is significant - elif step == "cross_contamination" in steps: - refractory = (censored_period_ms, refractory_period_ms) + elif step == "cross_contamination": + refractory = ( + params["censored_period_ms"], + params["refractory_period_ms"], + ) CC, p_values = compute_cross_contaminations( - sorting_analyzer, pair_mask, cc_thresh, refractory, contaminations + sorting_analyzer, pair_mask, params["cc_thresh"], refractory, contaminations ) - pair_mask = pair_mask & (p_values > p_value) + pair_mask = pair_mask & (p_values > params["p_value"]) outs["cross_contaminations"] = CC, p_values # STEP : validate the potential merges with CC increase the contamination quality metrics - elif step == "quality_score" in steps: + elif step == "quality_score": pair_mask, pairs_decreased_score = check_improve_contaminations_score( sorting_analyzer, pair_mask, contaminations, - firing_contamination_balance, - refractory_period_ms, - censored_period_ms, + params["firing_contamination_balance"], + params["refractory_period_ms"], + params["censored_period_ms"], ) outs["pairs_decreased_score"] = pairs_decreased_score # FINAL STEP : create the final list from pair_mask boolean matrix ind1, ind2 = np.nonzero(pair_mask) - potential_merges = list(zip(unit_ids[ind1], unit_ids[ind2])) - - # some methods return identities ie (1,1) which we can cleanup first. - potential_merges = [(ids[0], ids[1]) for ids in potential_merges if ids[0] != ids[1]] + merge_unit_groups = list(zip(unit_ids[ind1], unit_ids[ind2])) if resolve_graph: - potential_merges = resolve_merging_graph(sorting, potential_merges) + merge_unit_groups = resolve_merging_graph(sorting, merge_unit_groups) if extra_outputs: - return potential_merges, outs + return merge_unit_groups, outs else: - return potential_merges + return merge_unit_groups + + +def auto_merge_units( + sorting_analyzer: SortingAnalyzer, compute_merge_kwargs: dict = {}, apply_merge_kwargs: dict = {}, **job_kwargs +) -> SortingAnalyzer: + """ + Compute merge unit groups and apply it on a SortingAnalyzer. + Internally uses `compute_merge_unit_groups()` + """ + merge_unit_groups = compute_merge_unit_groups( + sorting_analyzer, extra_outputs=False, **compute_merge_kwargs, **job_kwargs + ) + + merged_analyzer = sorting_analyzer.merge_units(merge_unit_groups, **apply_merge_kwargs, **job_kwargs) + return merged_analyzer + + +def get_potential_auto_merge( + sorting_analyzer: SortingAnalyzer, + preset: str | None = "similarity_correlograms", + resolve_graph: bool = False, + min_spikes: int = 100, + min_snr: float = 2, + max_distance_um: float = 150.0, + corr_diff_thresh: float = 0.16, + template_diff_thresh: float = 0.25, + contamination_thresh: float = 0.2, + presence_distance_thresh: float = 100, + p_value: float = 0.2, + cc_thresh: float = 0.1, + censored_period_ms: float = 0.3, + refractory_period_ms: float = 1.0, + sigma_smooth_ms: float = 0.6, + adaptative_window_thresh: float = 0.5, + censor_correlograms_ms: float = 0.15, + firing_contamination_balance: float = 1.5, + k_nn: int = 10, + knn_kwargs: dict | None = None, + presence_distance_kwargs: dict | None = None, + extra_outputs: bool = False, + steps: list[str] | None = None, +) -> list[tuple[int | str, int | str]] | Tuple[tuple[int | str, int | str], dict]: + """ + This function is deprecated. Use compute_merge_unit_groups() instead. + This will be removed in 0.103.0 + + Algorithm to find and check potential merges between units. + + The merges are proposed based on a series of steps with different criteria: + + * "num_spikes": enough spikes are found in each unit for computing the correlogram (`min_spikes`) + * "snr": the SNR of the units is above a threshold (`min_snr`) + * "remove_contaminated": each unit is not contaminated (by checking auto-correlogram - `contamination_thresh`) + * "unit_locations": estimated unit locations are close enough (`max_distance_um`) + * "correlogram": the cross-correlograms of the two units are similar to each auto-corrleogram (`corr_diff_thresh`) + * "template_similarity": the templates of the two units are similar (`template_diff_thresh`) + * "presence_distance": the presence of the units is complementary in time (`presence_distance_thresh`) + * "cross_contamination": the cross-contamination is not significant (`cc_thresh` and `p_value`) + * "knn": the two units are close in the feature space + * "quality_score": the unit "quality score" is increased after the merge + + The "quality score" factors in the increase in firing rate (**f**) due to the merge and a possible increase in + contamination (**C**), wheighted by a factor **k** (`firing_contamination_balance`). + + .. math:: + + Q = f(1 - (k + 1)C) + + IMPORTANT: internally, all computations are relying on extensions of the analyzer, that are computed + with default parameters if not present (i.e. correlograms, template_similarity, ...) If you want to + have a finer control on these values, please precompute the extensions before applying the auto_merge + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The SortingAnalyzer + preset : "similarity_correlograms" | "x_contaminations" | "temporal_splits" | "feature_neighbors" | None, default: "similarity_correlograms" + The preset to use for the auto-merge. Presets combine different steps into a recipe and focus on: + + * | "similarity_correlograms": mainly focused on template similarity and correlograms. + | It uses the following steps: "num_spikes", "remove_contaminated", "unit_locations", + | "template_similarity", "correlogram", "quality_score" + * | "x_contaminations": similar to "similarity_correlograms", but checks for cross-contamination instead of correlograms. + | It uses the following steps: "num_spikes", "remove_contaminated", "unit_locations", + | "template_similarity", "cross_contamination", "quality_score" + * | "temporal_splits": focused on finding temporal splits using presence distance. + | It uses the following steps: "num_spikes", "remove_contaminated", "unit_locations", + | "template_similarity", "presence_distance", "quality_score" + * | "feature_neighbors": focused on finding unit pairs whose spikes are close in the feature space using kNN. + | It uses the following steps: "num_spikes", "snr", "remove_contaminated", "unit_locations", + | "knn", "quality_score" + + If `preset` is None, you can specify the steps manually with the `steps` parameter. + resolve_graph : bool, default: False + If True, the function resolves the potential unit pairs to be merged into multiple-unit merges. + min_spikes : int, default: 100 + Minimum number of spikes for each unit to consider a potential merge. + Enough spikes are needed to estimate the correlogram + min_snr : float, default 2 + Minimum Signal to Noise ratio for templates to be considered while merging + max_distance_um : float, default: 150 + Maximum distance between units for considering a merge + corr_diff_thresh : float, default: 0.16 + The threshold on the "correlogram distance metric" for considering a merge. + It needs to be between 0 and 1 + template_diff_thresh : float, default: 0.25 + The threshold on the "template distance metric" for considering a merge. + It needs to be between 0 and 1 + contamination_thresh : float, default: 0.2 + Threshold for not taking in account a unit when it is too contaminated. + presence_distance_thresh : float, default: 100 + Parameter to control how present two units should be simultaneously. + p_value : float, default: 0.2 + The p-value threshold for the cross-contamination test. + cc_thresh : float, default: 0.1 + The threshold on the cross-contamination for considering a merge. + censored_period_ms : float, default: 0.3 + Used to compute the refractory period violations aka "contamination". + refractory_period_ms : float, default: 1 + Used to compute the refractory period violations aka "contamination". + sigma_smooth_ms : float, default: 0.6 + Parameters to smooth the correlogram estimation. + adaptative_window_thresh : float, default: 0.5 + Parameter to detect the window size in correlogram estimation. + censor_correlograms_ms : float, default: 0.15 + The period to censor on the auto and cross-correlograms. + firing_contamination_balance : float, default: 1.5 + Parameter to control the balance between firing rate and contamination in computing unit "quality score". + k_nn : int, default 5 + The number of neighbors to consider for every spike in the recording. + knn_kwargs : dict, default None + The dict of extra params to be passed to knn. + extra_outputs : bool, default: False + If True, an additional dictionary (`outs`) with processed data is returned. + steps : None or list of str, default: None + Which steps to run, if no preset is used. + Pontential steps : "num_spikes", "snr", "remove_contaminated", "unit_locations", "correlogram", + "template_similarity", "presence_distance", "cross_contamination", "knn", "quality_score" + Please check steps explanations above! + presence_distance_kwargs : None|dict, default: None + A dictionary of kwargs to be passed to compute_presence_distance(). + + Returns + ------- + potential_merges: + A list of tuples of 2 elements (if `resolve_graph`if false) or 2+ elements (if `resolve_graph` is true). + List of pairs that could be merged. + outs: + Returned only when extra_outputs=True + A dictionary that contains data for debugging and plotting. + + References + ---------- + This function is inspired and built upon similar functions from Lussac [Llobet]_, + done by Aurelien Wyngaard and Victor Llobet. + https://github.com/BarbourLab/lussac/blob/v1.0.0/postprocessing/merge_units.py + """ + warnings.warn( + "get_potential_auto_merge() is deprecated. Use compute_merge_unit_groups() instead", + DeprecationWarning, + stacklevel=2, + ) + + presence_distance_kwargs = presence_distance_kwargs or dict() + knn_kwargs = knn_kwargs or dict() + return compute_merge_unit_groups( + sorting_analyzer, + preset, + resolve_graph, + steps_params={ + "num_spikes": {"min_spikes": min_spikes}, + "snr": {"min_snr": min_snr}, + "remove_contaminated": { + "contamination_thresh": contamination_thresh, + "refractory_period_ms": refractory_period_ms, + "censored_period_ms": censored_period_ms, + }, + "unit_locations": {"max_distance_um": max_distance_um}, + "correlogram": { + "corr_diff_thresh": corr_diff_thresh, + "censor_correlograms_ms": censor_correlograms_ms, + "sigma_smooth_ms": sigma_smooth_ms, + "adaptative_window_thresh": adaptative_window_thresh, + }, + "template_similarity": {"template_diff_thresh": template_diff_thresh}, + "presence_distance": {"presence_distance_thresh": presence_distance_thresh, **presence_distance_kwargs}, + "knn": {"k_nn": k_nn, **knn_kwargs}, + "cross_contamination": { + "cc_thresh": cc_thresh, + "p_value": p_value, + "refractory_period_ms": refractory_period_ms, + "censored_period_ms": censored_period_ms, + }, + "quality_score": { + "firing_contamination_balance": firing_contamination_balance, + "refractory_period_ms": refractory_period_ms, + "censored_period_ms": censored_period_ms, + }, + }, + compute_needed_extensions=True, + extra_outputs=extra_outputs, + steps=steps, + ) def get_pairs_via_nntree(sorting_analyzer, k_nn=5, pair_mask=None, **knn_kwargs): @@ -661,10 +857,10 @@ def check_improve_contaminations_score( f_new = compute_firing_rates(sorting_analyzer_new)[unit_id1] # old and new scores - k = firing_contamination_balance - score_1 = f_1 * (1 - (k + 1) * c_1) - score_2 = f_2 * (1 - (k + 1) * c_2) - score_new = f_new * (1 - (k + 1) * c_new) + k = 1 + firing_contamination_balance + score_1 = f_1 * (1 - k * c_1) + score_2 = f_2 * (1 - k * c_2) + score_new = f_new * (1 - k * c_new) if score_new < score_1 or score_new < score_2: # the score is not improved diff --git a/src/spikeinterface/curation/curation_format.py b/src/spikeinterface/curation/curation_format.py index 5f85538b08..80f251ca43 100644 --- a/src/spikeinterface/curation/curation_format.py +++ b/src/spikeinterface/curation/curation_format.py @@ -45,12 +45,16 @@ def validate_curation_dict(curation_dict): if not removed_units_set.issubset(unit_set): raise ValueError("Curation format: some removed units are not in the unit list") + for group in curation_dict["merge_unit_groups"]: + if len(group) < 2: + raise ValueError("Curation format: 'merge_unit_groups' must be list of list with at least 2 elements") + all_merging_groups = [set(group) for group in curation_dict["merge_unit_groups"]] for gp_1, gp_2 in combinations(all_merging_groups, 2): if len(gp_1.intersection(gp_2)) != 0: - raise ValueError("Some units belong to multiple merge groups") + raise ValueError("Curation format: some units belong to multiple merge groups") if len(removed_units_set.intersection(merged_units_set)) != 0: - raise ValueError("Some units were merged and deleted") + raise ValueError("Curation format: some units were merged and deleted") # Check the labels exclusivity for lbl in curation_dict["manual_labels"]: @@ -238,7 +242,7 @@ def apply_curation_labels(sorting, new_unit_ids, curation_dict): all_values = np.zeros(sorting.unit_ids.size, dtype=values.dtype) for unit_ind, unit_id in enumerate(sorting.unit_ids): if unit_id not in new_unit_ids: - ind = curation_dict["unit_ids"].index(unit_id) + ind = list(curation_dict["unit_ids"]).index(unit_id) all_values[unit_ind] = values[ind] sorting.set_property(key, all_values) @@ -253,7 +257,7 @@ def apply_curation_labels(sorting, new_unit_ids, curation_dict): group_values.append(value) if len(set(group_values)) == 1: # all group has the same label or empty - sorting.set_property(key, values=group_values, ids=[new_unit_id]) + sorting.set_property(key, values=group_values[:1], ids=[new_unit_id]) else: for key in label_def["label_options"]: @@ -339,18 +343,22 @@ def apply_curation( elif isinstance(sorting_or_analyzer, SortingAnalyzer): analyzer = sorting_or_analyzer - analyzer = analyzer.remove_units(curation_dict["removed_units"]) - analyzer, new_unit_ids = analyzer.merge_units( - curation_dict["merge_unit_groups"], - censor_ms=censor_ms, - merging_mode=merging_mode, - sparsity_overlap=sparsity_overlap, - new_id_strategy=new_id_strategy, - return_new_unit_ids=True, - format="memory", - verbose=verbose, - **job_kwargs, - ) + if len(curation_dict["removed_units"]) > 0: + analyzer = analyzer.remove_units(curation_dict["removed_units"]) + if len(curation_dict["merge_unit_groups"]) > 0: + analyzer, new_unit_ids = analyzer.merge_units( + curation_dict["merge_unit_groups"], + censor_ms=censor_ms, + merging_mode=merging_mode, + sparsity_overlap=sparsity_overlap, + new_id_strategy=new_id_strategy, + return_new_unit_ids=True, + format="memory", + verbose=verbose, + **job_kwargs, + ) + else: + new_unit_ids = [] apply_curation_labels(analyzer.sorting, new_unit_ids, curation_dict) return analyzer else: diff --git a/src/spikeinterface/curation/model_based_curation.py b/src/spikeinterface/curation/model_based_curation.py new file mode 100644 index 0000000000..93ad03734c --- /dev/null +++ b/src/spikeinterface/curation/model_based_curation.py @@ -0,0 +1,435 @@ +import numpy as np +from pathlib import Path +import json +import warnings +import re + +from spikeinterface.core import SortingAnalyzer +from spikeinterface.curation.train_manual_curation import ( + try_to_get_metrics_from_analyzer, + _get_computed_metrics, + _format_metric_dataframe, +) +from copy import deepcopy + + +class ModelBasedClassification: + """ + Class for performing model-based classification on spike sorting data. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The sorting analyzer object containing the spike sorting data. + pipeline : Pipeline + The pipeline object representing the trained classification model. + + Attributes + ---------- + sorting_analyzer : SortingAnalyzer + The sorting analyzer object containing the spike sorting data. + pipeline : Pipeline + The pipeline object representing the trained classification model. + required_metrics : Sequence[str] + The list of required metrics for classification, extracted from the pipeline. + + Methods + ------- + predict_labels() + Predicts the labels for the spike sorting data using the trained model. + """ + + def __init__(self, sorting_analyzer: SortingAnalyzer, pipeline): + from sklearn.pipeline import Pipeline + + if not isinstance(pipeline, Pipeline): + raise ValueError("The `pipeline` must be an instance of sklearn.pipeline.Pipeline") + + self.sorting_analyzer = sorting_analyzer + self.pipeline = pipeline + self.required_metrics = pipeline.feature_names_in_ + + def predict_labels( + self, label_conversion=None, input_data=None, export_to_phy=False, model_info=None, enforce_metric_params=False + ): + """ + Predicts the labels for the spike sorting data using the trained model. + Populates the sorting object with the predicted labels and probabilities as unit properties + + Parameters + ---------- + model_info : dict or None, default: None + Model info, generated with model, used to check metric parameters used to train it. + label_conversion : dict or None, default: None + A dictionary for converting the predicted labels (which are integers) to custom labels. If None, + tries to find in `model_info` file. The dictionary should have the format {old_label: new_label}. + input_data : pandas.DataFrame or None, default: None + The input data for classification. If not provided, the method will extract metrics stored in the sorting analyzer. + export_to_phy : bool, default: False. + Whether to export the classified units to Phy format. Default is False. + enforce_metric_params : bool, default: False + If True and the parameters used to compute the metrics in `sorting_analyzer` are different than the parmeters + used to compute the metrics used to train the model, this function will raise an error. Otherwise, a warning is raised. + + Returns + ------- + pd.DataFrame + A dataframe containing the classified units and their corresponding predictions and probabilities, + indexed by their `unit_ids`. + """ + import pandas as pd + + # Get metrics DataFrame for classification + if input_data is None: + input_data = _get_computed_metrics(self.sorting_analyzer) + else: + if not isinstance(input_data, pd.DataFrame): + raise ValueError("Input data must be a pandas DataFrame") + + input_data = self._check_required_metrics_are_present(input_data) + + if model_info is not None: + self._check_params_for_classification(enforce_metric_params, model_info=model_info) + + if model_info is not None and label_conversion is None: + try: + string_label_conversion = model_info["label_conversion"] + # json keys are strings; we convert these to ints + label_conversion = {} + for key, value in string_label_conversion.items(): + label_conversion[int(key)] = value + except: + warnings.warn("Could not find `label_conversion` key in `model_info.json` file") + + input_data = _format_metric_dataframe(input_data) + + # Apply classifier + predictions = self.pipeline.predict(input_data) + probabilities = self.pipeline.predict_proba(input_data) + probabilities = np.max(probabilities, axis=1) + + if isinstance(label_conversion, dict): + + if set(predictions).issubset(set(label_conversion.keys())) is False: + raise ValueError("Labels in predictions do not match those in label_conversion") + predictions = [label_conversion[label] for label in predictions] + + classified_units = pd.DataFrame( + zip(predictions, probabilities), columns=["prediction", "probability"], index=self.sorting_analyzer.unit_ids + ) + + # Set predictions and probability as sorting properties + self.sorting_analyzer.sorting.set_property("classifier_label", predictions) + self.sorting_analyzer.sorting.set_property("classifier_probability", probabilities) + + if export_to_phy: + self._export_to_phy(classified_units) + + return classified_units + + def _check_required_metrics_are_present(self, calculated_metrics): + + # Check all the required metrics have been calculated + required_metrics = set(self.required_metrics) + if required_metrics.issubset(set(calculated_metrics)): + input_data = calculated_metrics[self.required_metrics] + else: + raise ValueError( + "Input data does not contain all required metrics for classification", + f"Missing metrics: {required_metrics.difference(calculated_metrics)}", + ) + + return input_data + + def _check_params_for_classification(self, enforce_metric_params=False, model_info=None): + """ + Check that quality and template metrics parameters match those used to train the model + + Parameters + ---------- + enforce_metric_params : bool, default: False + If True and the parameters used to compute the metrics in `sorting_analyzer` are different than the parmeters + used to compute the metrics used to train the model, this function will raise an error. Otherwise, a warning is raised. + model_info : dict, default: None + Dictionary of model info containing provenance of the model. + """ + + extension_names = ["quality_metrics", "template_metrics"] + + metric_extensions = [self.sorting_analyzer.get_extension(extension_name) for extension_name in extension_names] + + for metric_extension, extension_name in zip(metric_extensions, extension_names): + + # remove the 's' at the end of the extension name + extension_name = extension_name[:-1] + model_extension_params = model_info["metric_params"].get(extension_name + "_params") + + if metric_extension is not None and model_extension_params is not None: + + metric_params = metric_extension.params["metric_params"] + + inconsistent_metrics = [] + for metric in model_extension_params["metric_names"]: + model_metric_params = model_extension_params.get("metric_params") + if model_metric_params is None or metric not in model_metric_params: + inconsistent_metrics.append(metric) + else: + if metric_params[metric] != model_metric_params[metric]: + warning_message = f"{extension_name} params for {metric} do not match those used to train the model. Parameters can be found in the 'model_info.json' file." + if enforce_metric_params is True: + raise Exception(warning_message) + else: + warnings.warn(warning_message) + + if len(inconsistent_metrics) > 0: + warning_message = f"Parameters used to compute metrics {inconsistent_metrics}, used to train this model, are unknown." + if enforce_metric_params is True: + raise Exception(warning_message) + else: + warnings.warn(warning_message) + + def _export_to_phy(self, classified_units): + """Export the classified units to Phy as cluster_prediction.tsv file""" + + import pandas as pd + + # Create a new DataFrame with unit_id, prediction, and probability columns from dict {unit_id: (prediction, probability)} + classified_df = pd.DataFrame.from_dict(classified_units, orient="index", columns=["prediction", "probability"]) + + # Export to Phy format + try: + sorting_path = self.sorting_analyzer.sorting.get_annotation("phy_folder") + assert sorting_path is not None + assert Path(sorting_path).is_dir() + except AssertionError: + raise ValueError("Phy folder not found in sorting annotations, or is not a directory") + + classified_df.to_csv(f"{sorting_path}/cluster_prediction.tsv", sep="\t", index_label="cluster_id") + + +def auto_label_units( + sorting_analyzer: SortingAnalyzer, + model_folder=None, + model_name=None, + repo_id=None, + label_conversion=None, + trust_model=False, + trusted=None, + export_to_phy=False, + enforce_metric_params=False, +): + """ + Automatically labels units based on a model-based classification, either from a model + hosted on HuggingFaceHub or one available in a local folder. + + This function returns the predicted labels and the prediction probabilities, and populates + the sorting object with the predicted labels and probabilities in the 'classifier_label' and + 'classifier_probability' properties. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The sorting analyzer object containing the spike sorting results. + model_folder : str or Path, defualt: None + The path to the folder containing the model + repo_id : str | Path, default: None + Hugging face repo id which contains the model e.g. 'username/model' + model_name: str | Path, default: None + Filename of model e.g. 'my_model.skops'. If None, uses first model found. + label_conversion : dic | None, default: None + A dictionary for converting the predicted labels (which are integers) to custom labels. If None, + tries to extract from `model_info.json` file. The dictionary should have the format {old_label: new_label}. + export_to_phy : bool, default: False + Whether to export the results to Phy format. Default is False. + trust_model : bool, default: False + Whether to trust the model. If True, the `trusted` parameter that is passed to `skops.load` to load the model will be + automatically inferred. If False, the `trusted` parameter must be provided to indicate the trusted objects. + trusted : list of str, default: None + Passed to skops.load. The object will be loaded only if there are only trusted objects and objects of types listed in trusted in the dumped file. + enforce_metric_params : bool, default: False + If True and the parameters used to compute the metrics in `sorting_analyzer` are different than the parmeters + used to compute the metrics used to train the model, this function will raise an error. Otherwise, a warning is raised. + + + Returns + ------- + classified_units : pd.DataFrame + A dataframe containing the classified units, indexed by the `unit_ids`, containing the predicted label + and confidence probability of each labelled unit. + + Raises + ------ + ValueError + If the pipeline is not an instance of sklearn.pipeline.Pipeline. + + """ + from sklearn.pipeline import Pipeline + + model, model_info = load_model( + model_folder=model_folder, repo_id=repo_id, model_name=model_name, trust_model=trust_model, trusted=trusted + ) + + if not isinstance(model, Pipeline): + raise ValueError("The model must be an instance of sklearn.pipeline.Pipeline") + + model_based_classification = ModelBasedClassification(sorting_analyzer, model) + + classified_units = model_based_classification.predict_labels( + label_conversion=label_conversion, + export_to_phy=export_to_phy, + model_info=model_info, + enforce_metric_params=enforce_metric_params, + ) + + return classified_units + + +def load_model(model_folder=None, repo_id=None, model_name=None, trust_model=False, trusted=None): + """ + Loads a model and model_info from a HuggingFaceHub repo or a local folder. + + Parameters + ---------- + model_folder : str or Path, defualt: None + The path to the folder containing the model + repo_id : str | Path, default: None + Hugging face repo id which contains the model e.g. 'username/model' + model_name: str | Path, default: None + Filename of model e.g. 'my_model.skops'. If None, uses first model found. + trust_model : bool, default: False + Whether to trust the model. If True, the `trusted` parameter that is passed to `skops.load` to load the model will be + automatically inferred. If False, the `trusted` parameter must be provided to indicate the trusted objects. + trusted : list of str, default: None + Passed to skops.load. The object will be loaded only if there are only trusted objects and objects of types listed in trusted in the dumped file. + + + Returns + ------- + model, model_info + A model and metadata about the model + """ + + if model_folder is None and repo_id is None: + raise ValueError("Please provide a 'model_folder' or a 'repo_id'.") + elif model_folder is not None and repo_id is not None: + raise ValueError("Please only provide one of 'model_folder' or 'repo_id'.") + elif model_folder is not None: + model, model_info = _load_model_from_folder( + model_folder=model_folder, model_name=model_name, trust_model=trust_model, trusted=trusted + ) + else: + model, model_info = _load_model_from_huggingface( + repo_id=repo_id, model_name=model_name, trust_model=trust_model, trusted=trusted + ) + + return model, model_info + + +def _load_model_from_huggingface(repo_id=None, model_name=None, trust_model=False, trusted=None): + """ + Loads a model from a huggingface repo + + Returns + ------- + model, model_info + A model and metadata about the model + """ + + from huggingface_hub import list_repo_files + from huggingface_hub import hf_hub_download + + # get repo filenames + repo_filenames = list_repo_files(repo_id=repo_id) + + # download all skops and json files to temp directory + for filename in repo_filenames: + if Path(filename).suffix in [".skops", ".json"]: + full_path = hf_hub_download(repo_id=repo_id, filename=filename) + model_folder = Path(full_path).parent + + model, model_info = _load_model_from_folder( + model_folder=model_folder, model_name=model_name, trust_model=trust_model, trusted=trusted + ) + + return model, model_info + + +def _load_model_from_folder(model_folder=None, model_name=None, trust_model=False, trusted=None): + """ + Loads a model and model_info from a folder + + Returns + ------- + model, model_info + A model and metadata about the model + """ + + import skops.io as skio + from skops.io.exceptions import UntrustedTypesFoundException + + folder = Path(model_folder) + assert folder.is_dir(), f"The folder {folder}, does not exist." + + # look for any .skops files + skops_files = list(folder.glob("*.skops")) + assert len(skops_files) > 0, f"There are no '.skops' files in the folder {folder}" + + if len(skops_files) > 1: + if model_name is None: + model_names = [f.name for f in skops_files] + raise ValueError( + f"There are more than 1 '.skops' file in folder {folder}. You have to specify " + f"the file using the 'model_name' argument. Available files:\n{model_names}" + ) + else: + skops_file = folder / Path(model_name) + assert skops_file.is_file(), f"Model file {skops_file} not found." + elif len(skops_files) == 1: + skops_file = skops_files[0] + + if trust_model and trusted is None: + try: + model = skio.load(skops_file) + except UntrustedTypesFoundException as e: + exception_msg = str(e) + # the exception message contains the list of untrusted objects. The following + # search assumes it is the only list in the message. + string_list = re.search(r"\[(.*?)\]", exception_msg).group() + trusted = [list_item for list_item in string_list.split("'") if len(list_item) > 2] + + model = skio.load(skops_file, trusted=trusted) + + model_info_path = folder / "model_info.json" + if not model_info_path.is_file(): + warnings.warn("No 'model_info.json' file found in folder. No metadata can be checked.") + model_info = None + else: + model_info = json.load(open(model_info_path)) + + model_info = handle_backwards_compatibility_metric_params(model_info) + + return model, model_info + + +def handle_backwards_compatibility_metric_params(model_info): + + if ( + model_info.get("metric_params") is not None + and model_info.get("metric_params").get("quality_metric_params") is not None + ): + if (qm_params := model_info["metric_params"]["quality_metric_params"].get("qm_params")) is not None: + model_info["metric_params"]["quality_metric_params"]["metric_params"] = qm_params + del model_info["metric_params"]["quality_metric_params"]["qm_params"] + + if ( + model_info.get("metric_params") is not None + and model_info.get("metric_params").get("template_metric_params") is not None + ): + if (tm_params := model_info["metric_params"]["template_metric_params"].get("metrics_kwargs")) is not None: + metric_params = {} + for metric_name in model_info["metric_params"]["template_metric_params"].get("metric_names"): + metric_params[metric_name] = deepcopy(tm_params) + model_info["metric_params"]["template_metric_params"]["metric_params"] = metric_params + del model_info["metric_params"]["template_metric_params"]["metrics_kwargs"] + + return model_info diff --git a/src/spikeinterface/curation/tests/common.py b/src/spikeinterface/curation/tests/common.py index 9cd20f4bfc..e9c4c4a463 100644 --- a/src/spikeinterface/curation/tests/common.py +++ b/src/spikeinterface/curation/tests/common.py @@ -19,6 +19,11 @@ def make_sorting_analyzer(sparse=True): seed=2205, ) + channel_ids_as_integers = [id for id in range(recording.get_num_channels())] + unit_ids_as_integers = [id for id in range(sorting.get_num_units())] + recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers) + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers) + sorting_analyzer = create_sorting_analyzer(sorting=sorting, recording=recording, format="memory", sparse=sparse) sorting_analyzer.compute("random_spikes") sorting_analyzer.compute("waveforms", **job_kwargs) diff --git a/src/spikeinterface/curation/tests/test_auto_merge.py b/src/spikeinterface/curation/tests/test_auto_merge.py index 33fd06d27a..4c05f41a4c 100644 --- a/src/spikeinterface/curation/tests/test_auto_merge.py +++ b/src/spikeinterface/curation/tests/test_auto_merge.py @@ -3,16 +3,16 @@ from spikeinterface.core import create_sorting_analyzer from spikeinterface.core.generate import inject_some_split_units -from spikeinterface.curation import get_potential_auto_merge +from spikeinterface.curation import compute_merge_unit_groups, auto_merge from spikeinterface.curation.tests.common import make_sorting_analyzer, sorting_analyzer_for_curation @pytest.mark.parametrize( - "preset", ["x_contaminations", "feature_neighbors", "temporal_splits", "similarity_correlograms"] + "preset", ["x_contaminations", "feature_neighbors", "temporal_splits", "similarity_correlograms", None] ) -def test_get_auto_merge_list(sorting_analyzer_for_curation, preset): +def test_compute_merge_unit_groups(sorting_analyzer_for_curation, preset): print(sorting_analyzer_for_curation) sorting = sorting_analyzer_for_curation.sorting @@ -47,32 +47,38 @@ def test_get_auto_merge_list(sorting_analyzer_for_curation, preset): ) if preset is not None: - potential_merges, outs = get_potential_auto_merge( + # do not resolve graph for checking true pairs + merge_unit_groups, outs = compute_merge_unit_groups( sorting_analyzer, preset=preset, - min_spikes=1000, - max_distance_um=150.0, - contamination_thresh=0.2, - corr_diff_thresh=0.16, - template_diff_thresh=0.25, - censored_period_ms=0.0, - refractory_period_ms=4.0, - sigma_smooth_ms=0.6, - adaptative_window_thresh=0.5, - firing_contamination_balance=1.5, + resolve_graph=False, + # min_spikes=1000, + # max_distance_um=150.0, + # contamination_thresh=0.2, + # corr_diff_thresh=0.16, + # template_diff_thresh=0.25, + # censored_period_ms=0.0, + # refractory_period_ms=4.0, + # sigma_smooth_ms=0.6, + # adaptative_window_thresh=0.5, + # firing_contamination_balance=1.5, extra_outputs=True, + **job_kwargs, ) if preset == "x_contaminations": - assert len(potential_merges) == num_unit_splited + assert len(merge_unit_groups) == num_unit_splited for true_pair in other_ids.values(): true_pair = tuple(true_pair) - assert true_pair in potential_merges + assert true_pair in merge_unit_groups else: # when preset is None you have to specify the steps with pytest.raises(ValueError): - potential_merges = get_potential_auto_merge(sorting_analyzer, preset=preset) - potential_merges = get_potential_auto_merge( - sorting_analyzer, preset=preset, steps=["min_spikes", "min_snr", "remove_contaminated", "unit_positions"] + merge_unit_groups = compute_merge_unit_groups(sorting_analyzer, preset=preset) + merge_unit_groups = compute_merge_unit_groups( + sorting_analyzer, + preset=preset, + steps=["num_spikes", "snr", "remove_contaminated", "unit_locations"], + **job_kwargs, ) # DEBUG @@ -93,7 +99,7 @@ def test_get_auto_merge_list(sorting_analyzer_for_curation, preset): # m = correlograms.shape[2] // 2 - # for unit_id1, unit_id2 in potential_merges[:5]: + # for unit_id1, unit_id2 in merge_unit_groups[:5]: # unit_ind1 = sorting_with_split.id_to_index(unit_id1) # unit_ind2 = sorting_with_split.id_to_index(unit_id2) @@ -129,4 +135,6 @@ def test_get_auto_merge_list(sorting_analyzer_for_curation, preset): if __name__ == "__main__": sorting_analyzer = make_sorting_analyzer(sparse=True) - test_get_auto_merge_list(sorting_analyzer) + # preset = "x_contaminations" + preset = None + test_compute_merge_unit_groups(sorting_analyzer, preset=preset) diff --git a/src/spikeinterface/curation/tests/test_model_based_curation.py b/src/spikeinterface/curation/tests/test_model_based_curation.py new file mode 100644 index 0000000000..3683b417df --- /dev/null +++ b/src/spikeinterface/curation/tests/test_model_based_curation.py @@ -0,0 +1,167 @@ +import pytest +from pathlib import Path +from spikeinterface.curation.tests.common import make_sorting_analyzer, sorting_analyzer_for_curation +from spikeinterface.curation.model_based_curation import ModelBasedClassification +from spikeinterface.curation import auto_label_units, load_model +from spikeinterface.curation.train_manual_curation import _get_computed_metrics + +import numpy as np + +if hasattr(pytest, "global_test_folder"): + cache_folder = pytest.global_test_folder / "curation" +else: + cache_folder = Path("cache_folder") / "curation" + + +@pytest.fixture +def model(): + """A toy model, created using the `sorting_analyzer_for_curation` from `spikeinterface.curation.tests.common`. + It has been trained locally and, when applied to `sorting_analyzer_for_curation` will label its 5 units with + the following labels: [1,0,1,0,1].""" + + model = load_model(Path(__file__).parent / "trained_pipeline/", trusted=["numpy.dtype"]) + return model + + +@pytest.fixture +def required_metrics(): + """These are the metrics which `model` are trained on.""" + return ["num_spikes", "snr", "half_width"] + + +def test_model_based_classification_init(sorting_analyzer_for_curation, model): + """Test that the ModelBasedClassification attributes are correctly initialised""" + + model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model[0]) + assert model_based_classification.sorting_analyzer == sorting_analyzer_for_curation + assert model_based_classification.pipeline == model[0] + assert np.all(model_based_classification.required_metrics == model_based_classification.pipeline.feature_names_in_) + + +def test_metric_ordering_independence(sorting_analyzer_for_curation, model): + """The function `auto_label_units` needs the correct metrics to have been computed. However, + it should be independent of the order of computation. We test this here.""" + + sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"]) + sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["num_spikes", "snr"]) + + model_folder = Path(__file__).parent / Path("trained_pipeline") + + prediction_prob_dataframe_1 = auto_label_units( + sorting_analyzer=sorting_analyzer_for_curation, + model_folder=model_folder, + trusted=["numpy.dtype"], + ) + + sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["snr", "num_spikes"]) + + prediction_prob_dataframe_2 = auto_label_units( + sorting_analyzer=sorting_analyzer_for_curation, + model_folder=model_folder, + trusted=["numpy.dtype"], + ) + + assert prediction_prob_dataframe_1.equals(prediction_prob_dataframe_2) + + +def test_model_based_classification_get_metrics_for_classification( + sorting_analyzer_for_curation, model, required_metrics +): + """If the user has not computed the required metrics, an error should be returned. + This test checks that an error occurs when the required metrics have not been computed, + and that no error is returned when the required metrics have been computed. + """ + + sorting_analyzer_for_curation.delete_extension("quality_metrics") + sorting_analyzer_for_curation.delete_extension("template_metrics") + + model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model[0]) + + # Check that ValueError is returned when no metrics are present in sorting_analyzer + with pytest.raises(ValueError): + computed_metrics = _get_computed_metrics(sorting_analyzer_for_curation) + + # Compute some (but not all) of the required metrics in sorting_analyzer, should still error + sorting_analyzer_for_curation.compute("quality_metrics", metric_names=[required_metrics[0]]) + computed_metrics = _get_computed_metrics(sorting_analyzer_for_curation) + with pytest.raises(ValueError): + model_based_classification._check_required_metrics_are_present(computed_metrics) + + # Compute all of the required metrics in sorting_analyzer, no more error + sorting_analyzer_for_curation.compute("quality_metrics", metric_names=required_metrics[0:2]) + sorting_analyzer_for_curation.compute("template_metrics", metric_names=[required_metrics[2]]) + + metrics_data = _get_computed_metrics(sorting_analyzer_for_curation) + assert metrics_data.shape[0] == len(sorting_analyzer_for_curation.sorting.get_unit_ids()) + assert set(metrics_data.columns.to_list()) == set(required_metrics) + + +def test_model_based_classification_export_to_phy(sorting_analyzer_for_curation, model): + # Test the _export_to_phy() method of ModelBasedClassification + model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model[0]) + classified_units = {0: (1, 0.5), 1: (0, 0.5), 2: (1, 0.5), 3: (0, 0.5), 4: (1, 0.5)} + # Function should fail here + with pytest.raises(ValueError): + model_based_classification._export_to_phy(classified_units) + # Make temp output folder and set as phy_folder + phy_folder = cache_folder / "phy_folder" + phy_folder.mkdir(parents=True, exist_ok=True) + + model_based_classification.sorting_analyzer.sorting.annotate(phy_folder=phy_folder) + model_based_classification._export_to_phy(classified_units) + assert (phy_folder / "cluster_prediction.tsv").exists() + + +def test_model_based_classification_predict_labels(sorting_analyzer_for_curation, model): + """The model `model` has been trained on the `sorting_analyzer` used in this test with + the labels `[1, 0, 1, 0, 1]`. Hence if we apply the model to this `sorting_analyzer` + we expect these labels to be outputted. The test checks this, and also checks + that label conversion works as expected.""" + + sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"]) + sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["num_spikes", "snr"]) + + # Test the predict_labels() method of ModelBasedClassification + model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model[0]) + classified_units = model_based_classification.predict_labels() + predictions = classified_units["prediction"].values + + assert np.all(predictions == np.array([1, 0, 1, 0, 1])) + + conversion = {0: "noise", 1: "good"} + classified_units_labelled = model_based_classification.predict_labels(label_conversion=conversion) + predictions_labelled = classified_units_labelled["prediction"] + assert np.all(predictions_labelled == ["good", "noise", "good", "noise", "good"]) + + +def test_exception_raised_when_metricparams_not_equal(sorting_analyzer_for_curation): + """We track whether the metric parameters used to compute the metrics used to train + a model are the same as the parameters used to compute the metrics in the sorting + analyzer which is being curated. If they are different, an error or warning will + be raised depending on the `enforce_metric_params` kwarg. This behaviour is tested here.""" + + sorting_analyzer_for_curation.compute( + "quality_metrics", metric_names=["num_spikes", "snr"], metric_params={"snr": {"peak_mode": "peak_to_peak"}} + ) + sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"]) + + model_folder = Path(__file__).parent / Path("trained_pipeline") + + model, model_info = load_model(model_folder=model_folder, trusted=["numpy.dtype"]) + model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model) + + # an error should be raised if `enforce_metric_params` is True + with pytest.raises(Exception): + model_based_classification._check_params_for_classification(enforce_metric_params=True, model_info=model_info) + + # but only a warning if `enforce_metric_params` is False + with pytest.warns(UserWarning): + model_based_classification._check_params_for_classification(enforce_metric_params=False, model_info=model_info) + + # Now test the positive case. Recompute using the default parameters + sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["num_spikes", "snr"], metric_params={}) + sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"]) + + model, model_info = load_model(model_folder=model_folder, trusted=["numpy.dtype"]) + model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model) + model_based_classification._check_params_for_classification(enforce_metric_params=True, model_info=model_info) diff --git a/src/spikeinterface/curation/tests/test_sortingview_curation.py b/src/spikeinterface/curation/tests/test_sortingview_curation.py index 945aca7937..ff80be365d 100644 --- a/src/spikeinterface/curation/tests/test_sortingview_curation.py +++ b/src/spikeinterface/curation/tests/test_sortingview_curation.py @@ -49,6 +49,9 @@ def test_gh_curation(): Test curation using GitHub URI. """ sorting = generate_sorting(num_units=10) + unit_ids_as_int = [id for id in range(sorting.get_num_units())] + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_int) + # curated link: # https://figurl.org/f?v=npm://@fi-sci/figurl-sortingview@12/dist&d=sha1://058ab901610aa9d29df565595a3cc2a81a1b08e5 gh_uri = "gh://SpikeInterface/spikeinterface/main/src/spikeinterface/curation/tests/sv-sorting-curation.json" @@ -76,6 +79,8 @@ def test_sha1_curation(): Test curation using SHA1 URI. """ sorting = generate_sorting(num_units=10) + unit_ids_as_int = [id for id in range(sorting.get_num_units())] + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_int) # from SHA1 # curated link: @@ -105,6 +110,8 @@ def test_json_curation(): Test curation using a JSON file. """ sorting = generate_sorting(num_units=10) + unit_ids_as_int = [id for id in range(sorting.get_num_units())] + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_int) # from curation.json json_file = parent_folder / "sv-sorting-curation.json" @@ -248,6 +255,8 @@ def test_json_no_merge_curation(): Test curation with no merges using a JSON file. """ sorting = generate_sorting(num_units=10) + unit_ids_as_int = [id for id in range(sorting.get_num_units())] + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_int) json_file = parent_folder / "sv-sorting-curation-no-merge.json" sorting_curated = apply_sortingview_curation(sorting, uri_or_json=json_file) diff --git a/src/spikeinterface/curation/tests/test_train_manual_curation.py b/src/spikeinterface/curation/tests/test_train_manual_curation.py new file mode 100644 index 0000000000..f455fbdb9c --- /dev/null +++ b/src/spikeinterface/curation/tests/test_train_manual_curation.py @@ -0,0 +1,285 @@ +import pytest +import numpy as np +import tempfile, csv +from pathlib import Path + +from spikeinterface.curation.tests.common import make_sorting_analyzer +from spikeinterface.curation.train_manual_curation import CurationModelTrainer, train_model + + +@pytest.fixture +def trainer(): + """A simple CurationModelTrainer object is created, which can later by used to + train models using data from `sorting_analyzer`s.""" + + folder = tempfile.mkdtemp() # Create a temporary output folder + imputation_strategies = ["median"] + scaling_techniques = ["standard_scaler"] + classifiers = ["LogisticRegression"] + metric_names = ["metric1", "metric2", "metric3"] + search_kwargs = {"cv": 3} + return CurationModelTrainer( + labels=[[0, 1, 0, 1, 0, 1, 0, 1, 0, 1]], + folder=folder, + metric_names=metric_names, + imputation_strategies=imputation_strategies, + scaling_techniques=scaling_techniques, + classifiers=classifiers, + search_kwargs=search_kwargs, + ) + + +def make_temp_training_csv(): + """Create a temporary CSV file with artificially generated quality metrics. + The data is designed to be easy to dicern between units. Even units metric + values are all `0`, while odd units metric values are all `1`. + """ + with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_file: + writer = csv.writer(temp_file) + writer.writerow(["unit_id", "metric1", "metric2", "metric3"]) + for i in range(5): + writer.writerow([i * 2, 0, 0, 0]) + writer.writerow([i * 2 + 1, 1, 1, 1]) + return temp_file.name + + +def test_load_and_preprocess_full(trainer): + """Check that we load and preprocess the csv file from `make_temp_training_csv` + correctly.""" + temp_file_path = make_temp_training_csv() + + # Load and preprocess the data from the temporary CSV file + trainer.load_and_preprocess_csv([temp_file_path]) + + # Assert that the data is loaded and preprocessed correctly + for a, row in trainer.X.iterrows(): + assert np.all(row.values == [float(a % 2)] * 3) + for a, label in enumerate(trainer.y.values): + assert label == a % 2 + for a, row in trainer.testing_metrics.iterrows(): + assert np.all(row.values == [a % 2] * 3) + assert row.name == a + + +def test_apply_scaling_imputation(trainer): + """Take a simple training and test set and check that they are corrected scaled, + using a standard scaler which rescales the training distribution to have mean 0 + and variance 1. Length between each row is 3, so if x0 is the first value in the + column, all other values are scaled as x -> 2/3(x - x0) - 1. The y (labled) values + do not get scaled.""" + + from sklearn.impute._knn import KNNImputer + from sklearn.preprocessing._data import StandardScaler + + imputation_strategy = "knn" + scaling_technique = "standard_scaler" + X_train = np.array([[1, 2, 3], [4, 5, 6]]) + X_test = np.array([[7, 8, 9], [10, 11, 12]]) + y_train = np.array([0, 1]) + y_test = np.array([2, 3]) + + X_train_scaled, X_test_scaled, y_train_scaled, y_test_scaled, imputer, scaler = trainer.apply_scaling_imputation( + imputation_strategy, scaling_technique, X_train, X_test, y_train, y_test + ) + + first_row_elements = X_train[0] + for a, row in enumerate(X_train): + assert np.all(2 / 3 * (row - first_row_elements) - 1.0 == X_train_scaled[a]) + for a, row in enumerate(X_test): + assert np.all(2 / 3 * (row - first_row_elements) - 1.0 == X_test_scaled[a]) + + assert np.all(y_train == y_train_scaled) + assert np.all(y_test == y_test_scaled) + + assert isinstance(imputer, KNNImputer) + assert isinstance(scaler, StandardScaler) + + +def test_get_classifier_search_space(trainer): + """For each classifier, there is a hyperparameter space we search over to find its + most accurate incarnation. Here, we check that we do indeed load the approprirate + dict of hyperparameter possibilities""" + + from sklearn.linear_model._logistic import LogisticRegression + + classifier = "LogisticRegression" + model, param_space = trainer.get_classifier_search_space(classifier) + + assert isinstance(model, LogisticRegression) + assert len(param_space) > 0 + assert isinstance(param_space, dict) + + +def test_get_custom_classifier_search_space(): + """Check that if a user passes a custom hyperparameter search space, that this is + passed correctly to the trainer.""" + + classifier = { + "LogisticRegression": { + "C": [0.1, 8.0], + "solver": ["lbfgs"], + "max_iter": [100, 400], + } + } + trainer = CurationModelTrainer(classifiers=classifier, labels=[[0, 1, 0, 1, 0, 1, 0, 1, 0, 1]]) + + model, param_space = trainer.get_classifier_search_space(list(classifier.keys())[0]) + assert param_space == classifier["LogisticRegression"] + + +def test_saved_files(trainer): + """During the trainer's creation, the following files should be created: + - best_model.skops + - labels.csv + - model_accuracies.csv + - model_info.json + - training_data.csv + This test checks that these exist, and checks some properties of the files.""" + + import pandas as pd + import json + + trainer.X = np.random.rand(10, 3) + trainer.y = np.append(np.ones(5), np.zeros(5)) + + trainer.evaluate_model_config() + trainer_folder = Path(trainer.folder) + + assert trainer_folder.is_dir() + + best_model_path = trainer_folder / "best_model.skops" + model_accuracies_path = trainer_folder / "model_accuracies.csv" + training_data_path = trainer_folder / "training_data.csv" + labels_path = trainer_folder / "labels.csv" + model_info_path = trainer_folder / "model_info.json" + + assert (best_model_path).is_file() + + model_accuracies = pd.read_csv(model_accuracies_path) + model_accuracies["classifier name"].values[0] == "LogisticRegression" + assert len(model_accuracies) == 1 + + training_data = pd.read_csv(training_data_path) + assert np.all(np.isclose(training_data.values[:, 1:4], trainer.X, rtol=1e-10)) + + labels = pd.read_csv(labels_path) + assert np.all(labels.values[:, 1] == trainer.y.astype("float")) + + model_info = pd.read_json(model_info_path) + + with open(model_info_path) as f: + model_info = json.load(f) + + assert set(model_info.keys()) == set(["metric_params", "requirements", "label_conversion"]) + + +def test_train_model(): + """A simple function test to check that `train_model` doesn't fail with one csv inputs""" + + metrics_path = make_temp_training_csv() + folder = tempfile.mkdtemp() + metric_names = ["metric1", "metric2", "metric3"] + trainer = train_model( + mode="csv", + metrics_paths=[metrics_path], + folder=folder, + labels=[[0, 1, 0, 1, 0, 1, 0, 1, 0, 1]], + metric_names=metric_names, + imputation_strategies=["median"], + scaling_techniques=["standard_scaler"], + classifiers=["LogisticRegression"], + overwrite=True, + search_kwargs={"cv": 3, "scoring": "balanced_accuracy", "n_iter": 1}, + ) + assert isinstance(trainer, CurationModelTrainer) + + +def test_train_model_using_two_csvs(): + """Models can be trained using more than one set of training data. This test checks + that `train_model` works with two inputs, from csv files.""" + + metrics_path_1 = make_temp_training_csv() + metrics_path_2 = make_temp_training_csv() + + folder = tempfile.mkdtemp() + metric_names = ["metric1", "metric2", "metric3"] + + trainer = train_model( + mode="csv", + metrics_paths=[metrics_path_1, metrics_path_2], + folder=folder, + labels=[[0, 1, 0, 1, 0, 1, 0, 1, 0, 1], [0, 1, 0, 1, 0, 1, 0, 1, 0, 1]], + metric_names=metric_names, + imputation_strategies=["median"], + scaling_techniques=["standard_scaler"], + classifiers=["LogisticRegression"], + overwrite=True, + ) + assert isinstance(trainer, CurationModelTrainer) + + +def test_train_using_two_sorting_analyzers(): + """Models can be trained using more than one set of training data. This test checks + that `train_model` works with two inputs, from sorting analzyers. It also checks that + an error is raised if the sorting_analyzers have different sets of metrics computed.""" + + sorting_analyzer_1 = make_sorting_analyzer() + sorting_analyzer_1.compute({"quality_metrics": {"metric_names": ["num_spikes", "snr"]}}) + + sorting_analyzer_2 = make_sorting_analyzer() + sorting_analyzer_2.compute({"quality_metrics": {"metric_names": ["num_spikes", "snr"]}}) + + labels_1 = [0, 1, 1, 1, 1] + labels_2 = [1, 1, 0, 1, 1] + + folder = tempfile.mkdtemp() + trainer = train_model( + analyzers=[sorting_analyzer_1, sorting_analyzer_2], + folder=folder, + labels=[labels_1, labels_2], + imputation_strategies=["median"], + scaling_techniques=["standard_scaler"], + classifiers=["LogisticRegression"], + overwrite=True, + ) + + assert isinstance(trainer, CurationModelTrainer) + + # Check that there is an error raised if the metric names are different + sorting_analyzer_2 = make_sorting_analyzer() + sorting_analyzer_2.compute({"quality_metrics": {"metric_names": ["num_spikes"], "delete_existing_metrics": True}}) + + with pytest.raises(Exception): + trainer = train_model( + analyzers=[sorting_analyzer_1, sorting_analyzer_2], + folder=folder, + labels=[labels_1, labels_2], + imputation_strategies=["median"], + scaling_techniques=["standard_scaler"], + classifiers=["LogisticRegression"], + overwrite=True, + ) + + # Now check that there is an error raised if we demand the same metric params, but don't have them + + sorting_analyzer_2.compute( + { + "quality_metrics": { + "metric_names": ["num_spikes", "snr"], + "metric_params": {"snr": {"peak_mode": "at_index"}}, + } + } + ) + + with pytest.raises(Exception): + train_model( + analyzers=[sorting_analyzer_1, sorting_analyzer_2], + folder=folder, + labels=[labels_1, labels_2], + imputation_strategies=["median"], + scaling_techniques=["standard_scaler"], + classifiers=["LogisticRegression"], + search_kwargs={"cv": 3, "scoring": "balanced_accuracy", "n_iter": 1}, + overwrite=True, + enforce_metric_params=True, + ) diff --git a/src/spikeinterface/curation/tests/trained_pipeline/best_model.skops b/src/spikeinterface/curation/tests/trained_pipeline/best_model.skops new file mode 100644 index 0000000000..362405f917 Binary files /dev/null and b/src/spikeinterface/curation/tests/trained_pipeline/best_model.skops differ diff --git a/src/spikeinterface/curation/tests/trained_pipeline/labels.csv b/src/spikeinterface/curation/tests/trained_pipeline/labels.csv new file mode 100644 index 0000000000..46680a9e89 --- /dev/null +++ b/src/spikeinterface/curation/tests/trained_pipeline/labels.csv @@ -0,0 +1,21 @@ +unit_index,0 +0,1 +1,0 +2,1 +3,0 +4,1 +0,1 +1,0 +2,1 +3,0 +4,1 +0,1 +1,0 +2,1 +3,0 +4,1 +0,1 +1,0 +2,1 +3,0 +4,1 diff --git a/src/spikeinterface/curation/tests/trained_pipeline/model_accuracies.csv b/src/spikeinterface/curation/tests/trained_pipeline/model_accuracies.csv new file mode 100644 index 0000000000..7f015c380b --- /dev/null +++ b/src/spikeinterface/curation/tests/trained_pipeline/model_accuracies.csv @@ -0,0 +1,2 @@ +,classifier name,imputation_strategy,scaling_strategy,accuracy,precision,recall,model_id,best_params +0,LogisticRegression,median,StandardScaler(),1.0000,1.0000,1.0000,0,"OrderedDict([('C', 4.811707275233983), ('max_iter', 384), ('solver', 'saga')])" diff --git a/src/spikeinterface/curation/tests/trained_pipeline/model_info.json b/src/spikeinterface/curation/tests/trained_pipeline/model_info.json new file mode 100644 index 0000000000..75ced28486 --- /dev/null +++ b/src/spikeinterface/curation/tests/trained_pipeline/model_info.json @@ -0,0 +1,60 @@ +{ + "metric_params": { + "quality_metric_params": { + "metric_names": [ + "snr", + "num_spikes" + ], + "peak_sign": null, + "seed": null, + "metric_params": { + "num_spikes": {}, + "snr": { + "peak_sign": "neg", + "peak_mode": "extremum" + } + }, + "skip_pc_metrics": false, + "delete_existing_metrics": false, + "metrics_to_compute": [ + "snr", + "num_spikes" + ] + }, + "template_metric_params": { + "metric_names": [ + "half_width" + ], + "sparsity": null, + "peak_sign": "neg", + "upsampling_factor": 10, + "metric_params": { + "half_width": { + "recovery_window_ms": 0.7, + "peak_relative_threshold": 0.2, + "peak_width_ms": 0.1, + "depth_direction": "y", + "min_channels_for_velocity": 5, + "min_r2_velocity": 0.5, + "exp_peak_function": "ptp", + "min_r2_exp_decay": 0.5, + "spread_threshold": 0.2, + "spread_smooth_um": 20, + "column_range": null + } + }, + "delete_existing_metrics": false, + "metrics_to_compute": [ + "half_width" + ] + } + }, + "requirements": { + "spikeinterface": "0.101.1", + "scikit-learn": "1.3.2" + }, + "label_conversion": { + "1": 1, + "0": 0 + } +} diff --git a/src/spikeinterface/curation/tests/trained_pipeline/training_data.csv b/src/spikeinterface/curation/tests/trained_pipeline/training_data.csv new file mode 100644 index 0000000000..c9efca17ad --- /dev/null +++ b/src/spikeinterface/curation/tests/trained_pipeline/training_data.csv @@ -0,0 +1,21 @@ +unit_id,snr,num_spikes,half_width +0,21.026926,5968.0,0.00027333334 +1,34.64474,5928.0,0.00023666666 +2,6.986315,5954.0,0.00026666667 +3,8.223127,6032.0,0.00020333333 +4,2.7464194,6002.0,0.00026666667 +0,21.026926,5968.0,0.00027333334 +1,34.64474,5928.0,0.00023666666 +2,6.986315,5954.0,0.00026666667 +3,8.223127,6032.0,0.00020333333 +4,2.7464194,6002.0,0.00026666667 +0,21.026926,5968.0,0.00027333334 +1,34.64474,5928.0,0.00023666666 +2,6.986315,5954.0,0.00026666667 +3,8.223127,6032.0,0.00020333333 +4,2.7464194,6002.0,0.00026666667 +0,21.026926,5968.0,0.00027333334 +1,34.64474,5928.0,0.00023666666 +2,6.986315,5954.0,0.00026666667 +3,8.223127,6032.0,0.00020333333 +4,2.7464194,6002.0,0.00026666667 diff --git a/src/spikeinterface/curation/train_manual_curation.py b/src/spikeinterface/curation/train_manual_curation.py new file mode 100644 index 0000000000..7b315b0fba --- /dev/null +++ b/src/spikeinterface/curation/train_manual_curation.py @@ -0,0 +1,843 @@ +import os +import warnings +import numpy as np +import json +import spikeinterface +from spikeinterface.core.job_tools import fix_job_kwargs +from spikeinterface.qualitymetrics import ( + get_quality_metric_list, + get_quality_pca_metric_list, + qm_compute_name_to_column_names, +) +from spikeinterface.postprocessing import get_template_metric_names +from spikeinterface.postprocessing.template_metrics import tm_compute_name_to_column_names +from pathlib import Path +from copy import deepcopy + + +def get_default_classifier_search_spaces(): + + from scipy.stats import uniform, randint + + default_classifier_search_spaces = { + "RandomForestClassifier": { + "n_estimators": [100, 150], + "criterion": ["gini", "entropy"], + "min_samples_split": [2, 4], + "min_samples_leaf": [2, 4], + "class_weight": ["balanced", "balanced_subsample"], + }, + "AdaBoostClassifier": { + "learning_rate": [1, 2], + "n_estimators": [50, 100], + "algorithm": ["SAMME", "SAMME.R"], + }, + "GradientBoostingClassifier": { + "learning_rate": uniform(0.05, 0.1), + "n_estimators": randint(100, 150), + "max_depth": [2, 4], + "min_samples_split": [2, 4], + "min_samples_leaf": [2, 4], + }, + "SVC": { + "C": uniform(0.001, 10.0), + "kernel": ["sigmoid", "rbf"], + "gamma": uniform(0.001, 10.0), + "probability": [True], + }, + "LogisticRegression": { + "C": uniform(0.001, 10.0), + "solver": ["newton-cg", "lbfgs", "liblinear", "sag", "saga"], + "max_iter": [100], + }, + "XGBClassifier": { + "max_depth": [2, 4], + "eta": uniform(0.2, 0.5), + "sampling_method": ["uniform"], + "grow_policy": ["depthwise", "lossguide"], + }, + "CatBoostClassifier": {"depth": [2, 4], "learning_rate": uniform(0.05, 0.15), "n_estimators": [100, 150]}, + "LGBMClassifier": {"learning_rate": uniform(0.05, 0.15), "n_estimators": randint(100, 150)}, + "MLPClassifier": { + "activation": ["tanh", "relu"], + "solver": ["adam"], + "alpha": uniform(1e-7, 1e-1), + "learning_rate": ["constant", "adaptive"], + "n_iter_no_change": [32], + }, + } + + return default_classifier_search_spaces + + +class CurationModelTrainer: + """ + Used to train and evaluate machine learning models for spike sorting curation. + + Parameters + ---------- + labels : list of lists, default: None + List of curated labels for each unit; must be in the same order as the metrics data. + folder : str, default: None + The folder where outputs such as models and evaluation metrics will be saved, if specified. Requires the skops library. If None, output will not be saved on file system. + metric_names : list of str, default: None + A list of metrics to use for training. If None, default metrics will be used. + imputation_strategies : list of str | None, default: None + A list of imputation strategies to try. Can be "knn”, "iterative" or any allowed + strategy passable to the sklearn `SimpleImputer`. If None, the default strategies + `["median", "most_frequent", "knn", "iterative"]` will be used. + scaling_techniques : list of str | None, default: None + A list of scaling techniques to try. Can be "standard_scaler", "min_max_scaler", + or "robust_scaler", If None, all techniques will be used. + classifiers : list of str or dict, default: None + A list of classifiers to evaluate. Optionally, a dictionary of classifiers and their hyperparameter search spaces can be provided. If None, default classifiers will be used. Check the `get_classifier_search_space` method for the default search spaces & format for custom spaces. + test_size : float, default: 0.2 + Proportion of the dataset to include in the test split, passed to `train_test_split` from `sklear`. + seed : int, default: None + Random seed for reproducibility. If None, a random seed will be generated. + smote : bool, default: False + Whether to apply SMOTE for class imbalance. Default is False. Requires imbalanced-learn package. + verbose : bool, default: True + If True, useful information is printed during training. + search_kwargs : dict or None, default: None + Keyword arguments passed to `BayesSearchCV` or `RandomizedSearchCV` from `sklearn`. If None, use + `search_kwargs = {'cv': 3, 'scoring': 'balanced_accuracy', 'n_iter': 25}`. + + Attributes + ---------- + folder : str + The folder where outputs such as models and evaluation metrics will be saved. Requires the skops library. + labels : list of lists, default: None + List of curated labels for each `sorting_analyzer` and each unit; must be in the same order as the metrics data. + imputation_strategies : list of str | None, default: None + A list of imputation strategies to try. Can be "knn”, "iterative" or any allowed + strategy passable to the sklearn `SimpleImputer`. If None, the default strategies + `["median", "most_frequent", "knn", "iterative"]` will be used. + scaling_techniques : list of str | None, default: None + A list of scaling techniques to try. Can be "standard_scaler", "min_max_scaler", + or "robust_scaler", If None, all techniques will be used. + classifiers : list of str + The list of classifiers to evaluate. + classifier_search_space : dict or None + Dictionary of classifiers and their hyperparameter search spaces, if provided. If None, default search spaces are used. + seed : int + Random seed for reproducibility. + metrics_list : list of str + The list of metrics to use for training. + X : pandas.DataFrame or None + The feature matrix after preprocessing. + y : pandas.Series or None + The target vector after preprocessing. + testing_metrics : dict or None + Dictionary to hold testing metrics data. + label_conversion : dict or None + Dictionary to map string labels to integer codes if target column contains string labels. + + Methods + ------- + get_default_metrics_list() + Returns the default list of metrics. + load_and_preprocess_full(path) + Loads and preprocesses the data from the given path. + load_data_file(path) + Loads the data file from the given path. + process_test_data_for_classification() + Processes the test data for classification. + apply_scaling_imputation(imputation_strategy, scaling_technique, X_train, X_val, y_train, y_val) + Applies the specified imputation and scaling techniques to the data. + get_classifier_instance(classifier_name) + Returns an instance of the specified classifier. + get_classifier_search_space(classifier_name) + Returns the search space for hyperparameter tuning for the specified classifier. + get_classifier_search_space() + Returns the default search spaces for hyperparameter tuning for the classifiers. + evaluate_model_config(imputation_strategies, scaling_techniques, classifiers) + Evaluates the model configurations with the given imputation strategies, scaling techniques, and classifiers. + """ + + def __init__( + self, + labels=None, + folder=None, + metric_names=None, + imputation_strategies=None, + scaling_techniques=None, + classifiers=None, + test_size=0.2, + seed=None, + smote=False, + verbose=True, + search_kwargs=None, + **job_kwargs, + ): + + import pandas as pd + + if imputation_strategies is None: + imputation_strategies = ["median", "most_frequent", "knn", "iterative"] + + if scaling_techniques is None: + scaling_techniques = [ + "standard_scaler", + "min_max_scaler", + "robust_scaler", + ] + + if classifiers is None: + self.classifiers = ["RandomForestClassifier"] + self.classifier_search_space = None + elif isinstance(classifiers, dict): + self.classifiers = list(classifiers.keys()) + self.classifier_search_space = classifiers + elif isinstance(classifiers, list): + self.classifiers = classifiers + self.classifier_search_space = None + else: + raise ValueError("classifiers must be a list or dictionary") + + # check if labels is a list of lists + if not all(isinstance(label, list) or isinstance(label, np.ndarray) for label in labels): + raise ValueError("labels must be a list of lists") + + self.folder = Path(folder) if folder is not None else None + self.imputation_strategies = imputation_strategies + self.scaling_techniques = scaling_techniques + self.test_size = test_size + self.seed = seed if seed is not None else np.random.default_rng(seed=None).integers(0, 2**31) + self.metrics_params = {} + self.smote = smote + self.label_conversion = None + self.verbose = verbose + self.search_kwargs = search_kwargs + + self.X = None + self.testing_metrics = None + + self.requirements = {"spikeinterface": spikeinterface.__version__} + + self.y = pd.concat([pd.DataFrame(one_labels)[0] for one_labels in labels]) + + self.metric_names = metric_names + + if self.folder is not None and not self.folder.is_dir(): + self.folder.mkdir(parents=True, exist_ok=True) + + # update job_kwargs with global ones + job_kwargs = fix_job_kwargs(job_kwargs) + self.n_jobs = job_kwargs["n_jobs"] + + def get_default_metrics_list(self): + """Returns the default list of metrics.""" + return get_quality_metric_list() + get_quality_pca_metric_list() + get_template_metric_names() + + def load_and_preprocess_analyzers(self, analyzers, enforce_metric_params): + """ + Loads and preprocesses the quality metrics and labels from the given list of SortingAnalyzer objects. + """ + import pandas as pd + + metrics_for_each_analyzer = [_get_computed_metrics(an) for an in analyzers] + check_metric_names_are_the_same(metrics_for_each_analyzer) + + self.testing_metrics = pd.concat(metrics_for_each_analyzer, axis=0) + + # Set metric names to those calculated if not provided + if self.metric_names is None: + warnings.warn("No metric_names provided, using all metrics calculated by the analyzers") + self.metric_names = self.testing_metrics.columns.tolist() + + conflicting_metrics = self._check_metrics_parameters(analyzers, enforce_metric_params) + + self.metrics_params = {} + + extension_names = ["quality_metrics", "template_metrics"] + metric_extensions = [analyzers[0].get_extension(extension_name) for extension_name in extension_names] + + for metric_extension, extension_name in zip(metric_extensions, extension_names): + + # remove the 's' at the end of the extension name + extension_name = extension_name[:-1] + if metric_extension is not None: + self.metrics_params[extension_name + "_params"] = metric_extension.params + + # Only save metric params which are 1) consistent and 2) exist in metric_names + metric_names = metric_extension.params["metric_names"] + consistent_metrics = list(set(metric_names).difference(set(conflicting_metrics))) + consistent_metric_params = { + metric: metric_extension.params["metric_params"][metric] for metric in consistent_metrics + } + self.metrics_params[extension_name + "_params"]["metric_params"] = consistent_metric_params + + self.process_test_data_for_classification() + + def _check_metrics_parameters(self, analyzers, enforce_metric_params): + """Checks that the metrics of each analyzer have been calcualted using the same parameters""" + + extension_names = ["quality_metrics", "template_metrics"] + + conflicting_metrics = [] + for analyzer_index_1, analyzer_1 in enumerate(analyzers): + for analyzer_index_2, analyzer_2 in enumerate(analyzers): + + if analyzer_index_1 <= analyzer_index_2: + continue + else: + + metric_params_1 = {} + metric_params_2 = {} + + for extension_name in extension_names: + if (extension_1 := analyzer_1.get_extension(extension_name)) is not None: + metric_params_1.update(extension_1.params["metric_params"]) + if (extension_2 := analyzer_2.get_extension(extension_name)) is not None: + metric_params_2.update(extension_2.params["metric_params"]) + + conflicting_metrics_between_1_2 = [] + # check quality metrics params + for metric, params_1 in metric_params_1.items(): + if params_1 != metric_params_2.get(metric): + conflicting_metrics_between_1_2.append(metric) + + conflicting_metrics += conflicting_metrics_between_1_2 + + if len(conflicting_metrics_between_1_2) > 0: + warning_message = f"Parameters used to calculate {conflicting_metrics_between_1_2} are different for sorting_analyzers #{analyzer_index_1} and #{analyzer_index_2}" + if enforce_metric_params is True: + raise Exception(warning_message) + else: + warnings.warn(warning_message) + + unique_conflicting_metrics = set(conflicting_metrics) + return unique_conflicting_metrics + + def load_and_preprocess_csv(self, paths): + self._load_data_files(paths) + self.process_test_data_for_classification() + self.get_metric_params_csv() + + def get_metric_params_csv(self): + + from itertools import chain + + qm_metric_names = list(chain.from_iterable(qm_compute_name_to_column_names.values())) + tm_metric_names = list(chain.from_iterable(tm_compute_name_to_column_names.values())) + + quality_metric_names = [] + template_metric_names = [] + + for metric_name in self.metric_names: + if metric_name in qm_metric_names: + quality_metric_names.append(metric_name) + if metric_name in tm_metric_names: + template_metric_names.append(metric_name) + + self.metrics_params = {} + if quality_metric_names != {}: + self.metrics_params["quality_metric_params"] = {"metric_names": quality_metric_names} + if template_metric_names != {}: + self.metrics_params["template_metric_params"] = {"metric_names": template_metric_names} + + return + + def process_test_data_for_classification(self): + """ + Cleans the input data so that it can be used by sklearn. + + Extracts the target variable and features from the loaded dataset. + It handles string labels by converting them to integer codes and reindexes the + feature matrix to match the specified metrics list. Infinite values in the features + are replaced with NaN, and any remaining NaN values are filled with zeros. + + Raises + ------ + ValueError + If the target column specified is not found in the loaded dataset. + + Notes + ----- + If the target column contains string labels, a warning is issued and the labels + are converted to integer codes. The mapping from string labels to integer codes + is stored in the `label_conversion` attribute. + """ + + # Convert string labels to integer codes to allow classification + new_y = self.y.astype("category").cat.codes + self.label_conversion = dict(zip(new_y, self.y)) + self.y = new_y + + # Extract features + try: + if (set(self.metric_names) - set(self.testing_metrics.columns) != set()) and self.verbose is True: + print( + f"Dropped metrics (calculated but not included in metric_names): {set(self.testing_metrics.columns) - set(self.metric_names)}" + ) + self.X = self.testing_metrics[self.metric_names] + except KeyError as e: + raise KeyError(f"{str(e)}, metrics_list contains invalid metric names") + + self.X = self.testing_metrics.reindex(columns=self.metric_names) + self.X = _format_metric_dataframe(self.X) + + def apply_scaling_imputation(self, imputation_strategy, scaling_technique, X_train, X_test, y_train, y_test): + """Impute and scale the data using the specified techniques.""" + from sklearn.experimental import enable_iterative_imputer + from sklearn.impute import SimpleImputer, KNNImputer, IterativeImputer + from sklearn.ensemble import HistGradientBoostingRegressor + from sklearn.preprocessing import StandardScaler, MinMaxScaler, RobustScaler + + if imputation_strategy == "knn": + imputer = KNNImputer(n_neighbors=5) + elif imputation_strategy == "iterative": + imputer = IterativeImputer( + estimator=HistGradientBoostingRegressor(random_state=self.seed), random_state=self.seed + ) + else: + imputer = SimpleImputer(strategy=imputation_strategy) + + if scaling_technique == "standard_scaler": + scaler = StandardScaler() + elif scaling_technique == "min_max_scaler": + scaler = MinMaxScaler() + elif scaling_technique == "robust_scaler": + scaler = RobustScaler() + else: + raise ValueError( + f"Unknown scaling technique: {scaling_technique}. Supported scaling techniques are 'standard_scaler', 'min_max_scaler' and 'robust_scaler." + ) + + y_train_processed = y_train.astype(int) + y_test = y_test.astype(int) + + X_train_imputed = imputer.fit_transform(X_train) + X_test_imputed = imputer.transform(X_test) + X_train_processed = scaler.fit_transform(X_train_imputed) + X_test_processed = scaler.transform(X_test_imputed) + + # Apply SMOTE for class imbalance + if self.smote: + try: + from imblearn.over_sampling import SMOTE + except ModuleNotFoundError: + raise ModuleNotFoundError("Please install imbalanced-learn package to use SMOTE") + smote = SMOTE(random_state=self.seed) + X_train_processed, y_train_processed = smote.fit_resample(X_train_processed, y_train_processed) + + return X_train_processed, X_test_processed, y_train_processed, y_test, imputer, scaler + + def get_classifier_instance(self, classifier_name): + from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier, GradientBoostingClassifier + from sklearn.svm import SVC + from sklearn.linear_model import LogisticRegression + from sklearn.neural_network import MLPClassifier + + classifier_mapping = { + "RandomForestClassifier": RandomForestClassifier(random_state=self.seed), + "AdaBoostClassifier": AdaBoostClassifier(random_state=self.seed), + "GradientBoostingClassifier": GradientBoostingClassifier(random_state=self.seed), + "SVC": SVC(random_state=self.seed), + "LogisticRegression": LogisticRegression(random_state=self.seed), + "MLPClassifier": MLPClassifier(random_state=self.seed), + } + + # Check lightgbm package install + if classifier_name == "LGBMClassifier": + try: + import lightgbm + + self.requirements["lightgbm"] = lightgbm.__version__ + classifier_mapping["LGBMClassifier"] = lightgbm.LGBMClassifier(random_state=self.seed, verbose=-1) + except ImportError: + raise ImportError("Please install lightgbm package to use LGBMClassifier") + elif classifier_name == "CatBoostClassifier": + try: + import catboost + + self.requirements["catboost"] = catboost.__version__ + classifier_mapping["CatBoostClassifier"] = catboost.CatBoostClassifier( + silent=True, random_state=self.seed + ) + except ImportError: + raise ImportError("Please install catboost package to use CatBoostClassifier") + elif classifier_name == "XGBClassifier": + try: + import xgboost + + self.requirements["xgboost"] = xgboost.__version__ + classifier_mapping["XGBClassifier"] = xgboost.XGBClassifier( + use_label_encoder=False, random_state=self.seed + ) + except ImportError: + raise ImportError("Please install xgboost package to use XGBClassifier") + + if classifier_name not in classifier_mapping: + raise ValueError( + f"Unknown classifier: {classifier_name}. To see list of supported classifiers run\n\t>>> from spikeinterface.curation import get_default_classifier_search_spaces\n\t>>> print(get_default_classifier_search_spaces().keys())" + ) + + return classifier_mapping[classifier_name] + + def get_classifier_search_space(self, classifier_name): + + default_classifier_search_spaces = get_default_classifier_search_spaces() + + if classifier_name not in default_classifier_search_spaces: + raise ValueError( + f"Unknown classifier: {classifier_name}. To see list of supported classifiers run\n\t>>> from spikeinterface.curation import get_default_classifier_search_spaces\n\t>>> print(get_default_classifier_search_spaces().keys())" + ) + + model = self.get_classifier_instance(classifier_name) + if self.classifier_search_space is not None: + param_space = self.classifier_search_space[classifier_name] + else: + param_space = default_classifier_search_spaces[classifier_name] + return model, param_space + + def evaluate_model_config(self): + """ + Evaluates the model configurations with the given imputation strategies, scaling techniques, and classifiers. + + This method splits the preprocessed data into training and testing sets, then evaluates the specified + combinations of imputation strategies, scaling techniques, and classifiers. The evaluation results are + saved to the output folder. + + Raises + ------ + ValueError + If any of the specified classifier names are not recognized. + + Notes + ----- + The method converts the classifier names to actual classifier instances before evaluating them. + The evaluation results, including the best model and its parameters, are saved to the output folder. + """ + from sklearn.model_selection import train_test_split + + X_train, X_test, y_train, y_test = train_test_split( + self.X, self.y, test_size=self.test_size, random_state=self.seed, stratify=self.y + ) + classifier_instances = [self.get_classifier_instance(clf) for clf in self.classifiers] + self._evaluate( + self.imputation_strategies, + self.scaling_techniques, + classifier_instances, + X_train, + X_test, + y_train, + y_test, + self.search_kwargs, + ) + + def _load_data_files(self, paths): + import pandas as pd + + self.testing_metrics = pd.concat([pd.read_csv(path, index_col=0) for path in paths], axis=0) + + def _evaluate( + self, imputation_strategies, scaling_techniques, classifiers, X_train, X_test, y_train, y_test, search_kwargs + ): + from joblib import Parallel, delayed + from sklearn.pipeline import Pipeline + import pandas as pd + + results = Parallel(n_jobs=self.n_jobs)( + delayed(self._train_and_evaluate)( + imputation_strategy, scaler, classifier, X_train, X_test, y_train, y_test, idx, search_kwargs + ) + for idx, (imputation_strategy, scaler, classifier) in enumerate( + (imputation_strategy, scaler, classifier) + for imputation_strategy in imputation_strategies + for scaler in scaling_techniques + for classifier in classifiers + ) + ) + + test_accuracies, models = zip(*results) + + if self.search_kwargs is None or self.search_kwargs.get("scoring"): + scoring_method = "balanced_accuracy" + else: + scoring_method = self.search_kwargs.get("scoring") + + self.test_accuracies_df = pd.DataFrame(test_accuracies).sort_values(scoring_method, ascending=False) + + best_model_id = int(self.test_accuracies_df.iloc[0]["model_id"]) + best_model, best_imputer, best_scaler = models[best_model_id] + + best_pipeline = Pipeline( + [("imputer", best_imputer), ("scaler", best_scaler), ("classifier", best_model.best_estimator_)] + ) + + self.best_pipeline = best_pipeline + + if self.folder is not None: + self._save() + + def _save(self): + from skops.io import dump + import sklearn + import pandas as pd + + # export training data and labels + pd.DataFrame(self.X).to_csv(self.folder / f"training_data.csv", index_label="unit_id") + pd.DataFrame(self.y).to_csv(self.folder / f"labels.csv", index_label="unit_index") + + self.requirements["scikit-learn"] = sklearn.__version__ + + # Dump to skops if folder is provided + dump(self.best_pipeline, self.folder / f"best_model.skops") + self.test_accuracies_df.to_csv(self.folder / f"model_accuracies.csv", float_format="%.4f") + + model_info = {} + model_info["metric_params"] = self.metrics_params + + model_info["requirements"] = self.requirements + + model_info["label_conversion"] = self.label_conversion + + param_file = self.folder / "model_info.json" + Path(param_file).write_text(json.dumps(model_info, indent=4), encoding="utf8") + + def _train_and_evaluate( + self, imputation_strategy, scaler, classifier, X_train, X_test, y_train, y_test, model_id, search_kwargs + ): + from sklearn.metrics import balanced_accuracy_score, precision_score, recall_score + + search_kwargs = set_default_search_kwargs(search_kwargs) + + X_train_scaled, X_test_scaled, y_train, y_test, imputer, scaler = self.apply_scaling_imputation( + imputation_strategy, scaler, X_train, X_test, y_train, y_test + ) + if self.verbose is True: + print(f"Running {classifier.__class__.__name__} with imputation {imputation_strategy} and scaling {scaler}") + model, param_space = self.get_classifier_search_space(classifier.__class__.__name__) + + try: + from skopt import BayesSearchCV + + model = BayesSearchCV( + model, + param_space, + random_state=self.seed, + **search_kwargs, + ) + except: + if self.verbose is True: + print("BayesSearchCV from scikit-optimize not available, using RandomizedSearchCV") + from sklearn.model_selection import RandomizedSearchCV + + model = RandomizedSearchCV(model, param_space, **search_kwargs) + + model.fit(X_train_scaled, y_train) + y_pred = model.predict(X_test_scaled) + balanced_acc = balanced_accuracy_score(y_test, y_pred) + precision = precision_score(y_test, y_pred, average="macro") + recall = recall_score(y_test, y_pred, average="macro") + return { + "classifier name": classifier.__class__.__name__, + "imputation_strategy": imputation_strategy, + "scaling_strategy": scaler, + "balanced_accuracy": balanced_acc, + "precision": precision, + "recall": recall, + "model_id": model_id, + "best_params": model.best_params_, + }, (model, imputer, scaler) + + +def train_model( + mode="analyzers", + labels=None, + analyzers=None, + metrics_paths=None, + folder=None, + metric_names=None, + imputation_strategies=None, + scaling_techniques=None, + classifiers=None, + test_size=0.2, + overwrite=False, + seed=None, + search_kwargs=None, + verbose=True, + enforce_metric_params=False, + **job_kwargs, +): + """ + Trains and evaluates machine learning models for spike sorting curation. + + This function initializes a `CurationModelTrainer` object, loads and preprocesses the data, + and evaluates the specified combinations of imputation strategies, scaling techniques, and classifiers. + The evaluation results, including the best model and its parameters, are saved to the output folder. + + Parameters + ---------- + mode : "analyzers" | "csv", default: "analyzers" + Mode to use for training. + analyzers : list of SortingAnalyzer | None, default: None + List of SortingAnalyzer objects containing the quality metrics and labels to use for training, if using 'analyzers' mode. + labels : list of list | None, default: None + List of curated labels for each unit; must be in the same order as the metrics data. + metrics_paths : list of str or None, default: None + List of paths to the CSV files containing the metrics data if using 'csv' mode. + folder : str | None, default: None + The folder where outputs such as models and evaluation metrics will be saved. + metric_names : list of str | None, default: None + A list of metrics to use for training. If None, default metrics will be used. + imputation_strategies : list of str | None, default: None + A list of imputation strategies to try. Can be "knn”, "iterative" or any allowed + strategy passable to the sklearn `SimpleImputer`. If None, the default strategies + `["median", "most_frequent", "knn", "iterative"]` will be used. + scaling_techniques : list of str | None, default: None + A list of scaling techniques to try. Can be "standard_scaler", "min_max_scaler", + or "robust_scaler", If None, all techniques will be used. + classifiers : list of str | dict | None, default: None + A list of classifiers to evaluate. Optionally, a dictionary of classifiers and their hyperparameter search spaces can be provided. If None, default classifiers will be used. Check the `get_classifier_search_space` method for the default search spaces & format for custom spaces. + test_size : float, default: 0.2 + Proportion of the dataset to include in the test split, passed to `train_test_split` from `sklear`. + overwrite : bool, default: False + Overwrites the `folder` if it already exists + seed : int | None, default: None + Random seed for reproducibility. If None, a random seed will be generated. + search_kwargs : dict or None, default: None + Keyword arguments passed to `BayesSearchCV` or `RandomizedSearchCV` from `sklearn`. If None, use + `search_kwargs = {'cv': 3, 'scoring': 'balanced_accuracy', 'n_iter': 25}`. + verbose : bool, default: True + If True, useful information is printed during training. + enforce_metric_params : bool, default: False + If True and metric parameters used to calculate metrics for different `sorting_analyzer`s are + different, an error will be raised. + + + Returns + ------- + CurationModelTrainer + The `CurationModelTrainer` object used for training and evaluation. + + Notes + ----- + This function handles the entire workflow of initializing the trainer, loading and preprocessing the data, + and evaluating the models. The evaluation results are saved to the specified output folder. + """ + + if folder is None: + raise Exception("You must supply a folder for the model to be saved in using `folder='path/to/folder/'`") + + if overwrite is False: + assert not Path(folder).is_dir(), f"folder {folder} already exists, choose another name or use overwrite=True" + + if labels is None: + raise Exception("You must supply a list of lists of curated labels using `labels = [[...],[...],...]`") + + if mode not in ["analyzers", "csv"]: + raise Exception("`mode` must be equal to 'analyzers' or 'csv'.") + + if (test_size > 1.0) or (0.0 > test_size): + raise Exception("`test_size` must be between 0.0 and 1.0") + + trainer = CurationModelTrainer( + labels=labels, + folder=folder, + metric_names=metric_names, + imputation_strategies=imputation_strategies, + scaling_techniques=scaling_techniques, + classifiers=classifiers, + test_size=test_size, + seed=seed, + verbose=verbose, + search_kwargs=search_kwargs, + **job_kwargs, + ) + + if mode == "analyzers": + assert analyzers is not None, "Analyzers must be provided as a list for mode 'analyzers'" + trainer.load_and_preprocess_analyzers(analyzers, enforce_metric_params) + + elif mode == "csv": + for metrics_path in metrics_paths: + assert Path(metrics_path).is_file(), f"{metrics_path} is not a file." + trainer.load_and_preprocess_csv(metrics_paths) + + trainer.evaluate_model_config() + return trainer + + +def _get_computed_metrics(sorting_analyzer): + """Loads and organises the computed metrics from a sorting_analyzer into a single dataframe""" + + import pandas as pd + + quality_metrics, template_metrics = try_to_get_metrics_from_analyzer(sorting_analyzer) + calculated_metrics = pd.concat([quality_metrics, template_metrics], axis=1) + + # Remove any metrics for non-existent units, raise error if no units are present + calculated_metrics = calculated_metrics.loc[calculated_metrics.index.isin(sorting_analyzer.sorting.get_unit_ids())] + if calculated_metrics.shape[0] == 0: + raise ValueError("No units present in sorting data") + + return calculated_metrics + + +def try_to_get_metrics_from_analyzer(sorting_analyzer): + + extension_names = ["quality_metrics", "template_metrics"] + metric_extensions = [sorting_analyzer.get_extension(extension_name) for extension_name in extension_names] + + if any(metric_extensions) is False: + raise ValueError( + "At least one of quality metrics or template metrics must be computed before classification.", + "Compute both using `sorting_analyzer.compute('quality_metrics', 'template_metrics')", + ) + + metric_extensions_data = [] + for metric_extension in metric_extensions: + try: + metric_extensions_data.append(metric_extension.get_data()) + except: + metric_extensions_data.append(None) + + return metric_extensions_data + + +def set_default_search_kwargs(search_kwargs): + + if search_kwargs is None: + search_kwargs = {} + + if search_kwargs.get("cv") is None: + search_kwargs["cv"] = 5 + if search_kwargs.get("scoring") is None: + search_kwargs["scoring"] = "balanced_accuracy" + if search_kwargs.get("n_iter") is None: + search_kwargs["n_iter"] = 25 + + return search_kwargs + + +def check_metric_names_are_the_same(metrics_for_each_analyzer): + """ + Given a list of dataframes, checks that the keys are all equal. + """ + + for i, metrics_for_analyzer_1 in enumerate(metrics_for_each_analyzer): + for j, metrics_for_analyzer_2 in enumerate(metrics_for_each_analyzer): + if i > j: + metric_names_1 = set(metrics_for_analyzer_1.keys()) + metric_names_2 = set(metrics_for_analyzer_2.keys()) + if metric_names_1 != metric_names_2: + metrics_in_1_but_not_2 = metric_names_1.difference(metric_names_2) + metrics_in_2_but_not_1 = metric_names_2.difference(metric_names_1) + + error_message = f"Computed metrics are not equal for sorting_analyzers #{j} and #{i}\n" + if metrics_in_1_but_not_2: + error_message += f"#{j} does not contain {metrics_in_1_but_not_2}, which #{i} does." + if metrics_in_2_but_not_1: + error_message += f"#{i} does not contain {metrics_in_2_but_not_1}, which #{j} does." + raise Exception(error_message) + + +def _format_metric_dataframe(input_data): + + input_data = input_data.map(lambda x: np.nan if np.isinf(x) else x) + input_data = input_data.astype("float32") + + return input_data diff --git a/src/spikeinterface/exporters/report.py b/src/spikeinterface/exporters/report.py index 3a4be9213a..ab08401382 100644 --- a/src/spikeinterface/exporters/report.py +++ b/src/spikeinterface/exporters/report.py @@ -20,10 +20,10 @@ def export_report( **job_kwargs, ): """ - Exports a SI spike sorting report. The report includes summary figures of the spike sorting output - (e.g. amplitude distributions, unit localization and depth VS amplitude) as well as unit-specific reports, - that include waveforms, templates, template maps, ISI distributions, and more. - + Exports a SI spike sorting report. The report includes summary figures of the spike sorting output. + What is plotted depends on what has been calculated. Unit locations and unit waveforms are always included. + Unit waveform densities, correlograms and spike amplitudes are plotted if `waveforms`, `correlograms`, + and `spike_amplitudes` have been computed for the given `sorting_analyzer`. Parameters ---------- diff --git a/src/spikeinterface/extractors/cbin_ibl.py b/src/spikeinterface/extractors/cbin_ibl.py index d7e5b58e11..728d352973 100644 --- a/src/spikeinterface/extractors/cbin_ibl.py +++ b/src/spikeinterface/extractors/cbin_ibl.py @@ -1,6 +1,7 @@ from __future__ import annotations from pathlib import Path +import warnings import numpy as np import probeinterface @@ -30,8 +31,10 @@ class CompressedBinaryIblExtractor(BaseRecording): stream_name : {"ap", "lp"}, default: "ap". Whether to load AP or LFP band, one of "ap" or "lp". - cbin_file : str or None, default None + cbin_file_path : str, Path or None, default None The cbin file of the recording. If None, searches in `folder_path` for file. + cbin_file : str or None, default None + (deprecated) The cbin file of the recording. If None, searches in `folder_path` for file. Returns ------- @@ -41,14 +44,23 @@ class CompressedBinaryIblExtractor(BaseRecording): installation_mesg = "To use the CompressedBinaryIblExtractor, install mtscomp: \n\n pip install mtscomp\n\n" - def __init__(self, folder_path=None, load_sync_channel=False, stream_name="ap", cbin_file=None): + def __init__( + self, folder_path=None, load_sync_channel=False, stream_name="ap", cbin_file_path=None, cbin_file=None + ): from neo.rawio.spikeglxrawio import read_meta_file try: import mtscomp except ImportError: raise ImportError(self.installation_mesg) - if cbin_file is None: + if cbin_file is not None: + warnings.warn( + "The `cbin_file` argument is deprecated, please use `cbin_file_path` instead", + DeprecationWarning, + stacklevel=2, + ) + cbin_file_path = cbin_file + if cbin_file_path is None: folder_path = Path(folder_path) # check bands assert stream_name in ["ap", "lp"], "stream_name must be one of: 'ap', 'lp'" @@ -60,17 +72,17 @@ def __init__(self, folder_path=None, load_sync_channel=False, stream_name="ap", assert ( len(curr_cbin_files) == 1 ), f"There should only be one `*.cbin` file in the folder, but {print(curr_cbin_files)} have been found" - cbin_file = curr_cbin_files[0] + cbin_file_path = curr_cbin_files[0] else: - cbin_file = Path(cbin_file) - folder_path = cbin_file.parent + cbin_file_path = Path(cbin_file_path) + folder_path = cbin_file_path.parent - ch_file = cbin_file.with_suffix(".ch") - meta_file = cbin_file.with_suffix(".meta") + ch_file = cbin_file_path.with_suffix(".ch") + meta_file = cbin_file_path.with_suffix(".meta") # reader cbuffer = mtscomp.Reader() - cbuffer.open(cbin_file, ch_file) + cbuffer.open(cbin_file_path, ch_file) # meta data meta = read_meta_file(meta_file) @@ -119,7 +131,7 @@ def __init__(self, folder_path=None, load_sync_channel=False, stream_name="ap", self._kwargs = { "folder_path": str(Path(folder_path).resolve()), "load_sync_channel": load_sync_channel, - "cbin_file": str(Path(cbin_file).resolve()), + "cbin_file_path": str(Path(cbin_file_path).resolve()), } diff --git a/src/spikeinterface/extractors/iblextractors.py b/src/spikeinterface/extractors/iblextractors.py index 5dd549347d..317ea21cce 100644 --- a/src/spikeinterface/extractors/iblextractors.py +++ b/src/spikeinterface/extractors/iblextractors.py @@ -105,6 +105,8 @@ def get_stream_names(eid: str, cache_folder: Optional[Union[Path, str]] = None, An instance of the ONE API to use for data loading. If not provided, a default instance is created using the default parameters. If you need to use a specific instance, you can create it using the ONE API and pass it here. + stream_type : "ap" | "lf" | None, default: None + The stream type to load, required when pid is provided and stream_name is not. Returns ------- @@ -140,6 +142,7 @@ def __init__( remove_cached: bool = True, stream: bool = True, one: "one.api.OneAlyx" = None, + stream_type: str | None = None, ): try: from brainbox.io.one import SpikeSortingLoader @@ -154,20 +157,24 @@ def __init__( one = IblRecordingExtractor._get_default_one(cache_folder=cache_folder) if pid is not None: + assert stream_type is not None, "When providing a PID, you must also provide a stream type." eid, _ = one.pid2eid(pid) - - stream_names = IblRecordingExtractor.get_stream_names(eid=eid, cache_folder=cache_folder, one=one) - if len(stream_names) > 1: - assert ( - stream_name is not None - ), f"Multiple streams found for session. Please specify a stream name from {stream_names}." - assert stream_name in stream_names, ( - f"The `stream_name` '{stream_name}' is not available for this experiment {eid}! " - f"Please choose one of {stream_names}." - ) + pids, probes = one.eid2pid(eid) + pname = probes[pids.index(pid)] + stream_name = f"{pname}.{stream_type}" else: - stream_name = stream_names[0] - pname, stream_type = stream_name.split(".") + stream_names = IblRecordingExtractor.get_stream_names(eid=eid, cache_folder=cache_folder, one=one) + if len(stream_names) > 1: + assert ( + stream_name is not None + ), f"Multiple streams found for session. Please specify a stream name from {stream_names}." + assert stream_name in stream_names, ( + f"The `stream_name` '{stream_name}' is not available for this experiment {eid}! " + f"Please choose one of {stream_names}." + ) + else: + stream_name = stream_names[0] + pname, stream_type = stream_name.split(".") self.ssl = SpikeSortingLoader(one=one, eid=eid, pid=pid, pname=pname) if pid is None: diff --git a/src/spikeinterface/extractors/mdaextractors.py b/src/spikeinterface/extractors/mdaextractors.py index f055e1d7c9..d2886d9e79 100644 --- a/src/spikeinterface/extractors/mdaextractors.py +++ b/src/spikeinterface/extractors/mdaextractors.py @@ -72,6 +72,7 @@ def write_recording( params_fname="params.json", geom_fname="geom.csv", dtype=None, + verbose=False, **job_kwargs, ): """Write a recording to file in MDA format. @@ -93,6 +94,8 @@ def write_recording( File name of geom file dtype : dtype or None, default: None Data type to be used. If None dtype is same as recording traces. + verbose : bool + If True, shows progress bar when saving recording. **job_kwargs: Use by job_tools modules to set: @@ -130,6 +133,7 @@ def write_recording( dtype=dtype, byte_offset=header_size, add_file_extension=False, + verbose=verbose, **job_kwargs, ) diff --git a/src/spikeinterface/extractors/neoextractors/__init__.py b/src/spikeinterface/extractors/neoextractors/__init__.py index bf52de7c1d..03d517b46e 100644 --- a/src/spikeinterface/extractors/neoextractors/__init__.py +++ b/src/spikeinterface/extractors/neoextractors/__init__.py @@ -9,6 +9,7 @@ from .mearec import MEArecRecordingExtractor, MEArecSortingExtractor, read_mearec from .mcsraw import MCSRawRecordingExtractor, read_mcsraw from .neuralynx import NeuralynxRecordingExtractor, NeuralynxSortingExtractor, read_neuralynx, read_neuralynx_sorting +from .neuronexus import NeuroNexusRecordingExtractor, read_neuronexus from .neuroscope import ( NeuroScopeRecordingExtractor, NeuroScopeSortingExtractor, @@ -54,6 +55,7 @@ MCSRawRecordingExtractor, NeuralynxRecordingExtractor, NeuroScopeRecordingExtractor, + NeuroNexusRecordingExtractor, NixRecordingExtractor, OpenEphysBinaryRecordingExtractor, OpenEphysLegacyRecordingExtractor, diff --git a/src/spikeinterface/extractors/neoextractors/neuronexus.py b/src/spikeinterface/extractors/neoextractors/neuronexus.py new file mode 100644 index 0000000000..dca482b28a --- /dev/null +++ b/src/spikeinterface/extractors/neoextractors/neuronexus.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from pathlib import Path + +from spikeinterface.core.core_tools import define_function_from_class + +from .neobaseextractor import NeoBaseRecordingExtractor, NeoBaseSortingExtractor + + +class NeuroNexusRecordingExtractor(NeoBaseRecordingExtractor): + """ + Class for reading data from NeuroNexus Allego. + + Based on :py:class:`neo.rawio.NeuronexusRawIO` + + Parameters + ---------- + file_path : str | Path + The file path to the metadata .xdat.json file of an Allego session + stream_id : str | None, default: None + If there are several streams, specify the stream id you want to load. + stream_name : str | None, default: None + If there are several streams, specify the stream name you want to load. + all_annotations : bool, default: False + Load exhaustively all annotations from neo. + use_names_as_ids : bool, default: False + Determines the format of the channel IDs used by the extractor. If set to True, the channel IDs will be the + names from NeoRawIO. If set to False, the channel IDs will be the ids provided by NeoRawIO. + + In Neuronexus the ids provided by NeoRawIO are the hardware channel ids stored as `ntv_chan_name` within + the metada and the names are the `chan_names` + + + """ + + NeoRawIOClass = "NeuroNexusRawIO" + + def __init__( + self, + file_path: str | Path, + stream_id: str | None = None, + stream_name: str | None = None, + all_annotations: bool = False, + use_names_as_ids: bool = False, + ): + neo_kwargs = self.map_to_neo_kwargs(file_path) + NeoBaseRecordingExtractor.__init__( + self, + stream_id=stream_id, + stream_name=stream_name, + all_annotations=all_annotations, + use_names_as_ids=use_names_as_ids, + **neo_kwargs, + ) + + self._kwargs.update(dict(file_path=str(Path(file_path).resolve()))) + + @classmethod + def map_to_neo_kwargs(cls, file_path): + + neo_kwargs = {"filename": str(file_path)} + + return neo_kwargs + + +read_neuronexus = define_function_from_class(source_class=NeuroNexusRecordingExtractor, name="read_neuronexus") diff --git a/src/spikeinterface/extractors/neoextractors/openephys.py b/src/spikeinterface/extractors/neoextractors/openephys.py index 24bc7591e4..dd24e6cae7 100644 --- a/src/spikeinterface/extractors/neoextractors/openephys.py +++ b/src/spikeinterface/extractors/neoextractors/openephys.py @@ -250,6 +250,7 @@ def __init__( except: warnings.warn(f"Could not load synchronized timestamps for {stream_name}") + self.annotate(experiment_name=f"experiment{exp_id}") self._stream_folders = stream_folders self._kwargs.update( diff --git a/src/spikeinterface/extractors/neoextractors/plexon2.py b/src/spikeinterface/extractors/neoextractors/plexon2.py index 2f360ed864..e0604f7496 100644 --- a/src/spikeinterface/extractors/neoextractors/plexon2.py +++ b/src/spikeinterface/extractors/neoextractors/plexon2.py @@ -28,6 +28,10 @@ class Plexon2RecordingExtractor(NeoBaseRecordingExtractor): ids: ["source3.1" , "source3.2", "source3.3", "source3.4"] all_annotations : bool, default: False Load exhaustively all annotations from neo. + reading_attempts : int, default: 25 + Number of attempts to read the file before raising an error + This opening process is somewhat unreliable and might fail occasionally. Adjust this higher + if you encounter problems in opening the file. Examples -------- @@ -37,8 +41,16 @@ class Plexon2RecordingExtractor(NeoBaseRecordingExtractor): NeoRawIOClass = "Plexon2RawIO" - def __init__(self, file_path, stream_id=None, stream_name=None, use_names_as_ids=True, all_annotations=False): - neo_kwargs = self.map_to_neo_kwargs(file_path) + def __init__( + self, + file_path, + stream_id=None, + stream_name=None, + use_names_as_ids=True, + all_annotations=False, + reading_attempts: int = 25, + ): + neo_kwargs = self.map_to_neo_kwargs(file_path, reading_attempts=reading_attempts) NeoBaseRecordingExtractor.__init__( self, stream_id=stream_id, @@ -50,8 +62,18 @@ def __init__(self, file_path, stream_id=None, stream_name=None, use_names_as_ids self._kwargs.update({"file_path": str(file_path)}) @classmethod - def map_to_neo_kwargs(cls, file_path): + def map_to_neo_kwargs(cls, file_path, reading_attempts: int = 25): + neo_kwargs = {"filename": str(file_path)} + + from packaging.version import Version + import neo + + neo_version = Version(neo.__version__) + + if neo_version > Version("0.13.3"): + neo_kwargs["reading_attempts"] = reading_attempts + return neo_kwargs diff --git a/src/spikeinterface/extractors/neuropixels_utils.py b/src/spikeinterface/extractors/neuropixels_utils.py index 7e717dd2eb..f7841aeae2 100644 --- a/src/spikeinterface/extractors/neuropixels_utils.py +++ b/src/spikeinterface/extractors/neuropixels_utils.py @@ -1,39 +1,47 @@ from __future__ import annotations import numpy as np +from typing import Optional -def get_neuropixels_sample_shifts(num_channels=384, num_channels_per_adc=12, num_cycles=None): +def get_neuropixels_sample_shifts( + num_channels: int = 384, num_channels_per_adc: int = 12, num_cycles: Optional[int] = None +) -> np.ndarray: """ - Calculates the relative sampling phase of each channel that results - from Neuropixels ADC multiplexing. + Calculate the relative sampling phase (inter-sample shifts) for each channel + in Neuropixels probes due to ADC multiplexing. - This information is needed to perform the preprocessing.phase_shift operation. + Neuropixels probes sample channels sequentially through multiple ADCs, + introducing slight temporal delays between channels within each sampling cycle. + These inter-sample shifts are fractions of the sampling period and are crucial + to consider during preprocessing steps, such as phase correction, to ensure + accurate alignment of the recorded signals. - See https://github.com/int-brain-lab/ibllib/blob/master/ibllib/ephys/neuropixel.py - - - for the original implementation. + This function computes these relative phase shifts, returning an array where + each value represents the fractional delay (ranging from 0 to 1) for the + corresponding channel. Parameters ---------- num_channels : int, default: 384 - The total number of channels in a recording. - All currently available Neuropixels variants have 384 channels. + Total number of channels in the recording. + Neuropixels probes typically have 384 channels. num_channels_per_adc : int, default: 12 - The number of channels per ADC on the probe. - Neuropixels 1.0 probes have 12 ADCs. + Number of channels assigned to each ADC on the probe. + Neuropixels 1.0 probes have 12 ADCs, each handling 32 channels. Neuropixels 2.0 probes have 16 ADCs. - num_cycles: int or None, default: None - The number of cycles in the ADC on the probe. - Neuropixels 1.0 probes have 13 cycles for AP and 12 for LFP. + num_cycles : int or None, default: None + Number of cycles in the ADC sampling sequence. + Neuropixels 1.0 probes have 13 cycles for AP (action potential) signals + and 12 for LFP (local field potential) signals. Neuropixels 2.0 probes have 16 cycles. - If None, the num_channels_per_adc is used. + If None, defaults to the value of `num_channels_per_adc`. Returns ------- - sample_shifts : ndarray - The relative phase (from 0-1) of each channel + sample_shifts : np.ndarray + Array of relative phase shifts for each channel, with values ranging from 0 to 1, + representing the fractional delay within the sampling period due to sequential ADC sampling. """ if num_cycles is None: num_cycles = num_channels_per_adc @@ -44,9 +52,8 @@ def get_neuropixels_sample_shifts(num_channels=384, num_channels_per_adc=12, num sample_shifts = np.zeros_like(adc_indices) - for a in adc_indices: - sample_shifts[adc_indices == a] = np.arange(num_channels_per_adc) / num_cycles - + for adc_index in adc_indices: + sample_shifts[adc_indices == adc_index] = np.arange(num_channels_per_adc) / num_cycles return sample_shifts diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index d797e64910..171992f6b1 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -599,6 +599,8 @@ def __init__( else: gains, offsets, locations, groups = self._fetch_main_properties_backend() self.extra_requirements.append("h5py") + if stream_mode is not None: + self.extra_requirements.append(stream_mode) self.set_channel_gains(gains) self.set_channel_offsets(offsets) if locations is not None: @@ -1100,6 +1102,8 @@ def __init__( for property_name, property_values in properties.items(): values = [x.decode("utf-8") if isinstance(x, bytes) else x for x in property_values] self.set_property(property_name, values) + if stream_mode is not None: + self.extra_requirements.append(stream_mode) if stream_mode is None and file_path is not None: file_path = str(Path(file_path).resolve()) diff --git a/src/spikeinterface/extractors/tests/common_tests.py b/src/spikeinterface/extractors/tests/common_tests.py index 5432efa9f3..61cfc2a153 100644 --- a/src/spikeinterface/extractors/tests/common_tests.py +++ b/src/spikeinterface/extractors/tests/common_tests.py @@ -52,8 +52,11 @@ def test_open(self): num_samples = rec.get_num_samples(segment_index=segment_index) full_traces = rec.get_traces(segment_index=segment_index) - assert full_traces.shape == (num_samples, num_chans) - assert full_traces.dtype == dtype + assert full_traces.shape == ( + num_samples, + num_chans, + ), f"{full_traces.shape} != {(num_samples, num_chans)}" + assert full_traces.dtype == dtype, f"{full_traces.dtype} != {dtype=}" traces_sample_first = rec.get_traces(segment_index=segment_index, start_frame=0, end_frame=1) assert traces_sample_first.shape == (1, num_chans) diff --git a/src/spikeinterface/extractors/tests/test_mdaextractors.py b/src/spikeinterface/extractors/tests/test_mdaextractors.py index 0ef6697c6c..78e6afb65e 100644 --- a/src/spikeinterface/extractors/tests/test_mdaextractors.py +++ b/src/spikeinterface/extractors/tests/test_mdaextractors.py @@ -9,6 +9,12 @@ def test_mda_extractors(create_cache_folder): cache_folder = create_cache_folder rec, sort = generate_ground_truth_recording(durations=[10.0], num_units=10) + ids_as_integers = [id for id in range(rec.get_num_channels())] + rec = rec.rename_channels(new_channel_ids=ids_as_integers) + + ids_as_integers = [id for id in range(sort.get_num_units())] + sort = sort.rename_units(new_unit_ids=ids_as_integers) + MdaRecordingExtractor.write_recording(rec, cache_folder / "mdatest") rec_mda = MdaRecordingExtractor(cache_folder / "mdatest") probe = rec_mda.get_probe() diff --git a/src/spikeinterface/extractors/tests/test_neoextractors.py b/src/spikeinterface/extractors/tests/test_neoextractors.py index 3f73161218..3da92331a6 100644 --- a/src/spikeinterface/extractors/tests/test_neoextractors.py +++ b/src/spikeinterface/extractors/tests/test_neoextractors.py @@ -181,6 +181,14 @@ class NeuroScopeSortingTest(SortingCommonTestSuite, unittest.TestCase): ] +class NeuroNexusRecordingTest(RecordingCommonTestSuite, unittest.TestCase): + ExtractorClass = NeuroNexusRecordingExtractor + downloads = ["neuronexus"] + entities = [ + ("neuronexus/allego_1/allego_2__uid0701-13-04-49.xdat.json", {"stream_id": "0"}), + ] + + class PlexonRecordingTest(RecordingCommonTestSuite, unittest.TestCase): ExtractorClass = PlexonRecordingExtractor downloads = ["plexon"] @@ -360,7 +368,7 @@ class Plexon2RecordingTest(RecordingCommonTestSuite, unittest.TestCase): ExtractorClass = Plexon2RecordingExtractor downloads = ["plexon"] entities = [ - ("plexon/4chDemoPL2.pl2", {"stream_id": "3"}), + ("plexon/4chDemoPL2.pl2", {"stream_name": "WB-Wideband"}), ] diff --git a/src/spikeinterface/extractors/tests/test_nwbextractors_streaming.py b/src/spikeinterface/extractors/tests/test_nwbextractors_streaming.py index b3c5b9c934..9724ec3d9f 100644 --- a/src/spikeinterface/extractors/tests/test_nwbextractors_streaming.py +++ b/src/spikeinterface/extractors/tests/test_nwbextractors_streaming.py @@ -4,7 +4,7 @@ import pytest import numpy as np -from spikeinterface import load_extractor +from spikeinterface import load from spikeinterface.core.testing import check_recordings_equal from spikeinterface.core.testing import check_recordings_equal, check_sortings_equal from spikeinterface.extractors import NwbRecordingExtractor, NwbSortingExtractor @@ -219,7 +219,7 @@ def test_sorting_s3_nwb_zarr(tmp_path): assert not sorting.check_serializability("pickle") # test to/from dict - sorting_loaded = load_extractor(sorting.to_dict()) + sorting_loaded = load(sorting.to_dict()) # just take 3 random units to test rng = np.random.default_rng(seed=2205) diff --git a/src/spikeinterface/full.py b/src/spikeinterface/full.py index 0cd0fb0fb5..b9410bc021 100644 --- a/src/spikeinterface/full.py +++ b/src/spikeinterface/full.py @@ -25,3 +25,4 @@ from .widgets import * from .exporters import * from .generation import * +from .benchmark import * diff --git a/src/spikeinterface/postprocessing/localization_tools.py b/src/spikeinterface/postprocessing/localization_tools.py index e6278fc59f..59c12a9923 100644 --- a/src/spikeinterface/postprocessing/localization_tools.py +++ b/src/spikeinterface/postprocessing/localization_tools.py @@ -3,6 +3,8 @@ import warnings import numpy as np +from spikeinterface.core import SortingAnalyzer, Templates, compute_sparsity +from spikeinterface.core.template_tools import _get_nbefore, get_dense_templates_array, get_template_extremum_channel try: import numba @@ -12,10 +14,6 @@ HAVE_NUMBA = False -from spikeinterface.core import compute_sparsity, SortingAnalyzer, Templates -from spikeinterface.core.template_tools import get_template_extremum_channel, _get_nbefore, get_dense_templates_array - - def compute_monopolar_triangulation( sorting_analyzer_or_templates: SortingAnalyzer | Templates, unit_ids=None, @@ -77,7 +75,11 @@ def compute_monopolar_triangulation( contact_locations = sorting_analyzer_or_templates.get_channel_locations() - sparsity = compute_sparsity(sorting_analyzer_or_templates, method="radius", radius_um=radius_um) + if sorting_analyzer_or_templates.sparsity is None: + sparsity = compute_sparsity(sorting_analyzer_or_templates, method="radius", radius_um=radius_um) + else: + sparsity = sorting_analyzer_or_templates.sparsity + templates = get_dense_templates_array( sorting_analyzer_or_templates, return_scaled=get_return_scaled(sorting_analyzer_or_templates) ) @@ -106,7 +108,7 @@ def compute_monopolar_triangulation( # wf is (nsample, nchan) - chann is only nieghboor wf = templates[i, :, :][:, chan_inds] if feature == "ptp": - wf_data = wf.ptp(axis=0) + wf_data = np.ptp(wf, axis=0) elif feature == "energy": wf_data = np.linalg.norm(wf, axis=0) elif feature == "peak_voltage": @@ -157,9 +159,13 @@ def compute_center_of_mass( assert feature in ["ptp", "mean", "energy", "peak_voltage"], f"{feature} is not a valid feature" - sparsity = compute_sparsity( - sorting_analyzer_or_templates, peak_sign=peak_sign, method="radius", radius_um=radius_um - ) + if sorting_analyzer_or_templates.sparsity is None: + sparsity = compute_sparsity( + sorting_analyzer_or_templates, peak_sign=peak_sign, method="radius", radius_um=radius_um + ) + else: + sparsity = sorting_analyzer_or_templates.sparsity + templates = get_dense_templates_array( sorting_analyzer_or_templates, return_scaled=get_return_scaled(sorting_analyzer_or_templates) ) @@ -180,7 +186,7 @@ def compute_center_of_mass( wf = templates[i, :, :] if feature == "ptp": - wf_data = (wf[:, chan_inds]).ptp(axis=0) + wf_data = np.ptp(wf[:, chan_inds], axis=0) elif feature == "mean": wf_data = (wf[:, chan_inds]).mean(axis=0) elif feature == "energy": @@ -650,8 +656,55 @@ def get_convolution_weights( enforce_decrease_shells = numba.jit(enforce_decrease_shells_data, nopython=True) +def compute_location_max_channel( + templates_or_sorting_analyzer: SortingAnalyzer | Templates, + unit_ids=None, + peak_sign: "neg" | "pos" | "both" = "neg", + mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", +) -> np.ndarray: + """ + Localize a unit using max channel. + + This uses internally `get_template_extremum_channel()` + + + Parameters + ---------- + templates_or_sorting_analyzer : SortingAnalyzer | Templates + A SortingAnalyzer or Templates object + unit_ids: list[str] | list[int] | None + A list of unit_id to restrict the computation + peak_sign : "neg" | "pos" | "both" + Sign of the template to find extremum channels + mode : "extremum" | "at_index" | "peak_to_peak", default: "at_index" + Where the amplitude is computed + * "extremum" : take the peak value (max or min depending on `peak_sign`) + * "at_index" : take value at `nbefore` index + * "peak_to_peak" : take the peak-to-peak amplitude + + Returns + ------- + unit_locations: np.ndarray + 2d + """ + extremum_channels_index = get_template_extremum_channel( + templates_or_sorting_analyzer, peak_sign=peak_sign, mode=mode, outputs="index" + ) + contact_locations = templates_or_sorting_analyzer.get_channel_locations() + if unit_ids is None: + unit_ids = templates_or_sorting_analyzer.unit_ids + else: + unit_ids = np.asarray(unit_ids) + unit_locations = np.zeros((unit_ids.size, 2), dtype="float32") + for i, unit_id in enumerate(unit_ids): + unit_locations[i, :] = contact_locations[extremum_channels_index[unit_id]] + + return unit_locations + + _unit_location_methods = { "center_of_mass": compute_center_of_mass, "grid_convolution": compute_grid_convolution, "monopolar_triangulation": compute_monopolar_triangulation, + "max_channel": compute_location_max_channel, } diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 1871c11b85..38ea6b3824 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -1,12 +1,14 @@ from __future__ import annotations -import shutil -import pickle import warnings -import tempfile +import platform from pathlib import Path from tqdm.auto import tqdm +from concurrent.futures import ProcessPoolExecutor +import multiprocessing as mp +from threadpoolctl import threadpool_limits + import numpy as np from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension @@ -314,11 +316,13 @@ def _run(self, verbose=False, **job_kwargs): job_kwargs = fix_job_kwargs(job_kwargs) n_jobs = job_kwargs["n_jobs"] progress_bar = job_kwargs["progress_bar"] + max_threads_per_worker = job_kwargs["max_threads_per_worker"] + mp_context = job_kwargs["mp_context"] # fit model/models # TODO : make parralel for by_channel_global and concatenated if mode == "by_channel_local": - pca_models = self._fit_by_channel_local(n_jobs, progress_bar) + pca_models = self._fit_by_channel_local(n_jobs, progress_bar, max_threads_per_worker, mp_context) for chan_ind, chan_id in enumerate(self.sorting_analyzer.channel_ids): self.data[f"pca_model_{mode}_{chan_id}"] = pca_models[chan_ind] pca_model = pca_models @@ -411,12 +415,16 @@ def run_for_all_spikes(self, file_path=None, verbose=False, **job_kwargs): ) processor.run() - def _fit_by_channel_local(self, n_jobs, progress_bar): + def _fit_by_channel_local(self, n_jobs, progress_bar, max_threads_per_worker, mp_context): from sklearn.decomposition import IncrementalPCA - from concurrent.futures import ProcessPoolExecutor p = self.params + if mp_context is not None and platform.system() == "Windows": + assert mp_context != "fork", "'fork' mp_context not supported on Windows!" + elif mp_context == "fork" and platform.system() == "Darwin": + warnings.warn('As of Python 3.8 "fork" is no longer considered safe on macOS') + unit_ids = self.sorting_analyzer.unit_ids channel_ids = self.sorting_analyzer.channel_ids # there is one PCA per channel for independent fit per channel @@ -436,13 +444,18 @@ def _fit_by_channel_local(self, n_jobs, progress_bar): pca = pca_models[chan_ind] pca.partial_fit(wfs[:, :, wf_ind]) else: - # parallel + # create list of args to parallelize. For convenience, the max_threads_per_worker is passed + # as last argument items = [ - (chan_ind, pca_models[chan_ind], wfs[:, :, wf_ind]) for wf_ind, chan_ind in enumerate(channel_inds) + (chan_ind, pca_models[chan_ind], wfs[:, :, wf_ind], max_threads_per_worker) + for wf_ind, chan_ind in enumerate(channel_inds) ] n_jobs = min(n_jobs, len(items)) - with ProcessPoolExecutor(max_workers=n_jobs) as executor: + with ProcessPoolExecutor( + max_workers=n_jobs, + mp_context=mp.get_context(mp_context), + ) as executor: results = executor.map(_partial_fit_one_channel, items) for chan_ind, pca_model_updated in results: pca_models[chan_ind] = pca_model_updated @@ -645,10 +658,6 @@ def _all_pc_extractor_chunk(segment_index, start_frame, end_frame, worker_ctx): def _init_work_all_pc_extractor(recording, sorting, all_pcs_args, nbefore, nafter, unit_channels, pca_model): worker_ctx = {} - if isinstance(recording, dict): - from spikeinterface.core import load_extractor - - recording = load_extractor(recording) worker_ctx["recording"] = recording worker_ctx["sorting"] = sorting @@ -674,6 +683,12 @@ def _init_work_all_pc_extractor(recording, sorting, all_pcs_args, nbefore, nafte def _partial_fit_one_channel(args): - chan_ind, pca_model, wf_chan = args - pca_model.partial_fit(wf_chan) - return chan_ind, pca_model + chan_ind, pca_model, wf_chan, max_threads_per_worker = args + + if max_threads_per_worker is None: + pca_model.partial_fit(wf_chan) + return chan_ind, pca_model + else: + with threadpool_limits(limits=int(max_threads_per_worker)): + pca_model.partial_fit(wf_chan) + return chan_ind, pca_model diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 306e9594b8..1969480503 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -63,23 +63,10 @@ class ComputeTemplateMetrics(AnalyzerExtension): include_multi_channel_metrics : bool, default: False Whether to compute multi-channel metrics delete_existing_metrics : bool, default: False - If True, any template metrics attached to the `sorting_analyzer` are deleted. If False, any metrics which were previously calculated but are not included in `metric_names` are kept, provided the `metrics_kwargs` are unchanged. - metrics_kwargs : dict - Additional arguments to pass to the metric functions. Including: - * recovery_window_ms: the window in ms after the peak to compute the recovery_slope, default: 0.7 - * peak_relative_threshold: the relative threshold to detect positive and negative peaks, default: 0.2 - * peak_width_ms: the width in samples to detect peaks, default: 0.2 - * depth_direction: the direction to compute velocity above and below, default: "y" (see notes) - * min_channels_for_velocity: the minimum number of channels above or below to compute velocity, default: 5 - * min_r2_velocity: the minimum r2 to accept the velocity fit, default: 0.7 - * exp_peak_function: the function to use to compute the peak amplitude for the exp decay, default: "ptp" - * min_r2_exp_decay: the minimum r2 to accept the exp decay fit, default: 0.5 - * spread_threshold: the threshold to compute the spread, default: 0.2 - * spread_smooth_um: the smoothing in um to compute the spread, default: 20 - * column_range: the range in um in the horizontal direction to consider channels for velocity, default: None - - If None, all channels all channels are considered - - If 0 or 1, only the "column" that includes the max channel is considered - - If > 1, only channels within range (+/-) um from the max channel horizontal position are used + If True, any template metrics attached to the `sorting_analyzer` are deleted. If False, any metrics which were previously calculated but are not included in `metric_names` are kept, provided the `metric_params` are unchanged. + metric_params : dict of dicts or None, default: None + Dictionary with parameters for template metrics calculation. + Default parameters can be obtained with: `si.postprocessing.template_metrics.get_default_tm_params()` Returns ------- @@ -97,18 +84,32 @@ class ComputeTemplateMetrics(AnalyzerExtension): extension_name = "template_metrics" depend_on = ["templates"] - need_recording = True + need_recording = False use_nodepipeline = False need_job_kwargs = False + need_backward_compatibility_on_load = True min_channels_for_multi_channel_warning = 10 + def _handle_backward_compatibility_on_load(self): + + # For backwards compatibility - this reformats metrics_kwargs as metric_params + if (metrics_kwargs := self.params.get("metrics_kwargs")) is not None: + + metric_params = {} + for metric_name in self.params["metric_names"]: + metric_params[metric_name] = deepcopy(metrics_kwargs) + self.params["metric_params"] = metric_params + + del self.params["metrics_kwargs"] + def _set_params( self, metric_names=None, peak_sign="neg", upsampling_factor=10, sparsity=None, + metric_params=None, metrics_kwargs=None, include_multi_channel_metrics=False, delete_existing_metrics=False, @@ -134,33 +135,24 @@ def _set_params( if include_multi_channel_metrics: metric_names += get_multi_channel_template_metric_names() - if metrics_kwargs is None: - metrics_kwargs_ = _default_function_kwargs.copy() - if len(other_kwargs) > 0: - for m in other_kwargs: - if m in metrics_kwargs_: - metrics_kwargs_[m] = other_kwargs[m] - else: - metrics_kwargs_ = _default_function_kwargs.copy() - metrics_kwargs_.update(metrics_kwargs) + if metrics_kwargs is not None and metric_params is None: + deprecation_msg = "`metrics_kwargs` is deprecated and will be removed in version 0.104.0. Please use `metric_params` instead" + deprecation_msg = "`metrics_kwargs` is deprecated and will be removed in version 0.104.0. Please use `metric_params` instead" + + metric_params = {} + for metric_name in metric_names: + metric_params[metric_name] = deepcopy(metrics_kwargs) + + metric_params_ = get_default_tm_params(metric_names) + for k in metric_params_: + if metric_params is not None and k in metric_params: + metric_params_[k].update(metric_params[k]) metrics_to_compute = metric_names tm_extension = self.sorting_analyzer.get_extension("template_metrics") if delete_existing_metrics is False and tm_extension is not None: - existing_params = tm_extension.params["metrics_kwargs"] - # checks that existing metrics were calculated using the same params - if existing_params != metrics_kwargs_: - warnings.warn( - f"The parameters used to calculate the previous template metrics are different" - f"than those used now.\nPrevious parameters: {existing_params}\nCurrent " - f"parameters: {metrics_kwargs_}\nDeleting previous template metrics..." - ) - tm_extension.params["metric_names"] = [] - existing_metric_names = [] - else: - existing_metric_names = tm_extension.params["metric_names"] - + existing_metric_names = tm_extension.params["metric_names"] existing_metric_names_propogated = [ metric_name for metric_name in existing_metric_names if metric_name not in metrics_to_compute ] @@ -171,7 +163,7 @@ def _set_params( sparsity=sparsity, peak_sign=peak_sign, upsampling_factor=int(upsampling_factor), - metrics_kwargs=metrics_kwargs_, + metric_params=metric_params_, delete_existing_metrics=delete_existing_metrics, metrics_to_compute=metrics_to_compute, ) @@ -273,7 +265,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri sampling_frequency=sampling_frequency_up, trough_idx=trough_idx, peak_idx=peak_idx, - **self.params["metrics_kwargs"], + **self.params["metric_params"][metric_name], ) except Exception as e: warnings.warn(f"Error computing metric {metric_name} for unit {unit_id}: {e}") @@ -312,7 +304,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri template_upsampled, channel_locations=channel_locations_sparse, sampling_frequency=sampling_frequency_up, - **self.params["metrics_kwargs"], + **self.params["metric_params"][metric_name], ) except Exception as e: warnings.warn(f"Error computing metric {metric_name} for unit {unit_id}: {e}") @@ -326,8 +318,8 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri def _run(self, verbose=False): - delete_existing_metrics = self.params["delete_existing_metrics"] metrics_to_compute = self.params["metrics_to_compute"] + delete_existing_metrics = self.params["delete_existing_metrics"] # compute the metrics which have been specified by the user computed_metrics = self._compute_metrics( @@ -343,9 +335,21 @@ def _run(self, verbose=False): ): existing_metrics = tm_extension.params["metric_names"] + existing_metrics = [] + # here we get in the loaded via the dict only (to avoid full loading from disk after params reset) + tm_extension = self.sorting_analyzer.extensions.get("template_metrics", None) + if ( + delete_existing_metrics is False + and tm_extension is not None + and tm_extension.data.get("metrics") is not None + ): + existing_metrics = tm_extension.params["metric_names"] + # append the metrics which were previously computed for metric_name in set(existing_metrics).difference(metrics_to_compute): - computed_metrics[metric_name] = tm_extension.data["metrics"][metric_name] + # some metrics names produce data columns with other names. This deals with that. + for column_name in tm_compute_name_to_column_names[metric_name]: + computed_metrics[column_name] = tm_extension.data["metrics"][column_name] self.data["metrics"] = computed_metrics @@ -372,6 +376,35 @@ def _get_data(self): ) +def get_default_tm_params(metric_names): + if metric_names is None: + metric_names = get_template_metric_names() + + base_tm_params = _default_function_kwargs + + metric_params = {} + for metric_name in metric_names: + metric_params[metric_name] = deepcopy(base_tm_params) + + return metric_params + + +# a dict converting the name of the metric for computation to the output of that computation +tm_compute_name_to_column_names = { + "peak_to_valley": ["peak_to_valley"], + "peak_trough_ratio": ["peak_trough_ratio"], + "half_width": ["half_width"], + "repolarization_slope": ["repolarization_slope"], + "recovery_slope": ["recovery_slope"], + "num_positive_peaks": ["num_positive_peaks"], + "num_negative_peaks": ["num_negative_peaks"], + "velocity_above": ["velocity_above"], + "velocity_below": ["velocity_below"], + "exp_decay": ["exp_decay"], + "spread": ["spread"], +} + + def get_trough_and_peak_idx(template): """ Return the indices into the input template of the detected trough diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 0e70b1f494..6c30e2730b 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -7,6 +7,13 @@ from ..core.template_tools import get_dense_templates_array from ..core.sparsity import ChannelSparsity +try: + import numba + + HAVE_NUMBA = True +except ImportError: + HAVE_NUMBA = False + class ComputeTemplateSimilarity(AnalyzerExtension): """Compute similarity between templates with several methods. @@ -37,7 +44,7 @@ class ComputeTemplateSimilarity(AnalyzerExtension): extension_name = "template_similarity" depend_on = ["templates"] - need_recording = True + need_recording = False use_nodepipeline = False need_job_kwargs = False need_backward_compatibility_on_load = True @@ -147,54 +154,15 @@ def _get_data(self): compute_template_similarity = ComputeTemplateSimilarity.function_factory() -def compute_similarity_with_templates_array( - templates_array, other_templates_array, method, support="union", num_shifts=0, sparsity=None, other_sparsity=None -): - import sklearn.metrics.pairwise +def _compute_similarity_matrix_numpy(templates_array, other_templates_array, num_shifts, mask, method): - if method == "cosine_similarity": - method = "cosine" - - all_metrics = ["cosine", "l1", "l2"] - - if method not in all_metrics: - raise ValueError(f"compute_template_similarity (method {method}) not exists") - - assert ( - templates_array.shape[1] == other_templates_array.shape[1] - ), "The number of samples in the templates should be the same for both arrays" - assert ( - templates_array.shape[2] == other_templates_array.shape[2] - ), "The number of channels in the templates should be the same for both arrays" num_templates = templates_array.shape[0] num_samples = templates_array.shape[1] - num_channels = templates_array.shape[2] other_num_templates = other_templates_array.shape[0] - same_array = np.array_equal(templates_array, other_templates_array) - - mask = None - if sparsity is not None and other_sparsity is not None: - if support == "intersection": - mask = np.logical_and(sparsity.mask[:, np.newaxis, :], other_sparsity.mask[np.newaxis, :, :]) - elif support == "union": - mask = np.logical_and(sparsity.mask[:, np.newaxis, :], other_sparsity.mask[np.newaxis, :, :]) - units_overlaps = np.sum(mask, axis=2) > 0 - mask = np.logical_or(sparsity.mask[:, np.newaxis, :], other_sparsity.mask[np.newaxis, :, :]) - mask[~units_overlaps] = False - if mask is not None: - units_overlaps = np.sum(mask, axis=2) > 0 - overlapping_templates = {} - for i in range(num_templates): - overlapping_templates[i] = np.flatnonzero(units_overlaps[i]) - else: - # here we make a dense mask and overlapping templates - overlapping_templates = {i: np.arange(other_num_templates) for i in range(num_templates)} - mask = np.ones((num_templates, other_num_templates, num_channels), dtype=bool) - - assert num_shifts < num_samples, "max_lag is too large" num_shifts_both_sides = 2 * num_shifts + 1 distances = np.ones((num_shifts_both_sides, num_templates, other_num_templates), dtype=np.float32) + same_array = np.array_equal(templates_array, other_templates_array) # We can use the fact that dist[i,j] at lag t is equal to dist[j,i] at time -t # So the matrix can be computed only for negative lags and be transposed @@ -210,8 +178,9 @@ def compute_similarity_with_templates_array( tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift] for i in range(num_templates): src_template = src_sliced_templates[i] - tgt_templates = tgt_sliced_templates[overlapping_templates[i]] - for gcount, j in enumerate(overlapping_templates[i]): + overlapping_templates = np.flatnonzero(np.sum(mask[i], 1)) + tgt_templates = tgt_sliced_templates[overlapping_templates] + for gcount, j in enumerate(overlapping_templates): # symmetric values are handled later if same_array and j < i: # no need exhaustive looping when same template @@ -222,23 +191,156 @@ def compute_similarity_with_templates_array( if method == "l1": norm_i = np.sum(np.abs(src)) norm_j = np.sum(np.abs(tgt)) - distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(src, tgt, metric="l1").item() + distances[count, i, j] = np.sum(np.abs(src - tgt)) distances[count, i, j] /= norm_i + norm_j elif method == "l2": norm_i = np.linalg.norm(src, ord=2) norm_j = np.linalg.norm(tgt, ord=2) - distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(src, tgt, metric="l2").item() + distances[count, i, j] = np.linalg.norm(src - tgt, ord=2) distances[count, i, j] /= norm_i + norm_j - else: - distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances( - src, tgt, metric="cosine" - ).item() + elif method == "cosine": + norm_i = np.linalg.norm(src, ord=2) + norm_j = np.linalg.norm(tgt, ord=2) + distances[count, i, j] = np.sum(src * tgt) + distances[count, i, j] /= norm_i * norm_j + distances[count, i, j] = 1 - distances[count, i, j] if same_array: distances[count, j, i] = distances[count, i, j] if same_array and num_shifts != 0: distances[num_shifts_both_sides - count - 1] = distances[count].T + return distances + + +if HAVE_NUMBA: + + from math import sqrt + + @numba.jit(nopython=True, parallel=True, fastmath=True, nogil=True) + def _compute_similarity_matrix_numba(templates_array, other_templates_array, num_shifts, mask, method): + num_templates = templates_array.shape[0] + num_samples = templates_array.shape[1] + other_num_templates = other_templates_array.shape[0] + + num_shifts_both_sides = 2 * num_shifts + 1 + distances = np.ones((num_shifts_both_sides, num_templates, other_num_templates), dtype=np.float32) + same_array = np.array_equal(templates_array, other_templates_array) + + # We can use the fact that dist[i,j] at lag t is equal to dist[j,i] at time -t + # So the matrix can be computed only for negative lags and be transposed + + if same_array: + # optimisation when array are the same because of symetry in shift + shift_loop = list(range(-num_shifts, 1)) + else: + shift_loop = list(range(-num_shifts, num_shifts + 1)) + + if method == "l1": + metric = 0 + elif method == "l2": + metric = 1 + elif method == "cosine": + metric = 2 + + for count in range(len(shift_loop)): + shift = shift_loop[count] + src_sliced_templates = templates_array[:, num_shifts : num_samples - num_shifts] + tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift] + for i in numba.prange(num_templates): + src_template = src_sliced_templates[i] + overlapping_templates = np.flatnonzero(np.sum(mask[i], 1)) + tgt_templates = tgt_sliced_templates[overlapping_templates] + for gcount in range(len(overlapping_templates)): + + j = overlapping_templates[gcount] + # symmetric values are handled later + if same_array and j < i: + # no need exhaustive looping when same template + continue + src = src_template[:, mask[i, j]].flatten() + tgt = (tgt_templates[gcount][:, mask[i, j]]).flatten() + + norm_i = 0 + norm_j = 0 + distances[count, i, j] = 0 + + for k in range(len(src)): + if metric == 0: + norm_i += abs(src[k]) + norm_j += abs(tgt[k]) + distances[count, i, j] += abs(src[k] - tgt[k]) + elif metric == 1: + norm_i += src[k] ** 2 + norm_j += tgt[k] ** 2 + distances[count, i, j] += (src[k] - tgt[k]) ** 2 + elif metric == 2: + distances[count, i, j] += src[k] * tgt[k] + norm_i += src[k] ** 2 + norm_j += tgt[k] ** 2 + + if metric == 0: + distances[count, i, j] /= norm_i + norm_j + elif metric == 1: + norm_i = sqrt(norm_i) + norm_j = sqrt(norm_j) + distances[count, i, j] = sqrt(distances[count, i, j]) + distances[count, i, j] /= norm_i + norm_j + elif metric == 2: + norm_i = sqrt(norm_i) + norm_j = sqrt(norm_j) + distances[count, i, j] /= norm_i * norm_j + distances[count, i, j] = 1 - distances[count, i, j] + + if same_array: + distances[count, j, i] = distances[count, i, j] + + if same_array and num_shifts != 0: + distances[num_shifts_both_sides - count - 1] = distances[count].T + + return distances + + _compute_similarity_matrix = _compute_similarity_matrix_numba +else: + _compute_similarity_matrix = _compute_similarity_matrix_numpy + + +def compute_similarity_with_templates_array( + templates_array, other_templates_array, method, support="union", num_shifts=0, sparsity=None, other_sparsity=None +): + + if method == "cosine_similarity": + method = "cosine" + + all_metrics = ["cosine", "l1", "l2"] + + if method not in all_metrics: + raise ValueError(f"compute_template_similarity (method {method}) not exists") + + assert ( + templates_array.shape[1] == other_templates_array.shape[1] + ), "The number of samples in the templates should be the same for both arrays" + assert ( + templates_array.shape[2] == other_templates_array.shape[2] + ), "The number of channels in the templates should be the same for both arrays" + num_templates = templates_array.shape[0] + num_samples = templates_array.shape[1] + num_channels = templates_array.shape[2] + other_num_templates = other_templates_array.shape[0] + + mask = np.ones((num_templates, other_num_templates, num_channels), dtype=bool) + + if sparsity is not None and other_sparsity is not None: + if support == "intersection": + mask = np.logical_and(sparsity.mask[:, np.newaxis, :], other_sparsity.mask[np.newaxis, :, :]) + elif support == "union": + mask = np.logical_and(sparsity.mask[:, np.newaxis, :], other_sparsity.mask[np.newaxis, :, :]) + units_overlaps = np.sum(mask, axis=2) > 0 + mask = np.logical_or(sparsity.mask[:, np.newaxis, :], other_sparsity.mask[np.newaxis, :, :]) + mask[~units_overlaps] = False + + assert num_shifts < num_samples, "max_lag is too large" + distances = _compute_similarity_matrix(templates_array, other_templates_array, num_shifts, mask, method) distances = np.min(distances, axis=0) similarity = 1 - distances diff --git a/src/spikeinterface/postprocessing/tests/test_correlograms.py b/src/spikeinterface/postprocessing/tests/test_correlograms.py index 66d84c9565..0431c8d675 100644 --- a/src/spikeinterface/postprocessing/tests/test_correlograms.py +++ b/src/spikeinterface/postprocessing/tests/test_correlograms.py @@ -93,7 +93,6 @@ def test_equal_results_correlograms(window_and_bin_ms): ) assert np.array_equal(result_numpy, result_numba) - assert np.array_equal(result_numpy, result_numba) @pytest.mark.parametrize("method", ["numpy", param("numba", marks=SKIP_NUMBA)]) diff --git a/src/spikeinterface/postprocessing/tests/test_multi_extensions.py b/src/spikeinterface/postprocessing/tests/test_multi_extensions.py index bf0000135c..be0070d94a 100644 --- a/src/spikeinterface/postprocessing/tests/test_multi_extensions.py +++ b/src/spikeinterface/postprocessing/tests/test_multi_extensions.py @@ -23,6 +23,11 @@ def get_dataset(): seed=2205, ) + channel_ids_as_integers = [id for id in range(recording.get_num_channels())] + unit_ids_as_integers = [id for id in range(sorting.get_num_units())] + recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers) + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers) + # since templates are going to be averaged and this might be a problem for amplitude scaling # we select the 3 units with the largest templates to split analyzer_raw = create_sorting_analyzer(sorting, recording, format="memory", sparse=False) diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index 4de86be32b..ecfc39f2c6 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -18,6 +18,18 @@ class TestPrincipalComponentsExtension(AnalyzerExtensionCommonTestSuite): def test_extension(self, params): self.run_extension_tests(ComputePrincipalComponents, params=params) + def test_multi_processing(self): + """ + Test the extension works with multiple processes. + """ + sorting_analyzer = self._prepare_sorting_analyzer( + format="memory", sparse=False, extension_class=ComputePrincipalComponents + ) + sorting_analyzer.compute("principal_components", mode="by_channel_local", n_jobs=2) + sorting_analyzer.compute( + "principal_components", mode="by_channel_local", n_jobs=2, max_threads_per_worker=4, mp_context="spawn" + ) + def test_mode_concatenated(self): """ Replicate the "extension_function_params_list" test outside of diff --git a/src/spikeinterface/postprocessing/tests/test_template_metrics.py b/src/spikeinterface/postprocessing/tests/test_template_metrics.py index 5056d4ff2a..1bf49f64c1 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_metrics.py +++ b/src/spikeinterface/postprocessing/tests/test_template_metrics.py @@ -1,5 +1,5 @@ from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite -from spikeinterface.postprocessing import ComputeTemplateMetrics +from spikeinterface.postprocessing import ComputeTemplateMetrics, compute_template_metrics import pytest import csv @@ -8,6 +8,49 @@ template_metrics = list(_single_channel_metric_name_to_func.keys()) +def test_different_params_template_metrics(small_sorting_analyzer): + """ + Computes template metrics using different params, and check that they are + actually calculated using the different params. + """ + compute_template_metrics( + sorting_analyzer=small_sorting_analyzer, + metric_names=["exp_decay", "spread", "half_width"], + metric_params={"exp_decay": {"recovery_window_ms": 0.8}, "spread": {"spread_smooth_um": 15}}, + ) + + tm_extension = small_sorting_analyzer.get_extension("template_metrics") + tm_params = tm_extension.params["metric_params"] + + assert tm_params["exp_decay"]["recovery_window_ms"] == 0.8 + assert tm_params["spread"]["recovery_window_ms"] == 0.7 + assert tm_params["half_width"]["recovery_window_ms"] == 0.7 + + assert tm_params["spread"]["spread_smooth_um"] == 15 + assert tm_params["exp_decay"]["spread_smooth_um"] == 20 + assert tm_params["half_width"]["spread_smooth_um"] == 20 + + +def test_backwards_compat_params_template_metrics(small_sorting_analyzer): + """ + Computes template metrics using the metrics_kwargs keyword + """ + compute_template_metrics( + sorting_analyzer=small_sorting_analyzer, + metric_names=["exp_decay", "spread"], + metrics_kwargs={"recovery_window_ms": 0.8}, + ) + + tm_extension = small_sorting_analyzer.get_extension("template_metrics") + tm_params = tm_extension.params["metric_params"] + + assert tm_params["exp_decay"]["recovery_window_ms"] == 0.8 + assert tm_params["spread"]["recovery_window_ms"] == 0.8 + + assert tm_params["spread"]["spread_smooth_um"] == 20 + assert tm_params["exp_decay"]["spread_smooth_um"] == 20 + + def test_compute_new_template_metrics(small_sorting_analyzer): """ Computes template metrics then computes a subset of template metrics, and checks @@ -17,6 +60,8 @@ def test_compute_new_template_metrics(small_sorting_analyzer): are deleted. """ + small_sorting_analyzer.delete_extension("template_metrics") + # calculate just exp_decay small_sorting_analyzer.compute({"template_metrics": {"metric_names": ["exp_decay"]}}) template_metric_extension = small_sorting_analyzer.get_extension("template_metrics") @@ -47,7 +92,7 @@ def test_compute_new_template_metrics(small_sorting_analyzer): # check that, when parameters are changed, the old metrics are deleted small_sorting_analyzer.compute( - {"template_metrics": {"metric_names": ["exp_decay"], "metrics_kwargs": {"recovery_window_ms": 0.6}}} + {"template_metrics": {"metric_names": ["exp_decay"], "metric_params": {"recovery_window_ms": 0.6}}} ) diff --git a/src/spikeinterface/postprocessing/tests/test_template_similarity.py b/src/spikeinterface/postprocessing/tests/test_template_similarity.py index cc6797c262..20d8373981 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_similarity.py +++ b/src/spikeinterface/postprocessing/tests/test_template_similarity.py @@ -7,7 +7,23 @@ ) from spikeinterface.postprocessing import check_equal_template_with_distribution_overlap, ComputeTemplateSimilarity -from spikeinterface.postprocessing.template_similarity import compute_similarity_with_templates_array +from spikeinterface.postprocessing.template_similarity import ( + compute_similarity_with_templates_array, + _compute_similarity_matrix_numpy, +) + +try: + import numba + + HAVE_NUMBA = True + from spikeinterface.postprocessing.template_similarity import _compute_similarity_matrix_numba +except ModuleNotFoundError as err: + HAVE_NUMBA = False + +import pytest +from pytest import param + +SKIP_NUMBA = pytest.mark.skipif(not HAVE_NUMBA, reason="Numba not available") class TestSimilarityExtension(AnalyzerExtensionCommonTestSuite): @@ -72,6 +88,35 @@ def test_compute_similarity_with_templates_array(params): print(similarity.shape) +pytest.mark.skipif(not HAVE_NUMBA, reason="Numba not available") + + +@pytest.mark.parametrize( + "params", + [ + dict(method="cosine", num_shifts=8), + dict(method="l1", num_shifts=0), + dict(method="l2", num_shifts=0), + dict(method="cosine", num_shifts=0), + ], +) +def test_equal_results_numba(params): + """ + Test that the 2 methods have same results with some varied time bins + that are not tested in other tests. + """ + + rng = np.random.default_rng(seed=2205) + templates_array = rng.random(size=(4, 20, 5), dtype=np.float32) + other_templates_array = rng.random(size=(2, 20, 5), dtype=np.float32) + mask = np.ones((4, 2, 5), dtype=bool) + + result_numpy = _compute_similarity_matrix_numba(templates_array, other_templates_array, mask=mask, **params) + result_numba = _compute_similarity_matrix_numpy(templates_array, other_templates_array, mask=mask, **params) + + assert np.allclose(result_numpy, result_numba, 1e-3) + + if __name__ == "__main__": from spikeinterface.postprocessing.tests.common_extension_tests import get_dataset from spikeinterface.core import estimate_sparsity diff --git a/src/spikeinterface/postprocessing/tests/test_unit_locations.py b/src/spikeinterface/postprocessing/tests/test_unit_locations.py index c40a917a2b..545edb3497 100644 --- a/src/spikeinterface/postprocessing/tests/test_unit_locations.py +++ b/src/spikeinterface/postprocessing/tests/test_unit_locations.py @@ -13,6 +13,7 @@ class TestUnitLocationsExtension(AnalyzerExtensionCommonTestSuite): dict(method="grid_convolution", radius_um=150, weight_method={"mode": "gaussian_2d"}), dict(method="monopolar_triangulation", radius_um=150), dict(method="monopolar_triangulation", radius_um=150, optimizer="minimize_with_log_penality"), + dict(method="max_channel"), ], ) def test_extension(self, params): diff --git a/src/spikeinterface/postprocessing/unit_locations.py b/src/spikeinterface/postprocessing/unit_locations.py index 4029fc88c7..df19458316 100644 --- a/src/spikeinterface/postprocessing/unit_locations.py +++ b/src/spikeinterface/postprocessing/unit_locations.py @@ -26,7 +26,7 @@ class ComputeUnitLocations(AnalyzerExtension): ---------- sorting_analyzer : SortingAnalyzer A SortingAnalyzer object - method : "center_of_mass" | "monopolar_triangulation" | "grid_convolution", default: "center_of_mass" + method : "monopolar_triangulation" | "center_of_mass" | "grid_convolution", default: "monopolar_triangulation" The method to use for localization **method_kwargs : dict, default: {} Kwargs which are passed to the method function. These can be found in the docstrings of `compute_center_of_mass`, `compute_grid_convolution` and `compute_monopolar_triangulation`. @@ -39,7 +39,7 @@ class ComputeUnitLocations(AnalyzerExtension): extension_name = "unit_locations" depend_on = ["templates"] - need_recording = True + need_recording = False use_nodepipeline = False need_job_kwargs = False need_backward_compatibility_on_load = True diff --git a/src/spikeinterface/preprocessing/average_across_direction.py b/src/spikeinterface/preprocessing/average_across_direction.py index ee2083d3c4..88c5f7301a 100644 --- a/src/spikeinterface/preprocessing/average_across_direction.py +++ b/src/spikeinterface/preprocessing/average_across_direction.py @@ -132,6 +132,9 @@ def get_traces(self, start_frame, end_frame, channel_indices): # now, divide by the number of channels at that position traces /= self.n_chans_each_pos + if channel_indices is not None: + traces = traces[:, channel_indices] + return traces diff --git a/src/spikeinterface/preprocessing/decimate.py b/src/spikeinterface/preprocessing/decimate.py index 334ebb02d2..d5fc9d2025 100644 --- a/src/spikeinterface/preprocessing/decimate.py +++ b/src/spikeinterface/preprocessing/decimate.py @@ -63,18 +63,15 @@ def __init__( f"Consider combining DecimateRecording with FrameSliceRecording for fine control on the recording start/end frames." ) self._decimation_offset = decimation_offset - resample_rate = self._orig_samp_freq / self._decimation_factor + decimated_sampling_frequency = self._orig_samp_freq / self._decimation_factor - BasePreprocessor.__init__(self, recording, sampling_frequency=resample_rate) + BasePreprocessor.__init__(self, recording, sampling_frequency=decimated_sampling_frequency) - # in case there was a time_vector, it will be dropped for sanity. - # This is not necessary but consistent with ResampleRecording for parent_segment in recording._recording_segments: - parent_segment.time_vector = None self.add_recording_segment( DecimateRecordingSegment( parent_segment, - resample_rate, + decimated_sampling_frequency, self._orig_samp_freq, decimation_factor, decimation_offset, @@ -93,22 +90,26 @@ class DecimateRecordingSegment(BaseRecordingSegment): def __init__( self, parent_recording_segment, - resample_rate, + decimated_sampling_frequency, parent_rate, decimation_factor, decimation_offset, dtype, ): - if parent_recording_segment.t_start is None: - new_t_start = None + if parent_recording_segment.time_vector is not None: + time_vector = parent_recording_segment.time_vector[decimation_offset::decimation_factor] + decimated_sampling_frequency = None + t_start = None else: - new_t_start = parent_recording_segment.t_start + decimation_offset / parent_rate + time_vector = None + if parent_recording_segment.t_start is None: + t_start = None + else: + t_start = parent_recording_segment.t_start + (decimation_offset / parent_rate) # Do not use BasePreprocessorSegment bcause we have to reset the sampling rate! BaseRecordingSegment.__init__( - self, - sampling_frequency=resample_rate, - t_start=new_t_start, + self, sampling_frequency=decimated_sampling_frequency, t_start=t_start, time_vector=time_vector ) self._parent_segment = parent_recording_segment self._decimation_factor = decimation_factor diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index 8f38f01469..00d9a1a407 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -71,8 +71,10 @@ def __init__(self, recording, list_periods, mode="zeros", noise_levels=None, see if mode in ["noise"]: if noise_levels is None: + random_slices_kwargs = random_chunk_kwargs.copy() + random_slices_kwargs["seed"] = seed noise_levels = get_noise_levels( - recording, return_scaled=False, concatenated=True, seed=seed, **random_chunk_kwargs + recording, return_scaled=False, random_slices_kwargs=random_slices_kwargs ) noise_generator = NoiseGeneratorRecording( num_channels=recording.get_num_channels(), @@ -95,7 +97,8 @@ def __init__(self, recording, list_periods, mode="zeros", noise_levels=None, see rec_segment = SilencedPeriodsRecordingSegment(parent_segment, periods, mode, noise_generator, seg_index) self.add_recording_segment(rec_segment) - self._kwargs = dict(recording=recording, list_periods=list_periods, mode=mode, noise_generator=noise_generator) + self._kwargs = dict(recording=recording, list_periods=list_periods, mode=mode, seed=seed) + self._kwargs.update(random_chunk_kwargs) class SilencedPeriodsRecordingSegment(BasePreprocessorSegment): diff --git a/src/spikeinterface/preprocessing/tests/test_average_across_direction.py b/src/spikeinterface/preprocessing/tests/test_average_across_direction.py index dc3edc3b1d..c0965d8e51 100644 --- a/src/spikeinterface/preprocessing/tests/test_average_across_direction.py +++ b/src/spikeinterface/preprocessing/tests/test_average_across_direction.py @@ -37,6 +37,13 @@ def test_average_across_direction(): assert np.all(geom_avgy[:2, 0] == 0) assert np.all(geom_avgy[2, 0] == 1.5) + # test with channel ids + # use chans at y in (1, 2) + traces = rec_avgy.get_traces(channel_ids=["0-1", "2-3"]) + assert traces.shape == (100, 2) + assert np.all(traces[:, 0] == 0.5) + assert np.all(traces[:, 1] == 2.5) + # test averaging across x rec_avgx = average_across_direction(rec, direction="x") traces = rec_avgx.get_traces() diff --git a/src/spikeinterface/preprocessing/tests/test_clip.py b/src/spikeinterface/preprocessing/tests/test_clip.py index 724ba2c963..c18c7d37af 100644 --- a/src/spikeinterface/preprocessing/tests/test_clip.py +++ b/src/spikeinterface/preprocessing/tests/test_clip.py @@ -14,12 +14,12 @@ def test_clip(): rec1 = clip(rec, a_min=-1.5) rec1.save(verbose=False) - traces0 = rec0.get_traces(segment_index=0, channel_ids=[1]) + traces0 = rec0.get_traces(segment_index=0, channel_ids=["1"]) assert traces0.shape[1] == 1 assert np.all(-2 <= traces0[0] <= 3) - traces1 = rec1.get_traces(segment_index=0, channel_ids=[0, 1]) + traces1 = rec1.get_traces(segment_index=0, channel_ids=["0", "1"]) assert traces1.shape[1] == 2 assert np.all(-1.5 <= traces1[1]) @@ -34,11 +34,11 @@ def test_blank_staturation(): rec1 = blank_staturation(rec, quantile_threshold=0.01, direction="both", chunk_size=10000) rec1.save(verbose=False) - traces0 = rec0.get_traces(segment_index=0, channel_ids=[1]) + traces0 = rec0.get_traces(segment_index=0, channel_ids=["1"]) assert traces0.shape[1] == 1 assert np.all(traces0 < 3.0) - traces1 = rec1.get_traces(segment_index=0, channel_ids=[0]) + traces1 = rec1.get_traces(segment_index=0, channel_ids=["0"]) assert traces1.shape[1] == 1 # use a smaller value to be sure a_min = rec1._recording_segments[0].a_min diff --git a/src/spikeinterface/preprocessing/tests/test_decimate.py b/src/spikeinterface/preprocessing/tests/test_decimate.py index 100972f762..aab17560a6 100644 --- a/src/spikeinterface/preprocessing/tests/test_decimate.py +++ b/src/spikeinterface/preprocessing/tests/test_decimate.py @@ -8,19 +8,14 @@ import numpy as np -@pytest.mark.parametrize("N_segments", [1, 2]) -@pytest.mark.parametrize("decimation_offset", [0, 1, 9, 10, 11, 100, 101]) -@pytest.mark.parametrize("decimation_factor", [1, 9, 10, 11, 100, 101]) -@pytest.mark.parametrize("start_frame", [0, 1, 5, None, 1000]) -@pytest.mark.parametrize("end_frame", [0, 1, 5, None, 1000]) -def test_decimate(N_segments, decimation_offset, decimation_factor, start_frame, end_frame): - rec = generate_recording() - - segment_num_samps = [101 + i for i in range(N_segments)] - +@pytest.mark.parametrize("num_segments", [1, 2]) +@pytest.mark.parametrize("decimation_offset", [0, 1, 5, 21, 101]) +@pytest.mark.parametrize("decimation_factor", [1, 7, 50]) +def test_decimate(num_segments, decimation_offset, decimation_factor): + segment_num_samps = [20000, 40000] rec = NumpyRecording([np.arange(2 * N).reshape(N, 2) for N in segment_num_samps], 1) - parent_traces = [rec.get_traces(i) for i in range(N_segments)] + parent_traces = [rec.get_traces(i) for i in range(num_segments)] if decimation_offset >= min(segment_num_samps) or decimation_offset >= decimation_factor: with pytest.raises(ValueError): @@ -28,19 +23,59 @@ def test_decimate(N_segments, decimation_offset, decimation_factor, start_frame, return decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset) - decimated_parent_traces = [parent_traces[i][decimation_offset::decimation_factor] for i in range(N_segments)] + decimated_parent_traces = [parent_traces[i][decimation_offset::decimation_factor] for i in range(num_segments)] - if start_frame is None: - start_frame = max(decimated_rec.get_num_samples(i) for i in range(N_segments)) - if end_frame is None: - end_frame = max(decimated_rec.get_num_samples(i) for i in range(N_segments)) + for start_frame in [0, 1, 5, None, 1000]: + for end_frame in [0, 1, 5, None, 1000]: + if start_frame is None: + start_frame = max(decimated_rec.get_num_samples(i) for i in range(num_segments)) + if end_frame is None: + end_frame = max(decimated_rec.get_num_samples(i) for i in range(num_segments)) - for i in range(N_segments): + for i in range(num_segments): + assert decimated_rec.get_num_samples(i) == decimated_parent_traces[i].shape[0] + assert np.all( + decimated_rec.get_traces(i, start_frame, end_frame) + == decimated_parent_traces[i][start_frame:end_frame] + ) + + for i in range(num_segments): assert decimated_rec.get_num_samples(i) == decimated_parent_traces[i].shape[0] assert np.all( decimated_rec.get_traces(i, start_frame, end_frame) == decimated_parent_traces[i][start_frame:end_frame] ) +def test_decimate_with_times(): + rec = generate_recording(durations=[5, 10]) + + # test with times + times = [rec.get_times(0) + 10, rec.get_times(1) + 20] + for i, t in enumerate(times): + rec.set_times(t, i) + + decimation_factor = 2 + decimation_offset = 1 + decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset) + + for segment_index in range(rec.get_num_segments()): + assert np.allclose( + decimated_rec.get_times(segment_index), + rec.get_times(segment_index)[decimation_offset::decimation_factor], + ) + + # test with t_start + rec = generate_recording(durations=[5, 10]) + t_starts = [10, 20] + for t_start, rec_segment in zip(t_starts, rec._recording_segments): + rec_segment.t_start = t_start + decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset) + for segment_index in range(rec.get_num_segments()): + assert np.allclose( + decimated_rec.get_times(segment_index), + rec.get_times(segment_index)[decimation_offset::decimation_factor], + ) + + if __name__ == "__main__": test_decimate() diff --git a/src/spikeinterface/preprocessing/tests/test_filter.py b/src/spikeinterface/preprocessing/tests/test_filter.py index 9df60af3db..bf723c84b9 100644 --- a/src/spikeinterface/preprocessing/tests/test_filter.py +++ b/src/spikeinterface/preprocessing/tests/test_filter.py @@ -46,7 +46,7 @@ def test_causal_filter_main_kwargs(self, recording_and_data): filt_data = causal_filter(recording, direction="forward", **options, margin_ms=0).get_traces() - assert np.allclose(test_data, filt_data, rtol=0, atol=1e-6) + assert np.allclose(test_data, filt_data, rtol=0, atol=1e-2) # Then, change all kwargs to ensure they are propagated # and check the backwards version. @@ -66,7 +66,7 @@ def test_causal_filter_main_kwargs(self, recording_and_data): filt_data = causal_filter(recording, direction="backward", **options, margin_ms=0).get_traces() - assert np.allclose(test_data, filt_data, rtol=0, atol=1e-6) + assert np.allclose(test_data, filt_data, rtol=0, atol=1e-2) def test_causal_filter_custom_coeff(self, recording_and_data): """ @@ -89,7 +89,7 @@ def test_causal_filter_custom_coeff(self, recording_and_data): filt_data = causal_filter(recording, direction="forward", **options, margin_ms=0).get_traces() - assert np.allclose(test_data, filt_data, rtol=0, atol=1e-6, equal_nan=True) + assert np.allclose(test_data, filt_data, rtol=0, atol=1e-2, equal_nan=True) # Next, in "sos" mode options["filter_mode"] = "sos" @@ -100,7 +100,7 @@ def test_causal_filter_custom_coeff(self, recording_and_data): filt_data = causal_filter(recording, direction="forward", **options, margin_ms=0).get_traces() - assert np.allclose(test_data, filt_data, rtol=0, atol=1e-6, equal_nan=True) + assert np.allclose(test_data, filt_data, rtol=0, atol=1e-2, equal_nan=True) def test_causal_kwarg_error_raised(self, recording_and_data): """ diff --git a/src/spikeinterface/preprocessing/tests/test_filter_gaussian.py b/src/spikeinterface/preprocessing/tests/test_filter_gaussian.py index 54682f2e94..3682b186f2 100644 --- a/src/spikeinterface/preprocessing/tests/test_filter_gaussian.py +++ b/src/spikeinterface/preprocessing/tests/test_filter_gaussian.py @@ -1,7 +1,7 @@ import numpy as np import pytest from pathlib import Path -from spikeinterface.core import load_extractor, set_global_tmp_folder +from spikeinterface.core import load, set_global_tmp_folder from spikeinterface.core.testing import check_recordings_equal from spikeinterface.core.generate import generate_recording from spikeinterface.preprocessing import gaussian_filter @@ -23,7 +23,7 @@ def test_filter_gaussian(tmp_path): assert rec_filtered.get_traces(segment_index=1, start_frame=rec_filtered.get_num_frames(1) - 200).shape == (200, 3) # Check dumpability - saved_loaded = load_extractor(rec_filtered.to_dict()) + saved_loaded = load(rec_filtered.to_dict()) check_recordings_equal(rec_filtered, saved_loaded, return_scaled=False) saved_1job = rec_filtered.save(folder=tmp_path / "1job") diff --git a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py index 1189f04f7d..06bde4e3d1 100644 --- a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py +++ b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py @@ -163,7 +163,9 @@ def test_output_values(): expected_weights = np.r_[np.tile(np.exp(-2), 3), np.exp(-4)] expected_weights /= np.sum(expected_weights) - si_interpolated_recording = spre.interpolate_bad_channels(recording, bad_channel_indexes, sigma_um=1, p=1) + si_interpolated_recording = spre.interpolate_bad_channels( + recording, bad_channel_ids=bad_channel_ids, sigma_um=1, p=1 + ) si_interpolated = si_interpolated_recording.get_traces() expected_ts = si_interpolated[:, 1:] @ expected_weights diff --git a/src/spikeinterface/preprocessing/tests/test_normalize_scale.py b/src/spikeinterface/preprocessing/tests/test_normalize_scale.py index 576b570832..151752e0e6 100644 --- a/src/spikeinterface/preprocessing/tests/test_normalize_scale.py +++ b/src/spikeinterface/preprocessing/tests/test_normalize_scale.py @@ -15,7 +15,7 @@ def test_normalize_by_quantile(): rec2 = normalize_by_quantile(rec, mode="by_channel") rec2.save(verbose=False) - traces = rec2.get_traces(segment_index=0, channel_ids=[1]) + traces = rec2.get_traces(segment_index=0, channel_ids=["1"]) assert traces.shape[1] == 1 rec2 = normalize_by_quantile(rec, mode="pool_channel") diff --git a/src/spikeinterface/preprocessing/tests/test_rectify.py b/src/spikeinterface/preprocessing/tests/test_rectify.py index b8bb31015e..a2a06e7a1f 100644 --- a/src/spikeinterface/preprocessing/tests/test_rectify.py +++ b/src/spikeinterface/preprocessing/tests/test_rectify.py @@ -15,7 +15,7 @@ def test_rectify(): rec2 = rectify(rec) rec2.save(verbose=False) - traces = rec2.get_traces(segment_index=0, channel_ids=[1]) + traces = rec2.get_traces(segment_index=0, channel_ids=["1"]) assert traces.shape[1] == 1 # import matplotlib.pyplot as plt diff --git a/src/spikeinterface/preprocessing/tests/test_scaling.py b/src/spikeinterface/preprocessing/tests/test_scaling.py index 321d7c9df2..e32d96901e 100644 --- a/src/spikeinterface/preprocessing/tests/test_scaling.py +++ b/src/spikeinterface/preprocessing/tests/test_scaling.py @@ -55,11 +55,11 @@ def test_scaling_in_preprocessing_chain(): recording.set_channel_gains(gains) recording.set_channel_offsets(offsets) - centered_recording = CenterRecording(scale_to_uV(recording=recording)) + centered_recording = CenterRecording(scale_to_uV(recording=recording), seed=2205) traces_scaled_with_argument = centered_recording.get_traces(return_scaled=True) # Chain preprocessors - centered_recording_scaled = CenterRecording(scale_to_uV(recording=recording)) + centered_recording_scaled = CenterRecording(scale_to_uV(recording=recording), seed=2205) traces_scaled_with_preprocessor = centered_recording_scaled.get_traces() np.testing.assert_allclose(traces_scaled_with_argument, traces_scaled_with_preprocessor) @@ -68,3 +68,8 @@ def test_scaling_in_preprocessing_chain(): traces_scaled_with_preprocessor_and_argument = centered_recording_scaled.get_traces(return_scaled=True) np.testing.assert_allclose(traces_scaled_with_preprocessor, traces_scaled_with_preprocessor_and_argument) + + +if __name__ == "__main__": + test_scale_to_uV() + test_scaling_in_preprocessing_chain() diff --git a/src/spikeinterface/preprocessing/tests/test_silence.py b/src/spikeinterface/preprocessing/tests/test_silence.py index 6c2e8ec8b5..20d4f6dfc7 100644 --- a/src/spikeinterface/preprocessing/tests/test_silence.py +++ b/src/spikeinterface/preprocessing/tests/test_silence.py @@ -9,6 +9,8 @@ import numpy as np +from pathlib import Path + def test_silence(create_cache_folder): @@ -46,4 +48,5 @@ def test_silence(create_cache_folder): if __name__ == "__main__": - test_silence() + cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" + test_silence(cache_folder) diff --git a/src/spikeinterface/preprocessing/tests/test_whiten.py b/src/spikeinterface/preprocessing/tests/test_whiten.py index 04b731de4f..b40627d836 100644 --- a/src/spikeinterface/preprocessing/tests/test_whiten.py +++ b/src/spikeinterface/preprocessing/tests/test_whiten.py @@ -5,13 +5,15 @@ from spikeinterface.preprocessing import whiten, scale, compute_whitening_matrix +from pathlib import Path + def test_whiten(create_cache_folder): cache_folder = create_cache_folder rec = generate_recording(num_channels=4, seed=2205) print(rec.get_channel_locations()) - random_chunk_kwargs = {} + random_chunk_kwargs = {"seed": 2205} W1, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, radius_um=None) # print(W) # print(M) @@ -47,4 +49,5 @@ def test_whiten(create_cache_folder): if __name__ == "__main__": - test_whiten() + cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" + test_whiten(cache_folder) diff --git a/src/spikeinterface/preprocessing/whiten.py b/src/spikeinterface/preprocessing/whiten.py index 195969ff79..57400c1199 100644 --- a/src/spikeinterface/preprocessing/whiten.py +++ b/src/spikeinterface/preprocessing/whiten.py @@ -19,6 +19,8 @@ class WhitenRecording(BasePreprocessor): recording : RecordingExtractor The recording extractor to be whitened. dtype : None or dtype, default: None + Datatype of the output recording (covariance matrix estimation + and whitening are performed in float32). If None the the parent dtype is kept. For integer dtype a int_scale must be also given. mode : "global" | "local", default: "global" @@ -74,7 +76,9 @@ def __init__( dtype_ = fix_dtype(recording, dtype) if dtype_.kind == "i": - assert int_scale is not None, "For recording with dtype=int you must set dtype=float32 OR set a int_scale" + assert ( + int_scale is not None + ), "For recording with dtype=int you must set the output dtype to float OR set a int_scale" if W is not None: W = np.asarray(W) @@ -124,7 +128,7 @@ def __init__(self, parent_recording_segment, W, M, dtype, int_scale): def get_traces(self, start_frame, end_frame, channel_indices): traces = self.parent_recording_segment.get_traces(start_frame, end_frame, slice(None)) traces_dtype = traces.dtype - # if uint --> force int + # if uint --> force float if traces_dtype.kind == "u": traces = traces.astype("float32") @@ -185,6 +189,7 @@ def compute_whitening_matrix( """ random_data = get_random_data_chunks(recording, concatenated=True, return_scaled=False, **random_chunk_kwargs) + random_data = random_data.astype(np.float32) regularize_kwargs = regularize_kwargs if regularize_kwargs is not None else {"method": "GraphicalLassoCV"} diff --git a/src/spikeinterface/qualitymetrics/__init__.py b/src/spikeinterface/qualitymetrics/__init__.py index 9d604f6ae2..754c82d8e3 100644 --- a/src/spikeinterface/qualitymetrics/__init__.py +++ b/src/spikeinterface/qualitymetrics/__init__.py @@ -6,4 +6,4 @@ get_default_qm_params, ) from .pca_metrics import get_quality_pca_metric_list -from .misc_metrics import get_synchrony_counts +from .misc_metrics import _get_synchrony_counts diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index 8dfd41cf88..6007de379c 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -520,7 +520,7 @@ def compute_sliding_rp_violations( ) -def get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids): +def _get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids): """ Compute synchrony counts, the number of simultaneous spikes with sizes `synchrony_sizes`. @@ -528,10 +528,10 @@ def get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids): ---------- spikes : np.array Structured numpy array with fields ("sample_index", "unit_index", "segment_index"). - synchrony_sizes : numpy array - The synchrony sizes to compute. Should be pre-sorted. all_unit_ids : list or None, default: None List of unit ids to compute the synchrony metrics. Expecting all units. + synchrony_sizes : None or np.array, default: None + The synchrony sizes to compute. Should be pre-sorted. Returns ------- @@ -565,37 +565,38 @@ def get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids): return synchrony_counts -def compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=(2, 4, 8), unit_ids=None): +def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=None): """ Compute synchrony metrics. Synchrony metrics represent the rate of occurrences of - "synchrony_size" spikes at the exact same sample index. + spikes at the exact same sample index, with synchrony sizes 2, 4 and 8. Parameters ---------- sorting_analyzer : SortingAnalyzer A SortingAnalyzer object. - synchrony_sizes : list or tuple, default: (2, 4, 8) - The synchrony sizes to compute. unit_ids : list or None, default: None List of unit ids to compute the synchrony metrics. If None, all units are used. + synchrony_sizes: None, default: None + Deprecated argument. Please use private `_get_synchrony_counts` if you need finer control over number of synchronous spikes. Returns ------- sync_spike_{X} : dict The synchrony metric for synchrony size X. - Returns are as many as synchrony_sizes. References ---------- Based on concepts described in [Grün]_ This code was adapted from `Elephant - Electrophysiology Analysis Toolkit `_ """ - assert min(synchrony_sizes) > 1, "Synchrony sizes must be greater than 1" - # Sort the synchrony times so we can slice numpy arrays, instead of using dicts - synchrony_sizes_np = np.array(synchrony_sizes, dtype=np.int16) - synchrony_sizes_np.sort() - res = namedtuple("synchrony_metrics", [f"sync_spike_{size}" for size in synchrony_sizes_np]) + if synchrony_sizes is not None: + warning_message = "Custom `synchrony_sizes` is deprecated; the `synchrony_metrics` will be computed using `synchrony_sizes = [2,4,8]`" + warnings.warn(warning_message, DeprecationWarning, stacklevel=2) + + synchrony_sizes = np.array([2, 4, 8]) + + res = namedtuple("synchrony_metrics", [f"sync_spike_{size}" for size in synchrony_sizes]) sorting = sorting_analyzer.sorting @@ -606,10 +607,10 @@ def compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=(2, 4, 8), unit_ spikes = sorting.to_spike_vector() all_unit_ids = sorting.unit_ids - synchrony_counts = get_synchrony_counts(spikes, synchrony_sizes_np, all_unit_ids) + synchrony_counts = _get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids) synchrony_metrics_dict = {} - for sync_idx, synchrony_size in enumerate(synchrony_sizes_np): + for sync_idx, synchrony_size in enumerate(synchrony_sizes): sync_id_metrics_dict = {} for i, unit_id in enumerate(all_unit_ids): if unit_id not in unit_ids: @@ -623,7 +624,7 @@ def compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=(2, 4, 8), unit_ return res(**synchrony_metrics_dict) -_default_params["synchrony"] = dict(synchrony_sizes=(2, 4, 8)) +_default_params["synchrony"] = dict() def compute_firing_ranges(sorting_analyzer, bin_size_s=5, percentiles=(5, 95), unit_ids=None): diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index 7c099a2f74..11f590869e 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -2,15 +2,17 @@ from __future__ import annotations - +import warnings from copy import deepcopy - -import numpy as np +import platform from tqdm.auto import tqdm -from concurrent.futures import ProcessPoolExecutor +from warnings import warn +import numpy as np -import warnings +import multiprocessing as mp +from concurrent.futures import ProcessPoolExecutor +from threadpoolctl import threadpool_limits from .misc_metrics import compute_num_spikes, compute_firing_rates @@ -40,6 +42,9 @@ max_spikes=10000, min_spikes=10, min_fr=0.0, n_neighbors=4, n_components=10, radius_um=100, peak_sign="neg" ), silhouette=dict(method=("simplified",)), + isolation_distance=dict(), + l_ratio=dict(), + d_prime=dict(), ) @@ -51,11 +56,14 @@ def get_quality_pca_metric_list(): def compute_pc_metrics( sorting_analyzer, metric_names=None, + metric_params=None, qm_params=None, unit_ids=None, seed=None, n_jobs=1, progress_bar=False, + mp_context=None, + max_threads_per_worker=None, ) -> dict: """ Calculate principal component derived metrics. @@ -67,7 +75,7 @@ def compute_pc_metrics( metric_names : list of str, default: None The list of PC metrics to compute. If not provided, defaults to all PC metrics. - qm_params : dict or None + metric_params : dict or None Dictionary with parameters for each PC metric function. unit_ids : list of int or None List of unit ids to compute metrics for. @@ -83,6 +91,14 @@ def compute_pc_metrics( pc_metrics : dict The computed PC metrics. """ + + if qm_params is not None and metric_params is None: + deprecation_msg = ( + "`qm_params` is deprecated and will be removed in version 0.104.0. Please use metric_params instead" + ) + warn(deprecation_msg, category=DeprecationWarning, stacklevel=2) + metric_params = qm_params + pca_ext = sorting_analyzer.get_extension("principal_components") assert pca_ext is not None, "calculate_pc_metrics() need extension 'principal_components'" @@ -90,8 +106,8 @@ def compute_pc_metrics( if metric_names is None: metric_names = _possible_pc_metric_names.copy() - if qm_params is None: - qm_params = _default_params + if metric_params is None: + metric_params = _default_params extremum_channels = get_template_extremum_channel(sorting_analyzer) @@ -144,17 +160,8 @@ def compute_pc_metrics( pcs = dense_projections[np.isin(all_labels, neighbor_unit_ids)][:, :, neighbor_channel_indices] pcs_flat = pcs.reshape(pcs.shape[0], -1) - func_args = ( - pcs_flat, - labels, - non_nn_metrics, - unit_id, - unit_ids, - qm_params, - seed, - n_spikes_all_units, - fr_all_units, - ) + func_args = (pcs_flat, labels, non_nn_metrics, unit_id, unit_ids, metric_params, max_threads_per_worker) + items.append(func_args) if not run_in_parallel and non_nn_metrics: @@ -167,7 +174,15 @@ def compute_pc_metrics( for metric_name, metric in pca_metrics_unit.items(): pc_metrics[metric_name][unit_id] = metric elif run_in_parallel and non_nn_metrics: - with ProcessPoolExecutor(n_jobs) as executor: + if mp_context is not None and platform.system() == "Windows": + assert mp_context != "fork", "'fork' mp_context not supported on Windows!" + elif mp_context == "fork" and platform.system() == "Darwin": + warnings.warn('As of Python 3.8 "fork" is no longer considered safe on macOS') + + with ProcessPoolExecutor( + max_workers=n_jobs, + mp_context=mp.get_context(mp_context), + ) as executor: results = executor.map(pca_metrics_one_unit, items) if progress_bar: results = tqdm(results, total=len(unit_ids), desc="calculate_pc_metrics") @@ -183,7 +198,7 @@ def compute_pc_metrics( units_loop = tqdm(units_loop, desc=f"calculate {metric_name} metric", total=len(unit_ids)) func = _nn_metric_name_to_func[metric_name] - metric_params = qm_params[metric_name] if metric_name in qm_params else {} + metric_params = metric_params[metric_name] if metric_name in metric_params else {} for _, unit_id in units_loop: try: @@ -212,7 +227,7 @@ def compute_pc_metrics( def calculate_pc_metrics( - sorting_analyzer, metric_names=None, qm_params=None, unit_ids=None, seed=None, n_jobs=1, progress_bar=False + sorting_analyzer, metric_names=None, metric_params=None, unit_ids=None, seed=None, n_jobs=1, progress_bar=False ): warnings.warn( "The `calculate_pc_metrics` function is deprecated and will be removed in 0.103.0. Please use compute_pc_metrics instead", @@ -223,7 +238,7 @@ def calculate_pc_metrics( pc_metrics = compute_pc_metrics( sorting_analyzer, metric_names=metric_names, - qm_params=qm_params, + metric_params=metric_params, unit_ids=unit_ids, seed=seed, n_jobs=n_jobs, @@ -976,26 +991,20 @@ def _compute_isolation(pcs_target_unit, pcs_other_unit, n_neighbors: int): def pca_metrics_one_unit(args): - ( - pcs_flat, - labels, - metric_names, - unit_id, - unit_ids, - qm_params, - seed, - # we_folder, - n_spikes_all_units, - fr_all_units, - ) = args - - # if "nn_isolation" in metric_names or "nn_noise_overlap" in metric_names: - # we = load_waveforms(we_folder) + (pcs_flat, labels, metric_names, unit_id, unit_ids, metric_params, max_threads_per_worker) = args + + if max_threads_per_worker is None: + return _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, metric_params) + else: + with threadpool_limits(limits=int(max_threads_per_worker)): + return _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, metric_params) + + +def _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, metric_params): pc_metrics = {} # metrics if "isolation_distance" in metric_names or "l_ratio" in metric_names: - try: isolation_distance, l_ratio = mahalanobis_metrics(pcs_flat, labels, unit_id) except: @@ -1021,7 +1030,7 @@ def pca_metrics_one_unit(args): if "nearest_neighbor" in metric_names: try: nn_hit_rate, nn_miss_rate = nearest_neighbors_metrics( - pcs_flat, labels, unit_id, **qm_params["nearest_neighbor"] + pcs_flat, labels, unit_id, **metric_params["nearest_neighbor"] ) except: nn_hit_rate = np.nan @@ -1030,7 +1039,7 @@ def pca_metrics_one_unit(args): pc_metrics["nn_miss_rate"] = nn_miss_rate if "silhouette" in metric_names: - silhouette_method = qm_params["silhouette"]["method"] + silhouette_method = metric_params["silhouette"]["method"] if "simplified" in silhouette_method: try: unit_silhouette_score = simplified_silhouette_score(pcs_flat, labels, unit_id) diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 3b6c6d3e50..11ce3d0160 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -6,6 +6,7 @@ from copy import deepcopy import numpy as np +from warnings import warn from spikeinterface.core.job_tools import fix_job_kwargs from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension @@ -15,7 +16,8 @@ compute_pc_metrics, _misc_metric_name_to_func, _possible_pc_metric_names, - compute_name_to_column_names, + qm_compute_name_to_column_names, + column_name_to_column_dtype, ) from .misc_metrics import _default_params as misc_metrics_params from .pca_metrics import _default_params as pca_metrics_params @@ -31,7 +33,7 @@ class ComputeQualityMetrics(AnalyzerExtension): A SortingAnalyzer object. metric_names : list or None List of quality metrics to compute. - qm_params : dict or None + metric_params : dict of dicts or None Dictionary with parameters for quality metrics calculation. Default parameters can be obtained with: `si.qualitymetrics.get_default_qm_params()` skip_pc_metrics : bool, default: False @@ -54,10 +56,18 @@ class ComputeQualityMetrics(AnalyzerExtension): need_recording = False use_nodepipeline = False need_job_kwargs = True + need_backward_compatibility_on_load = True + + def _handle_backward_compatibility_on_load(self): + # For backwards compatibility - this renames qm_params as metric_params + if (qm_params := self.params.get("qm_params")) is not None: + self.params["metric_params"] = qm_params + del self.params["qm_params"] def _set_params( self, metric_names=None, + metric_params=None, qm_params=None, peak_sign=None, seed=None, @@ -65,6 +75,12 @@ def _set_params( delete_existing_metrics=False, metrics_to_compute=None, ): + if qm_params is not None and metric_params is None: + deprecation_msg = ( + "`qm_params` is deprecated and will be removed in version 0.104.0 Please use metric_params instead" + ) + metric_params = qm_params + warn(deprecation_msg, category=DeprecationWarning, stacklevel=2) if metric_names is None: metric_names = list(_misc_metric_name_to_func.keys()) @@ -80,12 +96,12 @@ def _set_params( if "drift" in metric_names: metric_names.remove("drift") - qm_params_ = get_default_qm_params() - for k in qm_params_: - if qm_params is not None and k in qm_params: - qm_params_[k].update(qm_params[k]) - if "peak_sign" in qm_params_[k] and peak_sign is not None: - qm_params_[k]["peak_sign"] = peak_sign + metric_params_ = get_default_qm_params() + for k in metric_params_: + if metric_params is not None and k in metric_params: + metric_params_[k].update(metric_params[k]) + if "peak_sign" in metric_params_[k] and peak_sign is not None: + metric_params_[k]["peak_sign"] = peak_sign metrics_to_compute = metric_names qm_extension = self.sorting_analyzer.get_extension("quality_metrics") @@ -101,7 +117,7 @@ def _set_params( metric_names=metric_names, peak_sign=peak_sign, seed=seed, - qm_params=qm_params_, + metric_params=metric_params_, skip_pc_metrics=skip_pc_metrics, delete_existing_metrics=delete_existing_metrics, metrics_to_compute=metrics_to_compute, @@ -125,13 +141,20 @@ def _merge_extension_data( all_unit_ids = new_sorting_analyzer.unit_ids not_new_ids = all_unit_ids[~np.isin(all_unit_ids, new_unit_ids)] + # this creates a new metrics dictionary, but the dtype for everything will be + # object. So we will need to fix this later after computing metrics metrics = pd.DataFrame(index=all_unit_ids, columns=old_metrics.columns) - metrics.loc[not_new_ids, :] = old_metrics.loc[not_new_ids, :] metrics.loc[new_unit_ids, :] = self._compute_metrics( new_sorting_analyzer, new_unit_ids, verbose, metric_names, **job_kwargs ) + # we need to fix the dtypes after we compute everything because we have nans + # we can iterate through the columns and convert them back to the dtype + # of the original quality dataframe. + for column in old_metrics.columns: + metrics[column] = metrics[column].astype(old_metrics[column].dtype) + new_data = dict(metrics=metrics) return new_data @@ -141,7 +164,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri """ import pandas as pd - qm_params = self.params["qm_params"] + metric_params = self.params["metric_params"] # sparsity = self.params["sparsity"] seed = self.params["seed"] @@ -177,7 +200,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri func = _misc_metric_name_to_func[metric_name] - params = qm_params[metric_name] if metric_name in qm_params else {} + params = metric_params[metric_name] if metric_name in metric_params else {} res = func(sorting_analyzer, unit_ids=non_empty_unit_ids, **params) # QM with uninstall dependencies might return None if res is not None: @@ -205,7 +228,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri # sparsity=sparsity, progress_bar=progress_bar, n_jobs=n_jobs, - qm_params=qm_params, + metric_params=metric_params, seed=seed, ) for col, values in pc_metrics.items(): @@ -214,10 +237,20 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri # add NaN for empty units if len(empty_unit_ids) > 0: metrics.loc[empty_unit_ids] = np.nan + # num_spikes is an int and should be 0 + if "num_spikes" in metrics.columns: + metrics.loc[empty_unit_ids, ["num_spikes"]] = 0 # we use the convert_dtypes to convert the columns to the most appropriate dtype and avoid object columns # (in case of NaN values) metrics = metrics.convert_dtypes() + + # we do this because the convert_dtypes infers the wrong types sometimes. + # the actual types for columns can be found in column_name_to_column_dtype dictionary. + for column in metrics.columns: + if column in column_name_to_column_dtype: + metrics[column] = metrics[column].astype(column_name_to_column_dtype[column]) + return metrics def _run(self, verbose=False, **job_kwargs): @@ -234,7 +267,8 @@ def _run(self, verbose=False, **job_kwargs): ) existing_metrics = [] - qm_extension = self.sorting_analyzer.get_extension("quality_metrics") + # here we get in the loaded via the dict only (to avoid full loading from disk after params reset) + qm_extension = self.sorting_analyzer.extensions.get("quality_metrics", None) if ( delete_existing_metrics is False and qm_extension is not None @@ -245,7 +279,7 @@ def _run(self, verbose=False, **job_kwargs): # append the metrics which were previously computed for metric_name in set(existing_metrics).difference(metrics_to_compute): # some metrics names produce data columns with other names. This deals with that. - for column_name in compute_name_to_column_names[metric_name]: + for column_name in qm_compute_name_to_column_names[metric_name]: computed_metrics[column_name] = qm_extension.data["metrics"][column_name] self.data["metrics"] = computed_metrics diff --git a/src/spikeinterface/qualitymetrics/quality_metric_list.py b/src/spikeinterface/qualitymetrics/quality_metric_list.py index 375dd320ae..23b781eb9d 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_list.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_list.py @@ -55,7 +55,7 @@ } # a dict converting the name of the metric for computation to the output of that computation -compute_name_to_column_names = { +qm_compute_name_to_column_names = { "num_spikes": ["num_spikes"], "firing_rate": ["firing_rate"], "presence_ratio": ["presence_ratio"], @@ -66,7 +66,11 @@ "amplitude_cutoff": ["amplitude_cutoff"], "amplitude_median": ["amplitude_median"], "amplitude_cv": ["amplitude_cv_median", "amplitude_cv_range"], - "synchrony": ["sync_spike_2", "sync_spike_4", "sync_spike_8"], + "synchrony": [ + "sync_spike_2", + "sync_spike_4", + "sync_spike_8", + ], "firing_range": ["firing_range"], "drift": ["drift_ptp", "drift_std", "drift_mad"], "sd_ratio": ["sd_ratio"], @@ -79,3 +83,38 @@ "silhouette": ["silhouette"], "silhouette_full": ["silhouette_full"], } + +# this dict allows us to ensure the appropriate dtype of metrics rather than allow Pandas to infer them +column_name_to_column_dtype = { + "num_spikes": int, + "firing_rate": float, + "presence_ratio": float, + "snr": float, + "isi_violations_ratio": float, + "isi_violations_count": float, + "rp_violations": float, + "rp_contamination": float, + "sliding_rp_violation": float, + "amplitude_cutoff": float, + "amplitude_median": float, + "amplitude_cv_median": float, + "amplitude_cv_range": float, + "sync_spike_2": float, + "sync_spike_4": float, + "sync_spike_8": float, + "firing_range": float, + "drift_ptp": float, + "drift_std": float, + "drift_mad": float, + "sd_ratio": float, + "isolation_distance": float, + "l_ratio": float, + "d_prime": float, + "nn_hit_rate": float, + "nn_miss_rate": float, + "nn_isolation": float, + "nn_unit_id": float, + "nn_noise_overlap": float, + "silhouette": float, + "silhouette_full": float, +} diff --git a/src/spikeinterface/qualitymetrics/tests/conftest.py b/src/spikeinterface/qualitymetrics/tests/conftest.py index 01fa16c8d7..39bc62ae12 100644 --- a/src/spikeinterface/qualitymetrics/tests/conftest.py +++ b/src/spikeinterface/qualitymetrics/tests/conftest.py @@ -8,14 +8,18 @@ job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") -@pytest.fixture(scope="module") -def small_sorting_analyzer(): +def make_small_analyzer(): recording, sorting = generate_ground_truth_recording( durations=[2.0], num_units=10, seed=1205, ) + channel_ids_as_integers = [id for id in range(recording.get_num_channels())] + unit_ids_as_integers = [id for id in range(sorting.get_num_units())] + recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers) + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers) + sorting = sorting.select_units([2, 7, 0], ["#3", "#9", "#4"]) sorting_analyzer = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory") @@ -35,6 +39,11 @@ def small_sorting_analyzer(): return sorting_analyzer +@pytest.fixture(scope="module") +def small_sorting_analyzer(): + return make_small_analyzer() + + @pytest.fixture(scope="module") def sorting_analyzer_simple(): # we need high firing rate for amplitude_cutoff @@ -60,6 +69,11 @@ def sorting_analyzer_simple(): seed=1205, ) + channel_ids_as_integers = [id for id in range(recording.get_num_channels())] + unit_ids_as_integers = [id for id in range(sorting.get_num_units())] + recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers) + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers) + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True) sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=1205) diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index 4c0890b62b..18b49cd862 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -39,7 +39,7 @@ compute_firing_ranges, compute_amplitude_cv_metrics, compute_sd_ratio, - get_synchrony_counts, + _get_synchrony_counts, compute_quality_metrics, ) @@ -69,7 +69,7 @@ def test_compute_new_quality_metrics(small_sorting_analyzer): assert calculated_metrics == ["snr"] small_sorting_analyzer.compute( - {"quality_metrics": {"metric_names": list(qm_params.keys()), "qm_params": qm_params}} + {"quality_metrics": {"metric_names": list(qm_params.keys()), "metric_params": qm_params}} ) small_sorting_analyzer.compute({"quality_metrics": {"metric_names": ["snr"]}}) @@ -96,13 +96,13 @@ def test_compute_new_quality_metrics(small_sorting_analyzer): # check that, when parameters are changed, the data and metadata are updated old_snr_data = deepcopy(quality_metric_extension.get_data()["snr"].values) small_sorting_analyzer.compute( - {"quality_metrics": {"metric_names": ["snr"], "qm_params": {"snr": {"peak_mode": "peak_to_peak"}}}} + {"quality_metrics": {"metric_names": ["snr"], "metric_params": {"snr": {"peak_mode": "peak_to_peak"}}}} ) new_quality_metric_extension = small_sorting_analyzer.get_extension("quality_metrics") new_snr_data = new_quality_metric_extension.get_data()["snr"].values assert np.all(old_snr_data != new_snr_data) - assert new_quality_metric_extension.params["qm_params"]["snr"]["peak_mode"] == "peak_to_peak" + assert new_quality_metric_extension.params["metric_params"]["snr"]["peak_mode"] == "peak_to_peak" # check that all quality metrics are deleted when parents are recomputed, even after # recomputation @@ -280,10 +280,10 @@ def test_unit_id_order_independence(small_sorting_analyzer): } quality_metrics_1 = compute_quality_metrics( - small_sorting_analyzer, metric_names=get_quality_metric_list(), qm_params=qm_params + small_sorting_analyzer, metric_names=get_quality_metric_list(), metric_params=qm_params ) quality_metrics_2 = compute_quality_metrics( - small_sorting_analyzer_2, metric_names=get_quality_metric_list(), qm_params=qm_params + small_sorting_analyzer_2, metric_names=get_quality_metric_list(), metric_params=qm_params ) for metric, metric_2_data in quality_metrics_2.items(): @@ -352,7 +352,7 @@ def test_synchrony_counts_no_sync(): one_spike["sample_index"] = spike_times one_spike["unit_index"] = spike_units - sync_count = get_synchrony_counts(one_spike, np.array((2)), [0]) + sync_count = _get_synchrony_counts(one_spike, np.array([2, 4, 8]), [0]) assert np.all(sync_count[0] == np.array([0])) @@ -372,7 +372,7 @@ def test_synchrony_counts_one_sync(): two_spikes["sample_index"] = np.concatenate((spike_indices, added_spikes_indices)) two_spikes["unit_index"] = np.concatenate((spike_labels, added_spikes_labels)) - sync_count = get_synchrony_counts(two_spikes, np.array((2)), [0, 1]) + sync_count = _get_synchrony_counts(two_spikes, np.array([2, 4, 8]), [0, 1]) assert np.all(sync_count[0] == np.array([1, 1])) @@ -392,7 +392,7 @@ def test_synchrony_counts_one_quad_sync(): four_spikes["sample_index"] = np.concatenate((spike_indices, added_spikes_indices)) four_spikes["unit_index"] = np.concatenate((spike_labels, added_spikes_labels)) - sync_count = get_synchrony_counts(four_spikes, np.array((2, 4)), [0, 1, 2, 3]) + sync_count = _get_synchrony_counts(four_spikes, np.array([2, 4, 8]), [0, 1, 2, 3]) assert np.all(sync_count[0] == np.array([1, 1, 1, 1])) assert np.all(sync_count[1] == np.array([1, 1, 1, 1])) @@ -409,7 +409,7 @@ def test_synchrony_counts_not_all_units(): three_spikes["sample_index"] = np.concatenate((spike_indices, added_spikes_indices)) three_spikes["unit_index"] = np.concatenate((spike_labels, added_spikes_labels)) - sync_count = get_synchrony_counts(three_spikes, np.array((2)), [0, 1, 2]) + sync_count = _get_synchrony_counts(three_spikes, np.array([2, 4, 8]), [0, 1, 2]) assert np.all(sync_count[0] == np.array([0, 1, 1])) @@ -610,9 +610,9 @@ def test_calculate_rp_violations(sorting_analyzer_violations): def test_synchrony_metrics(sorting_analyzer_simple): sorting_analyzer = sorting_analyzer_simple sorting = sorting_analyzer.sorting - synchrony_sizes = (2, 3, 4) - synchrony_metrics = compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=synchrony_sizes) - print(synchrony_metrics) + synchrony_metrics = compute_synchrony_metrics(sorting_analyzer) + + synchrony_sizes = np.array([2, 4, 8]) # check returns for size in synchrony_sizes: @@ -625,10 +625,8 @@ def test_synchrony_metrics(sorting_analyzer_simple): sorting_sync = add_synchrony_to_sorting(sorting, sync_event_ratio=sync_level) sorting_analyzer_sync = create_sorting_analyzer(sorting_sync, sorting_analyzer.recording, format="memory") - previous_synchrony_metrics = compute_synchrony_metrics( - previous_sorting_analyzer, synchrony_sizes=synchrony_sizes - ) - current_synchrony_metrics = compute_synchrony_metrics(sorting_analyzer_sync, synchrony_sizes=synchrony_sizes) + previous_synchrony_metrics = compute_synchrony_metrics(previous_sorting_analyzer) + current_synchrony_metrics = compute_synchrony_metrics(sorting_analyzer_sync) print(current_synchrony_metrics) # check that all values increased for i, col in enumerate(previous_synchrony_metrics._fields): @@ -647,22 +645,17 @@ def test_synchrony_metrics_unit_id_subset(sorting_analyzer_simple): unit_ids_subset = [3, 7] - synchrony_sizes = (2,) - (synchrony_metrics,) = compute_synchrony_metrics( - sorting_analyzer_simple, synchrony_sizes=synchrony_sizes, unit_ids=unit_ids_subset - ) + synchrony_metrics = compute_synchrony_metrics(sorting_analyzer_simple, unit_ids=unit_ids_subset) - assert list(synchrony_metrics.keys()) == [3, 7] + assert list(synchrony_metrics.sync_spike_2.keys()) == [3, 7] + assert list(synchrony_metrics.sync_spike_4.keys()) == [3, 7] + assert list(synchrony_metrics.sync_spike_8.keys()) == [3, 7] def test_synchrony_metrics_no_unit_ids(sorting_analyzer_simple): - # all_unit_ids = sorting_analyzer_simple.sorting.unit_ids - - synchrony_sizes = (2,) - (synchrony_metrics,) = compute_synchrony_metrics(sorting_analyzer_simple, synchrony_sizes=synchrony_sizes) - - assert np.all(list(synchrony_metrics.keys()) == sorting_analyzer_simple.unit_ids) + synchrony_metrics = compute_synchrony_metrics(sorting_analyzer_simple) + assert np.all(list(synchrony_metrics.sync_spike_2.keys()) == sorting_analyzer_simple.unit_ids) @pytest.mark.sortingcomponents diff --git a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py index 6ddeb02689..287439a4f7 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py @@ -1,9 +1,7 @@ import pytest import numpy as np -from spikeinterface.qualitymetrics import ( - compute_pc_metrics, -) +from spikeinterface.qualitymetrics import compute_pc_metrics, get_quality_pca_metric_list def test_calculate_pc_metrics(small_sorting_analyzer): @@ -17,8 +15,50 @@ def test_calculate_pc_metrics(small_sorting_analyzer): res2 = pd.DataFrame(res2) for metric_name in res1.columns: + values1 = res1[metric_name].values + values2 = res1[metric_name].values + if metric_name != "nn_unit_id": - assert not np.all(np.isnan(res1[metric_name].values)) - assert not np.all(np.isnan(res2[metric_name].values)) + assert not np.all(np.isnan(values1)) + assert not np.all(np.isnan(values2)) + + if values1.dtype.kind == "f": + np.testing.assert_almost_equal(values1, values2, decimal=4) + # import matplotlib.pyplot as plt + # fig, axs = plt.subplots(nrows=2, share=True) + # ax =a xs[0] + # ax.plot(res1[metric_name].values) + # ax.plot(res2[metric_name].values) + # ax =a xs[1] + # ax.plot(res2[metric_name].values - res1[metric_name].values) + # plt.show() + else: + assert np.array_equal(values1, values2) + + +def test_pca_metrics_multi_processing(small_sorting_analyzer): + sorting_analyzer = small_sorting_analyzer + + metric_names = get_quality_pca_metric_list() + metric_names.remove("nn_isolation") + metric_names.remove("nn_noise_overlap") + + print(f"Computing PCA metrics with 1 thread per process") + res1 = compute_pc_metrics( + sorting_analyzer, n_jobs=-1, metric_names=metric_names, max_threads_per_worker=1, progress_bar=True + ) + print(f"Computing PCA metrics with 2 thread per process") + res2 = compute_pc_metrics( + sorting_analyzer, n_jobs=-1, metric_names=metric_names, max_threads_per_worker=2, progress_bar=True + ) + print("Computing PCA metrics with spawn context") + res2 = compute_pc_metrics( + sorting_analyzer, n_jobs=-1, metric_names=metric_names, max_threads_per_worker=2, progress_bar=True + ) + + +if __name__ == "__main__": + from spikeinterface.qualitymetrics.tests.conftest import make_small_analyzer - assert np.array_equal(res1[metric_name].values, res2[metric_name].values) + small_sorting_analyzer = make_small_analyzer() + test_calculate_pc_metrics(small_sorting_analyzer) diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index a6415c58e8..ea8939ebb4 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -24,14 +24,14 @@ def test_compute_quality_metrics(sorting_analyzer_simple): metrics = compute_quality_metrics( sorting_analyzer, metric_names=["snr"], - qm_params=dict(isi_violation=dict(isi_threshold_ms=2)), + metric_params=dict(isi_violation=dict(isi_threshold_ms=2)), skip_pc_metrics=True, seed=2205, ) # print(metrics) qm = sorting_analyzer.get_extension("quality_metrics") - assert qm.params["qm_params"]["isi_violation"]["isi_threshold_ms"] == 2 + assert qm.params["metric_params"]["isi_violation"]["isi_threshold_ms"] == 2 assert "snr" in metrics.columns assert "isolation_distance" not in metrics.columns @@ -40,7 +40,7 @@ def test_compute_quality_metrics(sorting_analyzer_simple): metrics = compute_quality_metrics( sorting_analyzer, metric_names=None, - qm_params=dict(isi_violation=dict(isi_threshold_ms=2)), + metric_params=dict(isi_violation=dict(isi_threshold_ms=2)), skip_pc_metrics=False, seed=2205, ) @@ -48,9 +48,10 @@ def test_compute_quality_metrics(sorting_analyzer_simple): assert "isolation_distance" in metrics.columns -def test_compute_quality_metrics_recordingless(sorting_analyzer_simple): +def test_merging_quality_metrics(sorting_analyzer_simple): sorting_analyzer = sorting_analyzer_simple + metrics = compute_quality_metrics( sorting_analyzer, metric_names=None, @@ -59,6 +60,32 @@ def test_compute_quality_metrics_recordingless(sorting_analyzer_simple): seed=2205, ) + # sorting_analyzer_simple has ten units + new_sorting_analyzer = sorting_analyzer.merge_units([[0, 1]]) + + new_metrics = new_sorting_analyzer.get_extension("quality_metrics").get_data() + + # we should copy over the metrics after merge + for column in metrics.columns: + assert column in new_metrics.columns + # should copy dtype too + assert metrics[column].dtype == new_metrics[column].dtype + + # 10 units vs 9 units + assert len(metrics.index) > len(new_metrics.index) + + +def test_compute_quality_metrics_recordingless(sorting_analyzer_simple): + + sorting_analyzer = sorting_analyzer_simple + metrics = compute_quality_metrics( + sorting_analyzer, + metric_names=None, + metric_params=dict(isi_violation=dict(isi_threshold_ms=2)), + skip_pc_metrics=False, + seed=2205, + ) + # make a copy and make it recordingless sorting_analyzer_norec = sorting_analyzer.save_as(format="memory") sorting_analyzer_norec.delete_extension("quality_metrics") @@ -68,7 +95,7 @@ def test_compute_quality_metrics_recordingless(sorting_analyzer_simple): metrics_norec = compute_quality_metrics( sorting_analyzer_norec, metric_names=None, - qm_params=dict(isi_violation=dict(isi_threshold_ms=2)), + metric_params=dict(isi_violation=dict(isi_threshold_ms=2)), skip_pc_metrics=False, seed=2205, ) @@ -101,15 +128,20 @@ def test_empty_units(sorting_analyzer_simple): metrics_empty = compute_quality_metrics( sorting_analyzer_empty, metric_names=None, - qm_params=dict(isi_violation=dict(isi_threshold_ms=2)), + metric_params=dict(isi_violation=dict(isi_threshold_ms=2)), skip_pc_metrics=True, seed=2205, ) - for empty_unit_id in sorting_empty.get_empty_unit_ids(): + # num_spikes are ints not nans so we confirm empty units are nans for everything except + # num_spikes which should be 0 + nan_containing_columns = [column for column in metrics_empty.columns if column != "num_spikes"] + for empty_unit_ids in sorting_empty.get_empty_unit_ids(): from pandas import isnull - assert np.all(isnull(metrics_empty.loc[empty_unit_id].values)) + assert np.all(isnull(metrics_empty.loc[empty_unit_ids, nan_containing_columns].values)) + if "num_spikes" in metrics_empty.columns: + assert sum(metrics_empty.loc[empty_unit_ids, ["num_spikes"]]) == 0 # TODO @alessio all theses old test should be moved in test_metric_functions.py or test_pca_metrics() diff --git a/src/spikeinterface/sorters/basesorter.py b/src/spikeinterface/sorters/basesorter.py index 3502d27548..4492057f21 100644 --- a/src/spikeinterface/sorters/basesorter.py +++ b/src/spikeinterface/sorters/basesorter.py @@ -15,7 +15,7 @@ import warnings -from spikeinterface.core import load_extractor, BaseRecordingSnippets, BaseRecording +from spikeinterface.core import load, BaseRecordingSnippets, BaseRecording from spikeinterface.core.core_tools import check_json from spikeinterface.core.globals import get_global_job_kwargs from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs @@ -145,9 +145,10 @@ def initialize_folder(cls, recording, output_folder, verbose, remove_existing_fo elif recording.check_serializability("pickle"): recording.dump(output_folder / "spikeinterface_recording.pickle", relative_to=output_folder) else: - # TODO: deprecate and finally remove this after 0.100 - d = {"warning": "The recording is not serializable to json"} - rec_file.write_text(json.dumps(d, indent=4), encoding="utf8") + raise RuntimeError( + "This recording is not serializable and so can not be sorted. Consider `recording.save()` to save a " + "compatible binary file." + ) return output_folder @@ -209,9 +210,9 @@ def load_recording_from_folder(cls, output_folder, with_warnings=False): ) recording = None else: - recording = load_extractor(json_file, base_folder=output_folder) + recording = load(json_file, base_folder=output_folder) elif pickle_file.exists(): - recording = load_extractor(pickle_file, base_folder=output_folder) + recording = load(pickle_file, base_folder=output_folder) return recording diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 2a9fb34267..ec15506006 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -66,7 +66,7 @@ class Kilosort4Sorter(BaseSorter): "do_correction": True, "keep_good_only": False, "skip_kilosort_preprocessing": False, - "use_binary_file": None, + "use_binary_file": True, "delete_recording_dat": True, } @@ -116,7 +116,7 @@ class Kilosort4Sorter(BaseSorter): "keep_good_only": "If True, only the units labeled as 'good' by Kilosort are returned in the output. (spikeinterface parameter)", "use_binary_file": "If True then Kilosort is run using a binary file. In this case, if the input recording is not binary compatible, it is written to a binary file in the output folder. " "If False then Kilosort is run on the recording object directly using the RecordingExtractorAsArray object. If None, then if the recording is binary compatible, the sorter will use the binary file, otherwise the RecordingExtractorAsArray. " - "Default is None. (spikeinterface parameter)", + "Default is True. (spikeinterface parameter)", "delete_recording_dat": "If True, if a temporary binary file is created, it is deleted after the sorting is done. Default is True. (spikeinterface parameter)", } diff --git a/src/spikeinterface/sorters/external/tests/test_kilosort4.py b/src/spikeinterface/sorters/external/tests/test_kilosort4.py index dbaf3ffc5e..5e1e908411 100644 --- a/src/spikeinterface/sorters/external/tests/test_kilosort4.py +++ b/src/spikeinterface/sorters/external/tests/test_kilosort4.py @@ -2,7 +2,7 @@ import pytest from pathlib import Path -from spikeinterface import load_extractor, generate_ground_truth_recording +from spikeinterface import load, generate_ground_truth_recording from spikeinterface.sorters import Kilosort4Sorter, run_sorter from spikeinterface.sorters.tests.common_tests import SorterCommonTestSuite @@ -15,7 +15,7 @@ class Kilosort4SorterCommonTestSuite(SorterCommonTestSuite, unittest.TestCase): # 4 channels is to few for KS4 def setUp(self): if (self.cache_folder / "rec").is_dir(): - recording = load_extractor(self.cache_folder / "rec") + recording = load(self.cache_folder / "rec") else: recording, _ = generate_ground_truth_recording(num_channels=32, durations=[60], seed=0) recording = recording.save(folder=self.cache_folder / "rec", verbose=False, format="binary") diff --git a/src/spikeinterface/sorters/internal/si_based.py b/src/spikeinterface/sorters/internal/si_based.py index 68aeead8e9..bd4324bb87 100644 --- a/src/spikeinterface/sorters/internal/si_based.py +++ b/src/spikeinterface/sorters/internal/si_based.py @@ -1,6 +1,6 @@ from __future__ import annotations -from spikeinterface.core import load_extractor, NumpyRecording +from spikeinterface.core import load, NumpyRecording from spikeinterface.sorters import BaseSorter @@ -20,7 +20,7 @@ def _setup_recording(cls, recording, output_folder, params, verbose): @classmethod def _get_result_from_folder(cls, output_folder): - sorting = load_extractor(output_folder / "sorting") + sorting = load(output_folder / "sorting") return sorting @classmethod diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index c3b3099535..a3a3523591 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -6,12 +6,16 @@ import numpy as np from spikeinterface.core import NumpySorting -from spikeinterface.core.job_tools import fix_job_kwargs +from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs from spikeinterface.core.recording_tools import get_noise_levels from spikeinterface.core.template import Templates from spikeinterface.core.waveform_tools import estimate_templates from spikeinterface.preprocessing import common_reference, whiten, bandpass_filter, correct_motion -from spikeinterface.sortingcomponents.tools import cache_preprocessing +from spikeinterface.sortingcomponents.tools import ( + cache_preprocessing, + get_prototype_and_waveforms_from_recording, + get_shuffled_recording_slices, +) from spikeinterface.core.basesorting import minimum_spike_dtype from spikeinterface.core.sparsity import compute_sparsity from spikeinterface.core.sortinganalyzer import create_sorting_analyzer @@ -24,9 +28,10 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): sorter_name = "spykingcircus2" _default_params = { - "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, + "general": {"ms_before": 2, "ms_after": 2, "radius_um": 75}, "sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25}, - "filtering": {"freq_min": 150, "freq_max": 7000, "ftype": "bessel", "filter_order": 2}, + "filtering": {"freq_min": 150, "freq_max": 7000, "ftype": "bessel", "filter_order": 2, "margin_ms": 10}, + "whitening": {"mode": "local", "regularize": False}, "detection": {"peak_sign": "neg", "detect_threshold": 4}, "selection": { "method": "uniform", @@ -36,7 +41,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "seed": 42, }, "apply_motion_correction": True, - "motion_correction": {"preset": "nonrigid_fast_and_accurate"}, + "motion_correction": {"preset": "dredge_fast"}, "merging": { "similarity_kwargs": {"method": "cosine", "support": "union", "max_lag_ms": 0.2}, "correlograms_kwargs": {}, @@ -46,12 +51,13 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): }, }, "clustering": {"legacy": True}, - "matching": {"method": "wobble"}, + "matching": {"method": "circus-omp-svd"}, "apply_preprocessing": True, "matched_filtering": True, "cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True}, "multi_units_only": False, - "job_kwargs": {"n_jobs": 0.8}, + "job_kwargs": {"n_jobs": 0.5}, + "seed": 42, "debug": False, } @@ -62,6 +68,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): and also the radius_um used to be considered during clustering", "sparsity": "A dictionary to be passed to all the calls to sparsify the templates", "filtering": "A dictionary for the high_pass filter to be used during preprocessing", + "whitening": "A dictionary for the whitening option to be used during preprocessing", "detection": "A dictionary for the peak detection node (locally_exclusive)", "selection": "A dictionary for the peak selection node. Default is to use smart_sampling_amplitudes, with a minimum of 20000 peaks\ and 5000 peaks per electrode on average.", @@ -72,18 +79,21 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "merging": "A dictionary to specify the final merging param to group cells after template matching (get_potential_auto_merge)", "motion_correction": "A dictionary to be provided if motion correction has to be performed (dense probe only)", "apply_preprocessing": "Boolean to specify whether circus 2 should preprocess the recording or not. If yes, then high_pass filtering + common\ - median reference + zscore", + median reference + whitening", + "apply_motion_correction": "Boolean to specify whether circus 2 should apply motion correction to the recording or not", + "matched_filtering": "Boolean to specify whether circus 2 should detect peaks via matched filtering (slightly slower)", "cache_preprocessing": "How to cache the preprocessed recording. Mode can be memory, file, zarr, with extra arguments. In case of memory (default), \ memory_limit will control how much RAM can be used. In case of folder or zarr, delete_cache controls if cache is cleaned after sorting", "multi_units_only": "Boolean to get only multi units activity (i.e. one template per electrode)", "job_kwargs": "A dictionary to specify how many jobs and which parameters they should used", + "seed": "An int to control how chunks are shuffled while detecting peaks", "debug": "Boolean to specify if internal data structures made during the sorting should be kept for debugging", } sorter_description = """Spyking Circus 2 is a rewriting of Spyking Circus, within the SpikeInterface framework It uses a more conservative clustering algorithm (compared to Spyking Circus), which is less prone to hallucinate units and/or find noise. In addition, it also uses a full Orthogonal Matching Pursuit engine to reconstruct the traces, leading to more spikes - being discovered.""" + being discovered. The code is much faster and memory efficient, inheriting from all the preprocessing possibilities of spikeinterface""" @classmethod def get_sorter_version(cls): @@ -98,6 +108,12 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): except: HAVE_HDBSCAN = False + try: + import torch + except ImportError: + HAVE_TORCH = False + print("spykingcircus2 could benefit from using torch. Consider installing it") + assert HAVE_HDBSCAN, "spykingcircus2 needs hdbscan to be installed" # this is importanted only on demand because numba import are too heavy @@ -106,11 +122,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): from spikeinterface.sortingcomponents.clustering import find_cluster_from_peaks from spikeinterface.sortingcomponents.matching import find_spikes_from_templates from spikeinterface.sortingcomponents.tools import remove_empty_templates - from spikeinterface.sortingcomponents.tools import get_prototype_spike, check_probe_for_drift_correction - from spikeinterface.sortingcomponents.tools import get_prototype_spike + from spikeinterface.sortingcomponents.tools import check_probe_for_drift_correction - job_kwargs = params["job_kwargs"] - job_kwargs = fix_job_kwargs(job_kwargs) + job_kwargs = fix_job_kwargs(params["job_kwargs"]) job_kwargs.update({"progress_bar": verbose}) recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) @@ -119,16 +133,20 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): num_channels = recording.get_num_channels() ms_before = params["general"].get("ms_before", 2) ms_after = params["general"].get("ms_after", 2) - radius_um = params["general"].get("radius_um", 100) + radius_um = params["general"].get("radius_um", 75) exclude_sweep_ms = params["detection"].get("exclude_sweep_ms", max(ms_before, ms_after) / 2) ## First, we are filtering the data filtering_params = params["filtering"].copy() if params["apply_preprocessing"]: + if verbose: + print("Preprocessing the recording (bandpass filtering + CMR + whitening)") recording_f = bandpass_filter(recording, **filtering_params, dtype="float32") if num_channels > 1: recording_f = common_reference(recording_f) else: + if verbose: + print("Skipping preprocessing (whitening only)") recording_f = recording recording_f.annotate(is_filtered=True) @@ -143,15 +161,22 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): print("Motion correction activated (probe geometry compatible)") motion_folder = sorter_output_folder / "motion" params["motion_correction"].update({"folder": motion_folder}) - recording_f = correct_motion(recording_f, **params["motion_correction"]) + recording_f = correct_motion(recording_f, **params["motion_correction"], **job_kwargs) else: motion_folder = None ## We need to whiten before the template matching step, to boost the results # TODO add , regularize=True chen ready - recording_w = whiten(recording_f, mode="local", radius_um=radius_um, dtype="float32", regularize=True) + whitening_kwargs = params["whitening"].copy() + whitening_kwargs["dtype"] = "float32" + whitening_kwargs["regularize"] = whitening_kwargs.get("regularize", False) + if num_channels == 1: + whitening_kwargs["regularize"] = False + if whitening_kwargs["regularize"]: + whitening_kwargs["regularize_kwargs"] = {"method": "LedoitWolf"} - noise_levels = get_noise_levels(recording_w, return_scaled=False) + recording_w = whiten(recording_f, **whitening_kwargs) + noise_levels = get_noise_levels(recording_w, return_scaled=False, **job_kwargs) if recording_w.check_serializability("json"): recording_w.dump(sorter_output_folder / "preprocessed_recording.json", relative_to=None) @@ -162,9 +187,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## Then, we are detecting peaks with a locally_exclusive method detection_params = params["detection"].copy() - detection_params.update(job_kwargs) - - detection_params["radius_um"] = detection_params.get("radius_um", 50) + selection_params = params["selection"].copy() + detection_params["radius_um"] = radius_um detection_params["exclude_sweep_ms"] = exclude_sweep_ms detection_params["noise_levels"] = noise_levels @@ -172,23 +196,47 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): nbefore = int(ms_before * fs / 1000.0) nafter = int(ms_after * fs / 1000.0) - peaks = detect_peaks(recording_w, "locally_exclusive", **detection_params) + skip_peaks = not params["multi_units_only"] and selection_params.get("method", "uniform") == "uniform" + max_n_peaks = selection_params["n_peaks_per_channel"] * num_channels + n_peaks = max(selection_params["min_n_peaks"], max_n_peaks) + + if params["debug"]: + clustering_folder = sorter_output_folder / "clustering" + clustering_folder.mkdir(parents=True, exist_ok=True) + np.save(clustering_folder / "noise_levels.npy", noise_levels) if params["matched_filtering"]: - prototype = get_prototype_spike(recording_w, peaks, ms_before, ms_after, **job_kwargs) + prototype, waveforms, _ = get_prototype_and_waveforms_from_recording( + recording_w, + n_peaks=10000, + ms_before=ms_before, + ms_after=ms_after, + seed=params["seed"], + **detection_params, + **job_kwargs, + ) detection_params["prototype"] = prototype detection_params["ms_before"] = ms_before + if params["debug"]: + np.save(clustering_folder / "waveforms.npy", waveforms) + np.save(clustering_folder / "prototype.npy", prototype) + if skip_peaks: + detection_params["skip_after_n_peaks"] = n_peaks + detection_params["recording_slices"] = get_shuffled_recording_slices( + recording_w, seed=params["seed"], **job_kwargs + ) + peaks = detect_peaks(recording_w, "matched_filtering", **detection_params, **job_kwargs) + else: + waveforms = None + if skip_peaks: + detection_params["skip_after_n_peaks"] = n_peaks + detection_params["recording_slices"] = get_shuffled_recording_slices( + recording_w, seed=params["seed"], **job_kwargs + ) + peaks = detect_peaks(recording_w, "locally_exclusive", **detection_params, **job_kwargs) - for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: - if value in detection_params: - detection_params.pop(value) - - detection_params["chunk_duration"] = "100ms" - - peaks = detect_peaks(recording_w, "matched_filtering", **detection_params) - - if verbose: - print("We found %d peaks in total" % len(peaks)) + if not skip_peaks and verbose: + print("Found %d peaks in total" % len(peaks)) if params["multi_units_only"]: sorting = NumpySorting.from_peaks(peaks, sampling_frequency, unit_ids=recording_w.unit_ids) @@ -196,14 +244,12 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## We subselect a subset of all the peaks, by making the distributions os SNRs over all ## channels as flat as possible selection_params = params["selection"] - selection_params["n_peaks"] = params["selection"]["n_peaks_per_channel"] * num_channels - selection_params["n_peaks"] = max(selection_params["min_n_peaks"], selection_params["n_peaks"]) - + selection_params["n_peaks"] = n_peaks selection_params.update({"noise_levels": noise_levels}) selected_peaks = select_peaks(peaks, **selection_params) if verbose: - print("We kept %d peaks for clustering" % len(selected_peaks)) + print("Kept %d peaks for clustering" % len(selected_peaks)) ## We launch a clustering (using hdbscan) relying on positions and features extracted on ## the fly from the snippets @@ -213,11 +259,13 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clustering_params["radius_um"] = radius_um clustering_params["waveforms"]["ms_before"] = ms_before clustering_params["waveforms"]["ms_after"] = ms_after - clustering_params["job_kwargs"] = job_kwargs + clustering_params["few_waveforms"] = waveforms clustering_params["noise_levels"] = noise_levels - clustering_params["ms_before"] = exclude_sweep_ms - clustering_params["ms_after"] = exclude_sweep_ms + clustering_params["ms_before"] = ms_before + clustering_params["ms_after"] = ms_after + clustering_params["verbose"] = verbose clustering_params["tmp_folder"] = sorter_output_folder / "clustering" + clustering_params["noise_threshold"] = detection_params.get("detect_threshold", 4) legacy = clustering_params.get("legacy", True) @@ -227,7 +275,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clustering_method = "random_projections" labels, peak_labels = find_cluster_from_peaks( - recording_w, selected_peaks, method=clustering_method, method_kwargs=clustering_params + recording_w, selected_peaks, method=clustering_method, method_kwargs=clustering_params, **job_kwargs ) ## We get the labels for our peaks @@ -242,12 +290,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): unit_ids = np.arange(len(np.unique(labeled_peaks["unit_index"]))) sorting = NumpySorting(labeled_peaks, sampling_frequency, unit_ids=unit_ids) - clustering_folder = sorter_output_folder / "clustering" - clustering_folder.mkdir(parents=True, exist_ok=True) - - if not params["debug"]: - shutil.rmtree(clustering_folder) - else: + if params["debug"]: + np.save(clustering_folder / "peak_labels", peak_labels) np.save(clustering_folder / "labels", labels) np.save(clustering_folder / "peaks", selected_peaks) @@ -278,16 +322,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): matching_method = params["matching"].pop("method") matching_params = params["matching"].copy() matching_params["templates"] = templates - matching_job_params = job_kwargs.copy() if matching_method is not None: - for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: - if value in matching_job_params: - matching_job_params[value] = None - matching_job_params["chunk_duration"] = "100ms" - spikes = find_spikes_from_templates( - recording_w, matching_method, method_kwargs=matching_params, **matching_job_params + recording_w, matching_method, method_kwargs=matching_params, **job_kwargs ) if params["debug"]: @@ -296,7 +334,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): np.save(fitting_folder / "spikes", spikes) if verbose: - print("We found %d spikes" % len(spikes)) + print("Found %d spikes" % len(spikes)) ## And this is it! We have a spyking circus sorting = np.zeros(spikes.size, dtype=minimum_spike_dtype) @@ -336,10 +374,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): sorting.save(folder=curation_folder) # np.save(fitting_folder / "amplitudes", guessed_amplitudes) - sorting = final_cleaning_circus(recording_w, sorting, templates, **merging_params) + sorting = final_cleaning_circus(recording_w, sorting, templates, merging_params, **job_kwargs) if verbose: - print(f"Final merging, keeping {len(sorting.unit_ids)} units") + print(f"Kept {len(sorting.unit_ids)} units after final merging") folder_to_delete = None cache_mode = params["cache_preprocessing"].get("mode", "memory") @@ -378,17 +416,18 @@ def create_sorting_analyzer_with_templates(sorting, recording, templates, remove return sa -def final_cleaning_circus(recording, sorting, templates, **merging_kwargs): +def final_cleaning_circus(recording, sorting, templates, merging_kwargs, **job_kwargs): from spikeinterface.core.sorting_tools import apply_merges_to_sorting sa = create_sorting_analyzer_with_templates(sorting, recording, templates) - sa.compute("unit_locations", method="monopolar_triangulation") + sa.compute("unit_locations", method="monopolar_triangulation", **job_kwargs) similarity_kwargs = merging_kwargs.pop("similarity_kwargs", {}) - sa.compute("template_similarity", **similarity_kwargs) + sa.compute("template_similarity", **similarity_kwargs, **job_kwargs) correlograms_kwargs = merging_kwargs.pop("correlograms_kwargs", {}) - sa.compute("correlograms", **correlograms_kwargs) + sa.compute("correlograms", **correlograms_kwargs, **job_kwargs) + auto_merge_kwargs = merging_kwargs.pop("auto_merge", {}) merges = get_potential_auto_merge(sa, resolve_graph=True, **auto_merge_kwargs) sorting = apply_merges_to_sorting(sa.sorting, merges) diff --git a/src/spikeinterface/sorters/internal/tests/test_spykingcircus2.py b/src/spikeinterface/sorters/internal/tests/test_spykingcircus2.py index 333bcdbc32..df6e3821bb 100644 --- a/src/spikeinterface/sorters/internal/tests/test_spykingcircus2.py +++ b/src/spikeinterface/sorters/internal/tests/test_spykingcircus2.py @@ -4,12 +4,18 @@ from spikeinterface.sorters import Spykingcircus2Sorter +from pathlib import Path + class SpykingCircus2SorterCommonTestSuite(SorterCommonTestSuite, unittest.TestCase): SorterClass = Spykingcircus2Sorter if __name__ == "__main__": + from spikeinterface import set_global_job_kwargs + + set_global_job_kwargs(n_jobs=1, progress_bar=False) test = SpykingCircus2SorterCommonTestSuite() + test.cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "sorters" test.setUp() test.test_with_run() diff --git a/src/spikeinterface/sorters/internal/tests/test_tridesclous2.py b/src/spikeinterface/sorters/internal/tests/test_tridesclous2.py index 58d6c15c8d..b256dd1328 100644 --- a/src/spikeinterface/sorters/internal/tests/test_tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tests/test_tridesclous2.py @@ -4,6 +4,8 @@ from spikeinterface.sorters import Tridesclous2Sorter +from pathlib import Path + class Tridesclous2SorterCommonTestSuite(SorterCommonTestSuite, unittest.TestCase): SorterClass = Tridesclous2Sorter @@ -11,5 +13,6 @@ class Tridesclous2SorterCommonTestSuite(SorterCommonTestSuite, unittest.TestCase if __name__ == "__main__": test = Tridesclous2SorterCommonTestSuite() + test.cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "sorters" test.setUp() test.test_with_run() diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index 57755cd759..a180fb4e02 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -226,7 +226,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): matching_method = params["matching"]["method"] matching_params = params["matching"]["method_kwargs"].copy() matching_params["templates"] = templates - matching_params["noise_levels"] = noise_levels + if params["matching"]["method"] in ("tdc-peeler",): + matching_params["noise_levels"] = noise_levels spikes = find_spikes_from_templates( recording_for_peeler, method=matching_method, method_kwargs=matching_params, **job_kwargs ) diff --git a/src/spikeinterface/sorters/launcher.py b/src/spikeinterface/sorters/launcher.py index 7ed5b29556..db660804aa 100644 --- a/src/spikeinterface/sorters/launcher.py +++ b/src/spikeinterface/sorters/launcher.py @@ -188,7 +188,7 @@ def run_sorter_jobs(job_list, engine="loop", engine_kwargs={}, return_output=Fal _slurm_script = """#! {python} from numpy import array -from spikeinterface import load_extractor +from spikeinterface import load from spikeinterface.sorters import run_sorter rec_dict = {recording_dict} @@ -196,7 +196,7 @@ def run_sorter_jobs(job_list, engine="loop", engine_kwargs={}, return_output=Fal kwargs = dict( {kwargs_txt} ) -kwargs['recording'] = load_extractor(rec_dict) +kwargs['recording'] = load(rec_dict) run_sorter(**kwargs) """ diff --git a/src/spikeinterface/sorters/runsorter.py b/src/spikeinterface/sorters/runsorter.py index d28af7b99c..d536d2480a 100644 --- a/src/spikeinterface/sorters/runsorter.py +++ b/src/spikeinterface/sorters/runsorter.py @@ -16,7 +16,7 @@ from .. import __version__ as si_version -from ..core import BaseRecording, NumpySorting, load_extractor +from ..core import BaseRecording, NumpySorting, load from ..core.core_tools import check_json, is_editable_mode from .sorterlist import sorter_dict from .utils import ( @@ -408,7 +408,7 @@ def run_sorter_container( py_script = f""" import json from pathlib import Path -from spikeinterface import load_extractor +from spikeinterface import load from spikeinterface.sorters import run_sorter_local if __name__ == '__main__': @@ -417,9 +417,9 @@ def run_sorter_container( json_rec = Path('{parent_folder_unix}/in_container_recording.json') pickle_rec = Path('{parent_folder_unix}/in_container_recording.pickle') if json_rec.exists(): - recording = load_extractor(json_rec) + recording = load(json_rec) else: - recording = load_extractor(pickle_rec) + recording = load(pickle_rec) # load params in container with open('{parent_folder_unix}/in_container_params.json', encoding='utf8', mode='r') as f: @@ -652,7 +652,7 @@ def run_sorter_container( sorting = SorterClass.get_result_from_folder(folder) except Exception as e: try: - sorting = load_extractor(in_container_sorting_folder) + sorting = load(in_container_sorting_folder) except FileNotFoundError: SpikeSortingError(f"Spike sorting in {mode} failed with the following error:\n{run_sorter_output}") diff --git a/src/spikeinterface/sorters/tests/test_container_tools.py b/src/spikeinterface/sorters/tests/test_container_tools.py index 0369bca860..606fe9940e 100644 --- a/src/spikeinterface/sorters/tests/test_container_tools.py +++ b/src/spikeinterface/sorters/tests/test_container_tools.py @@ -30,8 +30,8 @@ def setup_module(tmp_path_factory): def test_find_recording_folders(setup_module): cache_folder = setup_module - rec1 = si.load_extractor(cache_folder / "mono") - rec2 = si.load_extractor(cache_folder / "multi" / "binary.json", base_folder=cache_folder / "multi") + rec1 = si.load(cache_folder / "mono") + rec2 = si.load(cache_folder / "multi" / "binary.json", base_folder=cache_folder / "multi") d1 = rec1.to_dict() d2 = rec2.to_dict() diff --git a/src/spikeinterface/sorters/utils/shellscript.py b/src/spikeinterface/sorters/utils/shellscript.py index 286445dd2d..24f353bf00 100644 --- a/src/spikeinterface/sorters/utils/shellscript.py +++ b/src/spikeinterface/sorters/utils/shellscript.py @@ -86,15 +86,15 @@ def start(self) -> None: if self._verbose: print("RUNNING SHELL SCRIPT: " + cmd) self._start_time = time.time() + encoding = sys.getdefaultencoding() self._process = subprocess.Popen( - cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, bufsize=1, universal_newlines=True + cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, bufsize=1, universal_newlines=True, encoding=encoding ) with open(script_log_path, "w+") as script_log_file: for line in self._process.stdout: script_log_file.write(line) - if ( - self._verbose - ): # Print onto console depending on the verbose property passed on from the sorter class + if self._verbose: + # Print onto console depending on the verbose property passed on from the sorter class print(line) def wait(self, timeout=None) -> Optional[int]: diff --git a/src/spikeinterface/sortingcomponents/benchmark/__init__.py b/src/spikeinterface/sortingcomponents/benchmark/__init__.py deleted file mode 100644 index ad6d444bdb..0000000000 --- a/src/spikeinterface/sortingcomponents/benchmark/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -""" -Module to benchmark some sorting components: - * clustering - * motion - * template matching -""" diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_selection.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_selection.py deleted file mode 100644 index a9e404292d..0000000000 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_selection.py +++ /dev/null @@ -1,11 +0,0 @@ -import pytest - - -@pytest.mark.skip() -def test_benchmark_peak_selection(create_cache_folder): - cache_folder = create_cache_folder - pass - - -if __name__ == "__main__": - test_benchmark_peak_selection() diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index b08ee4d9cb..884e4cace8 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -18,9 +18,9 @@ from spikeinterface.core.waveform_tools import estimate_templates from .clustering_tools import remove_duplicates_via_matching from spikeinterface.core.recording_tools import get_noise_levels, get_channel_distances -from spikeinterface.core.job_tools import fix_job_kwargs from spikeinterface.sortingcomponents.peak_selection import select_peaks from spikeinterface.sortingcomponents.waveforms.temporal_pca import TemporalPCAProjection +from spikeinterface.sortingcomponents.waveforms.hanning_filter import HanningFilter from spikeinterface.core.template import Templates from spikeinterface.core.sparsity import compute_sparsity from spikeinterface.sortingcomponents.tools import remove_empty_templates @@ -41,13 +41,7 @@ class CircusClustering: """ _default_params = { - "hdbscan_kwargs": { - "min_cluster_size": 25, - "allow_single_cluster": True, - "core_dist_n_jobs": -1, - "cluster_selection_method": "eom", - # "cluster_selection_epsilon" : 5 ## To be optimized - }, + "hdbscan_kwargs": {"min_cluster_size": 10, "allow_single_cluster": True, "min_samples": 5}, "cleaning_kwargs": {}, "waveforms": {"ms_before": 2, "ms_after": 2}, "sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25}, @@ -58,21 +52,20 @@ class CircusClustering: }, "radius_um": 100, "n_svd": [5, 2], + "few_waveforms": None, "ms_before": 0.5, "ms_after": 0.5, + "noise_threshold": 4, "rank": 5, "noise_levels": None, "tmp_folder": None, - "job_kwargs": {}, "verbose": True, } @classmethod - def main_function(cls, recording, peaks, params): + def main_function(cls, recording, peaks, params, job_kwargs=dict()): assert HAVE_HDBSCAN, "random projections clustering needs hdbscan to be installed" - job_kwargs = fix_job_kwargs(params["job_kwargs"]) - d = params verbose = d["verbose"] @@ -90,12 +83,31 @@ def main_function(cls, recording, peaks, params): tmp_folder.mkdir(parents=True, exist_ok=True) # SVD for time compression - few_peaks = select_peaks(peaks, recording=recording, method="uniform", n_peaks=10000, margin=(nbefore, nafter)) - few_wfs = extract_waveform_at_max_channel( - recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **params["job_kwargs"] - ) + if params["few_waveforms"] is None: + few_peaks = select_peaks( + peaks, recording=recording, method="uniform", n_peaks=10000, margin=(nbefore, nafter) + ) + few_wfs = extract_waveform_at_max_channel( + recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs + ) + wfs = few_wfs[:, :, 0] + else: + offset = int(params["waveforms"]["ms_before"] * fs / 1000) + wfs = params["few_waveforms"][:, offset - nbefore : offset + nafter] + + # Ensure all waveforms have a positive max + wfs *= np.sign(wfs[:, nbefore])[:, np.newaxis] + + # Remove outliers + valid = np.argmax(np.abs(wfs), axis=1) == nbefore + wfs = wfs[valid] + + # Perform Hanning filtering + hanning_before = np.hanning(2 * nbefore) + hanning_after = np.hanning(2 * nafter) + hanning = np.concatenate((hanning_before[:nbefore], hanning_after[nafter:])) + wfs *= hanning - wfs = few_wfs[:, :, 0] from sklearn.decomposition import TruncatedSVD tsvd = TruncatedSVD(params["n_svd"][0]) @@ -129,18 +141,21 @@ def main_function(cls, recording, peaks, params): radius_um=radius_um, ) - node2 = TemporalPCAProjection( - recording, parents=[node0, node1], return_output=True, model_folder_path=model_folder + node2 = HanningFilter(recording, parents=[node0, node1], return_output=False) + + node3 = TemporalPCAProjection( + recording, parents=[node0, node2], return_output=True, model_folder_path=model_folder ) - pipeline_nodes = [node0, node1, node2] + pipeline_nodes = [node0, node1, node2, node3] if len(params["recursive_kwargs"]) == 0: + from sklearn.decomposition import PCA all_pc_data = run_node_pipeline( recording, pipeline_nodes, - params["job_kwargs"], + job_kwargs, job_name="extracting features", ) @@ -152,9 +167,9 @@ def main_function(cls, recording, peaks, params): sub_data = sub_data.reshape(len(sub_data), -1) if all_pc_data.shape[1] > params["n_svd"][1]: - tsvd = TruncatedSVD(params["n_svd"][1]) + tsvd = PCA(params["n_svd"][1], whiten=True) else: - tsvd = TruncatedSVD(all_pc_data.shape[1]) + tsvd = PCA(all_pc_data.shape[1], whiten=True) hdbscan_data = tsvd.fit_transform(sub_data) try: @@ -175,7 +190,7 @@ def main_function(cls, recording, peaks, params): _ = run_node_pipeline( recording, pipeline_nodes, - params["job_kwargs"], + job_kwargs, job_name="extracting features", gather_mode="npy", gather_kwargs=dict(exist_ok=True), @@ -184,7 +199,7 @@ def main_function(cls, recording, peaks, params): ) sparse_mask = node1.neighbours_mask - neighbours_mask = get_channel_distances(recording) < radius_um + neighbours_mask = get_channel_distances(recording) <= radius_um # np.save(features_folder / "sparse_mask.npy", sparse_mask) np.save(features_folder / "peaks.npy", peaks) @@ -192,6 +207,8 @@ def main_function(cls, recording, peaks, params): original_labels = peaks["channel_index"] from spikeinterface.sortingcomponents.clustering.split import split_clusters + min_size = 2 * params["hdbscan_kwargs"].get("min_cluster_size", 10) + peak_labels, _ = split_clusters( original_labels, recording, @@ -202,7 +219,7 @@ def main_function(cls, recording, peaks, params): feature_name="sparse_tsvd", neighbours_mask=neighbours_mask, waveforms_sparse_mask=sparse_mask, - min_size_split=50, + min_size_split=min_size, clusterer_kwargs=d["hdbscan_kwargs"], n_pca_features=params["n_svd"][1], scale_n_pca_by_depth=True, @@ -226,52 +243,64 @@ def main_function(cls, recording, peaks, params): nbefore = int(params["waveforms"]["ms_before"] * fs / 1000.0) nafter = int(params["waveforms"]["ms_after"] * fs / 1000.0) + if params["noise_levels"] is None: + params["noise_levels"] = get_noise_levels(recording, return_scaled=False, **job_kwargs) + templates_array = estimate_templates( - recording, spikes, unit_ids, nbefore, nafter, return_scaled=False, job_name=None, **job_kwargs + recording, + spikes, + unit_ids, + nbefore, + nafter, + return_scaled=False, + job_name=None, + **job_kwargs, ) + best_channels = np.argmax(np.abs(templates_array[:, nbefore, :]), axis=1) + peak_snrs = np.abs(templates_array[:, nbefore, :]) + best_snrs_ratio = (peak_snrs / params["noise_levels"])[np.arange(len(peak_snrs)), best_channels] + valid_templates = best_snrs_ratio > params["noise_threshold"] + if d["rank"] is not None: from spikeinterface.sortingcomponents.matching.circus import compress_templates - _, _, _, templates_array = compress_templates(templates_array, 5) + _, _, _, templates_array = compress_templates(templates_array, d["rank"]) templates = Templates( - templates_array=templates_array, + templates_array=templates_array[valid_templates], sampling_frequency=fs, nbefore=nbefore, sparsity_mask=None, channel_ids=recording.channel_ids, - unit_ids=unit_ids, + unit_ids=unit_ids[valid_templates], probe=recording.get_probe(), is_scaled=False, ) - if params["noise_levels"] is None: - params["noise_levels"] = get_noise_levels(recording, return_scaled=False) + sparsity = compute_sparsity(templates, noise_levels=params["noise_levels"], **params["sparsity"]) templates = templates.to_sparse(sparsity) empty_templates = templates.sparsity_mask.sum(axis=1) == 0 templates = remove_empty_templates(templates) + mask = np.isin(peak_labels, np.where(empty_templates)[0]) peak_labels[mask] = -1 - if verbose: - print("We found %d raw clusters, starting to clean with matching..." % (len(templates.unit_ids))) + mask = np.isin(peak_labels, np.where(~valid_templates)[0]) + peak_labels[mask] = -1 - cleaning_matching_params = params["job_kwargs"].copy() - for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: - if value in cleaning_matching_params: - cleaning_matching_params.pop(value) - cleaning_matching_params["chunk_duration"] = "100ms" - cleaning_matching_params["n_jobs"] = 1 - cleaning_matching_params["progress_bar"] = False + if verbose: + print("Found %d raw clusters, starting to clean with matching" % (len(templates.unit_ids))) + cleaning_job_kwargs = job_kwargs.copy() + cleaning_job_kwargs["progress_bar"] = False cleaning_params = params["cleaning_kwargs"].copy() labels, peak_labels = remove_duplicates_via_matching( - templates, peak_labels, job_kwargs=cleaning_matching_params, **cleaning_params + templates, peak_labels, job_kwargs=cleaning_job_kwargs, **cleaning_params ) if verbose: - print("We kept %d non-duplicated clusters..." % len(labels)) + print("Kept %d non-duplicated clusters" % len(labels)) return labels, peak_labels diff --git a/src/spikeinterface/sortingcomponents/clustering/clean.py b/src/spikeinterface/sortingcomponents/clustering/clean.py index c7d57b14e4..e8bc5a1d49 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clean.py +++ b/src/spikeinterface/sortingcomponents/clustering/clean.py @@ -32,7 +32,6 @@ def clean_clusters( count = np.zeros(n, dtype="int64") for i, label in enumerate(labels_set): count[i] = np.sum(peak_labels == label) - print(count) templates = compute_template_from_sparse(peaks, peak_labels, labels_set, sparse_wfs, sparse_mask, total_channels) @@ -42,6 +41,5 @@ def clean_clusters( max_values = -np.min(templates, axis=(1, 2)) elif peak_sign == "pos": max_values = np.max(templates, axis=(1, 2)) - print(max_values) return clean_labels diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 234be686d0..93db9a268f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -570,7 +570,7 @@ def detect_mixtures(templates, method_kwargs={}, job_kwargs={}, tmp_folder=None, ) else: recording = NumpyRecording(zdata, sampling_frequency=fs) - recording = SharedMemoryRecording.from_recording(recording) + recording = SharedMemoryRecording.from_recording(recording, **job_kwargs) recording = recording.set_probe(templates.probe) recording.annotate(is_filtered=True) @@ -587,6 +587,8 @@ def detect_mixtures(templates, method_kwargs={}, job_kwargs={}, tmp_folder=None, keep_searching = True + local_job_kargs = {"n_jobs": 1, "progress_bar": False} + DEBUG = False while keep_searching: @@ -602,21 +604,15 @@ def detect_mixtures(templates, method_kwargs={}, job_kwargs={}, tmp_folder=None, sub_recording = recording.frame_slice(t_start, t_stop) local_params.update({"ignore_inds": ignore_inds + [i]}) - spikes, computed = find_spikes_from_templates( - sub_recording, method="circus-omp-svd", method_kwargs=local_params, extra_outputs=True, **job_kwargs - ) - local_params.update( - { - "overlaps": computed["overlaps"], - "normed_templates": computed["normed_templates"], - "norms": computed["norms"], - "temporal": computed["temporal"], - "spatial": computed["spatial"], - "singular": computed["singular"], - "units_overlaps": computed["units_overlaps"], - "unit_overlaps_indices": computed["unit_overlaps_indices"], - } + + spikes, more_outputs = find_spikes_from_templates( + sub_recording, + method="circus-omp-svd", + method_kwargs=local_params, + extra_outputs=True, + **local_job_kargs, ) + local_params["precomputed"] = more_outputs valid = (spikes["sample_index"] >= 0) * (spikes["sample_index"] < duration + 2 * margin) if np.sum(valid) > 0: diff --git a/src/spikeinterface/sortingcomponents/clustering/dummy.py b/src/spikeinterface/sortingcomponents/clustering/dummy.py index c1032ee6c6..b5761ad5cf 100644 --- a/src/spikeinterface/sortingcomponents/clustering/dummy.py +++ b/src/spikeinterface/sortingcomponents/clustering/dummy.py @@ -13,7 +13,7 @@ class DummyClustering: _default_params = {} @classmethod - def main_function(cls, recording, peaks, params): + def main_function(cls, recording, peaks, params, job_kwargs=dict()): labels = np.arange(recording.get_num_channels(), dtype="int64") peak_labels = peaks["channel_index"] return labels, peak_labels diff --git a/src/spikeinterface/sortingcomponents/clustering/main.py b/src/spikeinterface/sortingcomponents/clustering/main.py index 99881f2f34..ba0fe6f9ac 100644 --- a/src/spikeinterface/sortingcomponents/clustering/main.py +++ b/src/spikeinterface/sortingcomponents/clustering/main.py @@ -41,7 +41,7 @@ def find_cluster_from_peaks(recording, peaks, method="stupid", method_kwargs={}, params = method_class._default_params.copy() params.update(**method_kwargs) - outputs = method_class.main_function(recording, peaks, params) + outputs = method_class.main_function(recording, peaks, params, job_kwargs=job_kwargs) if extra_outputs: return outputs diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py index 4a7b722aea..e618cfbfb6 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merge.py +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -261,7 +261,7 @@ def find_merge_pairs( **job_kwargs, # n_jobs=1, # mp_context="fork", - # max_threads_per_process=1, + # max_threads_per_worker=1, # progress_bar=True, ): """ @@ -299,7 +299,7 @@ def find_merge_pairs( n_jobs = job_kwargs["n_jobs"] mp_context = job_kwargs.get("mp_context", None) - max_threads_per_process = job_kwargs.get("max_threads_per_process", 1) + max_threads_per_worker = job_kwargs.get("max_threads_per_worker", 1) progress_bar = job_kwargs["progress_bar"] Executor = get_poolexecutor(n_jobs) @@ -316,7 +316,7 @@ def find_merge_pairs( templates, method, method_kwargs, - max_threads_per_process, + max_threads_per_worker, ), ) as pool: jobs = [] @@ -354,7 +354,7 @@ def find_pair_worker_init( templates, method, method_kwargs, - max_threads_per_process, + max_threads_per_worker, ): global _ctx _ctx = {} @@ -366,7 +366,7 @@ def find_pair_worker_init( _ctx["method"] = method _ctx["method_kwargs"] = method_kwargs _ctx["method_class"] = find_pair_method_dict[method] - _ctx["max_threads_per_process"] = max_threads_per_process + _ctx["max_threads_per_worker"] = max_threads_per_worker # if isinstance(features_dict_or_folder, dict): # _ctx["features"] = features_dict_or_folder @@ -380,7 +380,7 @@ def find_pair_worker_init( def find_pair_function_wrapper(label0, label1): global _ctx - with threadpool_limits(limits=_ctx["max_threads_per_process"]): + with threadpool_limits(limits=_ctx["max_threads_per_worker"]): is_merge, label0, label1, shift, merge_value = _ctx["method_class"].merge( label0, label1, diff --git a/src/spikeinterface/sortingcomponents/clustering/position.py b/src/spikeinterface/sortingcomponents/clustering/position.py index ae772206bb..dc76d787f6 100644 --- a/src/spikeinterface/sortingcomponents/clustering/position.py +++ b/src/spikeinterface/sortingcomponents/clustering/position.py @@ -25,18 +25,17 @@ class PositionClustering: "hdbscan_kwargs": {"min_cluster_size": 20, "allow_single_cluster": True, "core_dist_n_jobs": -1}, "debug": False, "tmp_folder": None, - "job_kwargs": {"n_jobs": -1, "chunk_memory": "10M"}, } @classmethod - def main_function(cls, recording, peaks, params): + def main_function(cls, recording, peaks, params, job_kwargs=dict()): assert HAVE_HDBSCAN, "position clustering need hdbscan to be installed" d = params if d["peak_locations"] is None: from spikeinterface.sortingcomponents.peak_localization import localize_peaks - peak_locations = localize_peaks(recording, peaks, **d["peak_localization_kwargs"], **d["job_kwargs"]) + peak_locations = localize_peaks(recording, peaks, **d["peak_localization_kwargs"], **job_kwargs) else: peak_locations = d["peak_locations"] diff --git a/src/spikeinterface/sortingcomponents/clustering/position_and_features.py b/src/spikeinterface/sortingcomponents/clustering/position_and_features.py index d23eb26239..20067a2eec 100644 --- a/src/spikeinterface/sortingcomponents/clustering/position_and_features.py +++ b/src/spikeinterface/sortingcomponents/clustering/position_and_features.py @@ -42,23 +42,14 @@ class PositionAndFeaturesClustering: "ms_before": 1.5, "ms_after": 1.5, "cleaning_method": "dip", - "job_kwargs": {"n_jobs": -1, "chunk_memory": "10M", "progress_bar": True}, } @classmethod - def main_function(cls, recording, peaks, params): + def main_function(cls, recording, peaks, params, job_kwargs=dict()): from sklearn.preprocessing import QuantileTransformer assert HAVE_HDBSCAN, "twisted clustering needs hdbscan to be installed" - if "n_jobs" in params["job_kwargs"]: - if params["job_kwargs"]["n_jobs"] == -1: - params["job_kwargs"]["n_jobs"] = os.cpu_count() - - if "core_dist_n_jobs" in params["hdbscan_kwargs"]: - if params["hdbscan_kwargs"]["core_dist_n_jobs"] == -1: - params["hdbscan_kwargs"]["core_dist_n_jobs"] = os.cpu_count() - d = params peak_dtype = [("sample_index", "int64"), ("unit_index", "int64"), ("segment_index", "int64")] @@ -80,7 +71,7 @@ def main_function(cls, recording, peaks, params): } features_data = compute_features_from_peaks( - recording, peaks, features_list, features_params, ms_before=1, ms_after=1, **params["job_kwargs"] + recording, peaks, features_list, features_params, ms_before=1, ms_after=1, **job_kwargs ) hdbscan_data = np.zeros((len(peaks), 3), dtype=np.float32) @@ -150,10 +141,10 @@ def main_function(cls, recording, peaks, params): dtype=recording.get_dtype(), sparsity_mask=None, copy=True, - **params["job_kwargs"], + **job_kwargs, ) - noise_levels = get_noise_levels(recording, return_scaled=False) + noise_levels = get_noise_levels(recording, return_scaled=False, **job_kwargs) labels, peak_labels = remove_duplicates( wfs_arrays, noise_levels, peak_labels, num_samples, num_chans, **params["cleaning_kwargs"] ) @@ -181,7 +172,7 @@ def main_function(cls, recording, peaks, params): nbefore, nafter, return_scaled=False, - **params["job_kwargs"], + **job_kwargs, ) templates = Templates( templates_array=templates_array, @@ -193,7 +184,7 @@ def main_function(cls, recording, peaks, params): ) labels, peak_labels = remove_duplicates_via_matching( - templates, peak_labels, job_kwargs=params["job_kwargs"], **params["cleaning_kwargs"] + templates, peak_labels, job_kwargs=job_kwargs, **params["cleaning_kwargs"] ) shutil.rmtree(tmp_folder) diff --git a/src/spikeinterface/sortingcomponents/clustering/position_and_pca.py b/src/spikeinterface/sortingcomponents/clustering/position_and_pca.py index 4dfe3c960c..c4f372fc21 100644 --- a/src/spikeinterface/sortingcomponents/clustering/position_and_pca.py +++ b/src/spikeinterface/sortingcomponents/clustering/position_and_pca.py @@ -38,7 +38,6 @@ class PositionAndPCAClustering: "ms_after": 2.5, "n_components_by_channel": 3, "n_components": 5, - "job_kwargs": {"n_jobs": -1, "chunk_memory": "10M", "progress_bar": True}, "hdbscan_global_kwargs": {"min_cluster_size": 20, "allow_single_cluster": True, "core_dist_n_jobs": -1}, "hdbscan_local_kwargs": {"min_cluster_size": 20, "allow_single_cluster": True, "core_dist_n_jobs": -1}, "waveform_mode": "shared_memory", @@ -73,7 +72,7 @@ def _check_params(cls, recording, peaks, params): return params2 @classmethod - def main_function(cls, recording, peaks, params): + def main_function(cls, recording, peaks, params, job_kwargs=dict()): # res = PositionClustering(recording, peaks, params) assert HAVE_HDBSCAN, "position_and_pca clustering need hdbscan to be installed" @@ -85,9 +84,7 @@ def main_function(cls, recording, peaks, params): if params["peak_locations"] is None: from spikeinterface.sortingcomponents.peak_localization import localize_peaks - peak_locations = localize_peaks( - recording, peaks, **params["peak_localization_kwargs"], **params["job_kwargs"] - ) + peak_locations = localize_peaks(recording, peaks, **params["peak_localization_kwargs"], **job_kwargs) else: peak_locations = params["peak_locations"] @@ -155,7 +152,7 @@ def main_function(cls, recording, peaks, params): dtype=recording.get_dtype(), sparsity_mask=sparsity_mask, copy=(params["waveform_mode"] == "shared_memory"), - **params["job_kwargs"], + **job_kwargs, ) noise = get_random_data_chunks( @@ -222,7 +219,7 @@ def main_function(cls, recording, peaks, params): dtype=recording.get_dtype(), sparsity_mask=sparsity_mask3, copy=(params["waveform_mode"] == "shared_memory"), - **params["job_kwargs"], + **job_kwargs, ) clean_peak_labels, peak_sample_shifts = auto_clean_clustering( diff --git a/src/spikeinterface/sortingcomponents/clustering/position_ptp_scaled.py b/src/spikeinterface/sortingcomponents/clustering/position_ptp_scaled.py index 788addf1e6..0f7390d7ac 100644 --- a/src/spikeinterface/sortingcomponents/clustering/position_ptp_scaled.py +++ b/src/spikeinterface/sortingcomponents/clustering/position_ptp_scaled.py @@ -26,7 +26,6 @@ class PositionPTPScaledClustering: "ptps": None, "scales": (1, 1, 10), "peak_localization_kwargs": {"method": "center_of_mass"}, - "job_kwargs": {"n_jobs": -1, "chunk_memory": "10M", "progress_bar": True}, "hdbscan_kwargs": { "min_cluster_size": 20, "min_samples": 20, @@ -38,7 +37,7 @@ class PositionPTPScaledClustering: } @classmethod - def main_function(cls, recording, peaks, params): + def main_function(cls, recording, peaks, params, job_kwargs=dict()): assert HAVE_HDBSCAN, "position clustering need hdbscan to be installed" d = params @@ -60,7 +59,7 @@ def main_function(cls, recording, peaks, params): if d["ptps"] is None: (ptps,) = compute_features_from_peaks( - recording, peaks, ["ptp"], feature_params={"ptp": {"all_channels": True}}, **d["job_kwargs"] + recording, peaks, ["ptp"], feature_params={"ptp": {"all_channels": True}}, **job_kwargs ) else: ptps = d["ptps"] diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index f7ca999d53..1d4d8881ad 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -16,8 +16,7 @@ from spikeinterface.core.basesorting import minimum_spike_dtype from spikeinterface.core.waveform_tools import estimate_templates from .clustering_tools import remove_duplicates_via_matching -from spikeinterface.core.recording_tools import get_noise_levels, get_channel_distances -from spikeinterface.core.job_tools import fix_job_kwargs +from spikeinterface.core.recording_tools import get_noise_levels from spikeinterface.sortingcomponents.waveforms.savgol_denoiser import SavGolDenoiser from spikeinterface.sortingcomponents.features_from_peaks import RandomProjectionsFeature from spikeinterface.core.template import Templates @@ -54,17 +53,15 @@ class RandomProjectionClustering: "random_seed": 42, "noise_levels": None, "smoothing_kwargs": {"window_length_ms": 0.25}, + "noise_threshold": 4, "tmp_folder": None, - "job_kwargs": {}, "verbose": True, } @classmethod - def main_function(cls, recording, peaks, params): + def main_function(cls, recording, peaks, params, job_kwargs=dict()): assert HAVE_HDBSCAN, "random projections clustering need hdbscan to be installed" - job_kwargs = fix_job_kwargs(params["job_kwargs"]) - d = params verbose = d["verbose"] @@ -133,44 +130,59 @@ def main_function(cls, recording, peaks, params): nbefore = int(params["waveforms"]["ms_before"] * fs / 1000.0) nafter = int(params["waveforms"]["ms_after"] * fs / 1000.0) + if params["noise_levels"] is None: + params["noise_levels"] = get_noise_levels(recording, return_scaled=False, **job_kwargs) + templates_array = estimate_templates( - recording, spikes, unit_ids, nbefore, nafter, return_scaled=False, job_name=None, **job_kwargs + recording, + spikes, + unit_ids, + nbefore, + nafter, + return_scaled=False, + job_name=None, + **job_kwargs, ) + best_channels = np.argmax(np.abs(templates_array[:, nbefore, :]), axis=1) + peak_snrs = np.abs(templates_array[:, nbefore, :]) + best_snrs_ratio = (peak_snrs / params["noise_levels"])[np.arange(len(peak_snrs)), best_channels] + valid_templates = best_snrs_ratio > params["noise_threshold"] + templates = Templates( - templates_array=templates_array, + templates_array=templates_array[valid_templates], sampling_frequency=fs, nbefore=nbefore, sparsity_mask=None, channel_ids=recording.channel_ids, - unit_ids=unit_ids, + unit_ids=unit_ids[valid_templates], probe=recording.get_probe(), is_scaled=False, ) - if params["noise_levels"] is None: - params["noise_levels"] = get_noise_levels(recording, return_scaled=False) - sparsity = compute_sparsity(templates, params["noise_levels"], **params["sparsity"]) + + sparsity = compute_sparsity(templates, noise_levels=params["noise_levels"], **params["sparsity"]) templates = templates.to_sparse(sparsity) + empty_templates = templates.sparsity_mask.sum(axis=1) == 0 templates = remove_empty_templates(templates) - if verbose: - print("We found %d raw clusters, starting to clean with matching..." % (len(templates.unit_ids))) + mask = np.isin(peak_labels, np.where(empty_templates)[0]) + peak_labels[mask] = -1 - cleaning_matching_params = job_kwargs.copy() - for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: - if value in cleaning_matching_params: - cleaning_matching_params[value] = None - cleaning_matching_params["chunk_duration"] = "100ms" - cleaning_matching_params["n_jobs"] = 1 - cleaning_matching_params["progress_bar"] = False + mask = np.isin(peak_labels, np.where(~valid_templates)[0]) + peak_labels[mask] = -1 + + if verbose: + print("Found %d raw clusters, starting to clean with matching" % (len(templates.unit_ids))) + cleaning_job_kwargs = job_kwargs.copy() + cleaning_job_kwargs["progress_bar"] = False cleaning_params = params["cleaning_kwargs"].copy() labels, peak_labels = remove_duplicates_via_matching( - templates, peak_labels, job_kwargs=cleaning_matching_params, **cleaning_params + templates, peak_labels, job_kwargs=cleaning_job_kwargs, **cleaning_params ) if verbose: - print("We kept %d non-duplicated clusters..." % len(labels)) + print("Kept %d non-duplicated clusters" % len(labels)) return labels, peak_labels diff --git a/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py b/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py index 8b9acbc92d..5f8ac99848 100644 --- a/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py +++ b/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py @@ -23,7 +23,7 @@ get_random_data_chunks, extract_waveforms_to_buffers, ) -from .clustering_tools import auto_clean_clustering, auto_split_clustering +from .clustering_tools import auto_clean_clustering class SlidingHdbscanClustering: @@ -55,18 +55,17 @@ class SlidingHdbscanClustering: "auto_merge_quantile_limit": 0.8, "ratio_num_channel_intersect": 0.5, # ~ 'auto_trash_misalignment_shift' : 4, - "job_kwargs": {"n_jobs": -1, "chunk_memory": "10M", "progress_bar": True}, } @classmethod - def main_function(cls, recording, peaks, params): + def main_function(cls, recording, peaks, params, job_kwargs=dict()): assert HAVE_HDBSCAN, "sliding_hdbscan clustering need hdbscan to be installed" params = cls._check_params(recording, peaks, params) wfs_arrays, sparsity_mask, noise = cls._initialize_folder(recording, peaks, params) peak_labels = cls._find_clusters(recording, peaks, wfs_arrays, sparsity_mask, noise, params) wfs_arrays2, sparsity_mask2 = cls._prepare_clean( - recording, peaks, wfs_arrays, sparsity_mask, peak_labels, params + recording, peaks, wfs_arrays, sparsity_mask, peak_labels, params, job_kwargs ) clean_peak_labels, peak_sample_shifts = cls._clean_cluster( @@ -100,7 +99,7 @@ def _check_params(cls, recording, peaks, params): return params2 @classmethod - def _initialize_folder(cls, recording, peaks, params): + def _initialize_folder(cls, recording, peaks, params, job_kwargs=dict()): d = params tmp_folder = params["tmp_folder"] @@ -145,7 +144,7 @@ def _initialize_folder(cls, recording, peaks, params): dtype=dtype, sparsity_mask=sparsity_mask, copy=(d["waveform_mode"] == "shared_memory"), - **d["job_kwargs"], + **job_kwargs, ) # noise @@ -401,7 +400,7 @@ def _find_clusters(cls, recording, peaks, wfs_arrays, sparsity_mask, noise, d): return peak_labels @classmethod - def _prepare_clean(cls, recording, peaks, wfs_arrays, sparsity_mask, peak_labels, d): + def _prepare_clean(cls, recording, peaks, wfs_arrays, sparsity_mask, peak_labels, d, job_kwargs=dict()): tmp_folder = d["tmp_folder"] if tmp_folder is None: wf_folder = None @@ -465,7 +464,7 @@ def _prepare_clean(cls, recording, peaks, wfs_arrays, sparsity_mask, peak_labels dtype=recording.get_dtype(), sparsity_mask=sparsity_mask2, copy=(d["waveform_mode"] == "shared_memory"), - **d["job_kwargs"], + **job_kwargs, ) return wfs_arrays2, sparsity_mask2 diff --git a/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py b/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py index a6ffa5fdc2..40cedacdc5 100644 --- a/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py +++ b/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py @@ -71,11 +71,10 @@ class SlidingNNClustering: "tmp_folder": None, "verbose": False, "tmp_folder": None, - "job_kwargs": {"n_jobs": -1}, } @classmethod - def _initialize_folder(cls, recording, peaks, params): + def _initialize_folder(cls, recording, peaks, params, job_kwargs=dict()): assert HAVE_NUMBA, "SlidingNN needs numba to work" assert HAVE_TORCH, "SlidingNN needs torch to work" assert HAVE_NNDESCENT, "SlidingNN needs pynndescent to work" @@ -126,16 +125,16 @@ def _initialize_folder(cls, recording, peaks, params): dtype=dtype, sparsity_mask=sparsity_mask, copy=(d["waveform_mode"] == "shared_memory"), - **d["job_kwargs"], + **job_kwargs, ) return wfs_arrays, sparsity_mask @classmethod - def main_function(cls, recording, peaks, params): + def main_function(cls, recording, peaks, params, job_kwargs=dict()): d = params - # wfs_arrays, sparsity_mask, noise = cls._initialize_folder(recording, peaks, params) + # wfs_arrays, sparsity_mask, noise = cls._initialize_folder(recording, peaks, params, job_kwargs) # prepare neighborhood parameters fs = recording.get_sampling_frequency() @@ -228,7 +227,7 @@ def main_function(cls, recording, peaks, params): n_channel_neighbors=d["n_channel_neighbors"], low_memory=d["low_memory"], knn_verbose=d["verbose"], - n_jobs=d["job_kwargs"]["n_jobs"], + n_jobs=job_kwargs["n_jobs"], ) # remove the first nearest neighbor (which should be self) knn_distances = knn_distances[:, 1:] @@ -297,7 +296,7 @@ def main_function(cls, recording, peaks, params): # TODO HDBSCAN can be done on GPU with NVIDIA RAPIDS for speed clusterer = hdbscan.HDBSCAN( prediction_data=True, - core_dist_n_jobs=d["job_kwargs"]["n_jobs"], + core_dist_n_jobs=job_kwargs["n_jobs"], **d["hdbscan_kwargs"], ).fit(embeddings_chunk) diff --git a/src/spikeinterface/sortingcomponents/clustering/split.py b/src/spikeinterface/sortingcomponents/clustering/split.py index 5934bdfbbb..3c2e878c39 100644 --- a/src/spikeinterface/sortingcomponents/clustering/split.py +++ b/src/spikeinterface/sortingcomponents/clustering/split.py @@ -24,7 +24,7 @@ def split_clusters( peak_labels, recording, features_dict_or_folder, - method="hdbscan_on_local_pca", + method="local_feature_clustering", method_kwargs={}, recursive=False, recursive_depth=None, @@ -65,7 +65,7 @@ def split_clusters( n_jobs = job_kwargs["n_jobs"] mp_context = job_kwargs.get("mp_context", None) progress_bar = job_kwargs["progress_bar"] - max_threads_per_process = job_kwargs.get("max_threads_per_process", 1) + max_threads_per_worker = job_kwargs.get("max_threads_per_worker", 1) original_labels = peak_labels peak_labels = peak_labels.copy() @@ -77,11 +77,10 @@ def split_clusters( max_workers=n_jobs, initializer=split_worker_init, mp_context=get_context(method=mp_context), - initargs=(recording, features_dict_or_folder, original_labels, method, method_kwargs, max_threads_per_process), + initargs=(recording, features_dict_or_folder, original_labels, method, method_kwargs, max_threads_per_worker), ) as pool: labels_set = np.setdiff1d(peak_labels, [-1]) current_max_label = np.max(labels_set) + 1 - jobs = [] for label in labels_set: peak_indices = np.flatnonzero(peak_labels == label) @@ -95,15 +94,14 @@ def split_clusters( for res in iterator: is_split, local_labels, peak_indices = res.result() + # print(is_split, local_labels, peak_indices) if not is_split: continue mask = local_labels >= 0 peak_labels[peak_indices[mask]] = local_labels[mask] + current_max_label peak_labels[peak_indices[~mask]] = local_labels[~mask] - split_count[peak_indices] += 1 - current_max_label += np.max(local_labels[mask]) + 1 if recursive: @@ -120,6 +118,7 @@ def split_clusters( for label in new_labels_set: peak_indices = np.flatnonzero(peak_labels == label) if peak_indices.size > 0: + # print('Relaunched', label, len(peak_indices), recursion_level) jobs.append(pool.submit(split_function_wrapper, peak_indices, recursion_level)) if progress_bar: iterator.total += 1 @@ -134,7 +133,7 @@ def split_clusters( def split_worker_init( - recording, features_dict_or_folder, original_labels, method, method_kwargs, max_threads_per_process + recording, features_dict_or_folder, original_labels, method, method_kwargs, max_threads_per_worker ): global _ctx _ctx = {} @@ -145,14 +144,14 @@ def split_worker_init( _ctx["method"] = method _ctx["method_kwargs"] = method_kwargs _ctx["method_class"] = split_methods_dict[method] - _ctx["max_threads_per_process"] = max_threads_per_process + _ctx["max_threads_per_worker"] = max_threads_per_worker _ctx["features"] = FeaturesLoader.from_dict_or_folder(features_dict_or_folder) _ctx["peaks"] = _ctx["features"]["peaks"] def split_function_wrapper(peak_indices, recursion_level): global _ctx - with threadpool_limits(limits=_ctx["max_threads_per_process"]): + with threadpool_limits(limits=_ctx["max_threads_per_worker"]): is_split, local_labels = _ctx["method_class"].split( peak_indices, _ctx["peaks"], _ctx["features"], recursion_level, **_ctx["method_kwargs"] ) @@ -187,7 +186,7 @@ def split( min_size_split=25, n_pca_features=2, scale_n_pca_by_depth=False, - minimum_common_channels=2, + minimum_overlap_ratio=0.25, ): local_labels = np.zeros(peak_indices.size, dtype=np.int64) @@ -199,19 +198,22 @@ def split( # target channel subset is done intersect local channels + neighbours local_chans = np.unique(peaks["channel_index"][peak_indices]) - target_channels = np.flatnonzero(np.all(neighbours_mask[local_chans, :], axis=0)) + target_intersection_channels = np.flatnonzero(np.all(neighbours_mask[local_chans, :], axis=0)) + target_union_channels = np.flatnonzero(np.any(neighbours_mask[local_chans, :], axis=0)) + num_intersection = len(target_intersection_channels) + num_union = len(target_union_channels) # TODO fix this a better way, this when cluster have too few overlapping channels - if target_channels.size < minimum_common_channels: + if (num_intersection / num_union) < minimum_overlap_ratio: return False, None aligned_wfs, dont_have_channels = aggregate_sparse_features( - peaks, peak_indices, sparse_features, waveforms_sparse_mask, target_channels + peaks, peak_indices, sparse_features, waveforms_sparse_mask, target_intersection_channels ) local_labels[dont_have_channels] = -2 kept = np.flatnonzero(~dont_have_channels) - + # print(recursion_level, kept.size, min_size_split) if kept.size < min_size_split: return False, None @@ -222,6 +224,8 @@ def split( if flatten_features.shape[1] > n_pca_features: from sklearn.decomposition import PCA + # from sklearn.decomposition import TruncatedSVD + if scale_n_pca_by_depth: # tsvd = TruncatedSVD(n_pca_features * recursion_level) tsvd = PCA(n_pca_features * recursion_level, whiten=True) diff --git a/src/spikeinterface/sortingcomponents/clustering/tdc.py b/src/spikeinterface/sortingcomponents/clustering/tdc.py index 13af5b0fab..59472d1374 100644 --- a/src/spikeinterface/sortingcomponents/clustering/tdc.py +++ b/src/spikeinterface/sortingcomponents/clustering/tdc.py @@ -9,27 +9,21 @@ from spikeinterface.core import ( get_channel_distances, - Templates, - compute_sparsity, get_global_tmp_folder, ) from spikeinterface.core.node_pipeline import ( run_node_pipeline, - ExtractDenseWaveforms, ExtractSparseWaveforms, PeakRetriever, ) -from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel, cache_preprocessing -from spikeinterface.sortingcomponents.peak_detection import detect_peaks, DetectPeakLocallyExclusive +from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel from spikeinterface.sortingcomponents.peak_selection import select_peaks -from spikeinterface.sortingcomponents.peak_localization import LocalizeCenterOfMass, LocalizeGridConvolution from spikeinterface.sortingcomponents.waveforms.temporal_pca import TemporalPCAProjection from spikeinterface.sortingcomponents.clustering.split import split_clusters from spikeinterface.sortingcomponents.clustering.merge import merge_clusters -from spikeinterface.sortingcomponents.clustering.tools import compute_template_from_sparse class TdcClustering: @@ -50,15 +44,12 @@ class TdcClustering: "merge_radius_um": 40.0, "threshold_diff": 1.5, }, - "job_kwargs": {}, } @classmethod - def main_function(cls, recording, peaks, params): + def main_function(cls, recording, peaks, params, job_kwargs=dict()): import hdbscan - job_kwargs = params["job_kwargs"] - if params["folder"] is None: randname = "".join(random.choices(string.ascii_uppercase + string.digits, k=6)) clustering_folder = get_global_tmp_folder() / f"tdcclustering_{randname}" diff --git a/src/spikeinterface/sortingcomponents/clustering/tools.py b/src/spikeinterface/sortingcomponents/clustering/tools.py index e2a0d273d6..64cc0f39c4 100644 --- a/src/spikeinterface/sortingcomponents/clustering/tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/tools.py @@ -172,8 +172,6 @@ def apply_waveforms_shift(waveforms, peak_shifts, inplace=False): """ - print("apply_waveforms_shift") - if inplace: aligned_waveforms = waveforms else: @@ -193,6 +191,4 @@ def apply_waveforms_shift(waveforms, peak_shifts, inplace=False): else: aligned_waveforms[mask, -shift:, :] = wfs[:, :-shift, :] - print("apply_waveforms_shift DONE") - return aligned_waveforms diff --git a/src/spikeinterface/sortingcomponents/matching/base.py b/src/spikeinterface/sortingcomponents/matching/base.py new file mode 100644 index 0000000000..0e60a9e864 --- /dev/null +++ b/src/spikeinterface/sortingcomponents/matching/base.py @@ -0,0 +1,48 @@ +import numpy as np +from spikeinterface.core import Templates +from spikeinterface.core.node_pipeline import PeakDetector + +_base_matching_dtype = [ + ("sample_index", "int64"), + ("channel_index", "int64"), + ("cluster_index", "int64"), + ("amplitude", "float64"), + ("segment_index", "int64"), +] + + +class BaseTemplateMatching(PeakDetector): + def __init__(self, recording, templates, return_output=True, parents=None): + # TODO make a sharedmem of template here + # TODO maybe check that channel_id are the same with recording + + assert isinstance( + templates, Templates + ), f"The templates supplied is of type {type(templates)} and must be a Templates" + self.templates = templates + PeakDetector.__init__(self, recording, return_output=return_output, parents=parents) + + def get_dtype(self): + return np.dtype(_base_matching_dtype) + + def get_trace_margin(self): + raise NotImplementedError + + def compute(self, traces, start_frame, end_frame, segment_index, max_margin): + spikes = self.compute_matching(traces, start_frame, end_frame, segment_index) + spikes["segment_index"] = segment_index + + margin = self.get_trace_margin() + if margin > 0 and spikes.size > 0: + keep = (spikes["sample_index"] >= margin) & (spikes["sample_index"] < (traces.shape[0] - margin)) + spikes = spikes[keep] + + # node pipeline need to return a tuple + return (spikes,) + + def compute_matching(self, traces, start_frame, end_frame, segment_index): + raise NotImplementedError + + def get_extra_outputs(self): + # can be overwritten if need to ouput some variables with a dict + return None diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index ad7391a297..3b97f2dc6a 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -9,6 +9,7 @@ from spikeinterface.sortingcomponents.peak_detection import DetectPeakByChannel from spikeinterface.core.template import Templates + spike_dtype = [ ("sample_index", "int64"), ("channel_index", "int64"), @@ -17,7 +18,16 @@ ("segment_index", "int64"), ] -from .main import BaseTemplateMatchingEngine +try: + import torch + import torch.nn.functional as F + + HAVE_TORCH = True + from torch.nn.functional import conv1d +except ImportError: + HAVE_TORCH = False + +from .base import BaseTemplateMatching def compress_templates( @@ -42,9 +52,9 @@ def compress_templates( temporal, singular, spatial = np.linalg.svd(templates_array, full_matrices=False) # Keep only the strongest components - temporal = temporal[:, :, :approx_rank] - singular = singular[:, :approx_rank] - spatial = spatial[:, :approx_rank, :] + temporal = temporal[:, :, :approx_rank].astype(np.float32) + singular = singular[:, :approx_rank].astype(np.float32) + spatial = spatial[:, :approx_rank, :].astype(np.float32) if return_new_templates: templates_array = np.matmul(temporal * singular[:, np.newaxis, :], spatial) @@ -89,7 +99,7 @@ def compute_overlaps(templates, num_samples, num_channels, sparsities): return new_overlaps -class CircusOMPSVDPeeler(BaseTemplateMatchingEngine): +class CircusOMPSVDPeeler(BaseTemplateMatching): """ Orthogonal Matching Pursuit inspired from Spyking Circus sorter @@ -106,213 +116,247 @@ class CircusOMPSVDPeeler(BaseTemplateMatchingEngine): Parameters ---------- - amplitude: tuple + amplitude : tuple (Minimal, Maximal) amplitudes allowed for every template - max_failures: int + max_failures : int Stopping criteria of the OMP algorithm, as number of retry while updating amplitudes - sparse_kwargs: dict + sparse_kwargs : dict Parameters to extract a sparsity mask from the waveform_extractor, if not already sparse. - rank: int, default: 5 + rank : int, default: 5 Number of components used internally by the SVD - vicinity: int + vicinity : int Size of the area surrounding a spike to perform modification (expressed in terms of template temporal width) + engine : string in ["numpy", "torch", "auto"]. Default "auto" + The engine to use for the convolutions + torch_device : string in ["cpu", "cuda", None]. Default "cpu" + Controls torch device if the torch engine is selected ----- """ - _default_params = { - "amplitudes": [0.6, np.inf], - "stop_criteria": "max_failures", - "max_failures": 10, - "omp_min_sps": 0.1, - "relative_error": 5e-5, - "templates": None, - "rank": 5, - "ignore_inds": [], - "vicinity": 3, - } + _more_output_keys = [ + "norms", + "temporal", + "spatial", + "singular", + "units_overlaps", + "unit_overlaps_indices", + "normed_templates", + "overlaps", + ] + + def __init__( + self, + recording, + return_output=True, + parents=None, + templates=None, + amplitudes=[0.6, np.inf], + stop_criteria="max_failures", + max_failures=5, + omp_min_sps=0.1, + relative_error=5e-5, + rank=5, + ignore_inds=[], + vicinity=2, + precomputed=None, + engine="numpy", + torch_device="cpu", + ): + + BaseTemplateMatching.__init__(self, recording, templates, return_output=True, parents=None) + + self.num_channels = recording.get_num_channels() + self.num_samples = templates.num_samples + self.nbefore = templates.nbefore + self.nafter = templates.nafter + self.sampling_frequency = recording.get_sampling_frequency() + self.vicinity = vicinity * self.num_samples + assert engine in ["numpy", "torch", "auto"], "engine should be numpy, torch or auto" + if engine == "auto": + if HAVE_TORCH: + self.engine = "torch" + else: + self.engine = "numpy" + else: + if engine == "torch": + assert HAVE_TORCH, "please install torch to use the torch engine" + self.engine = engine + + assert torch_device in ["cuda", "cpu", None] + self.torch_device = torch_device + + self.amplitudes = amplitudes + self.stop_criteria = stop_criteria + self.max_failures = max_failures + self.omp_min_sps = omp_min_sps + self.relative_error = relative_error + self.rank = rank + + self.num_templates = len(templates.unit_ids) + + if precomputed is None: + self._prepare_templates() + else: + for key in self._more_output_keys: + assert precomputed[key] is not None, "If templates are provided, %d should also be there" % key + setattr(self, key, precomputed[key]) + + self.ignore_inds = np.array(ignore_inds) + + self.unit_overlaps_tables = {} + for i in range(self.num_templates): + self.unit_overlaps_tables[i] = np.zeros(self.num_templates, dtype=int) + self.unit_overlaps_tables[i][self.unit_overlaps_indices[i]] = np.arange(len(self.unit_overlaps_indices[i])) - @classmethod - def _prepare_templates(cls, d): - templates = d["templates"] - num_templates = len(d["templates"].unit_ids) + self.margin = 2 * self.num_samples + self.is_pushed = False - assert d["stop_criteria"] in ["max_failures", "omp_min_sps", "relative_error"] + def _prepare_templates(self): - sparsity = templates.sparsity.mask + assert self.stop_criteria in ["max_failures", "omp_min_sps", "relative_error"] + + if self.templates.sparsity is None: + sparsity = np.ones((self.num_templates, self.num_channels), dtype=bool) + else: + sparsity = self.templates.sparsity.mask units_overlaps = np.sum(np.logical_and(sparsity[:, np.newaxis, :], sparsity[np.newaxis, :, :]), axis=2) - d["units_overlaps"] = units_overlaps > 0 - d["unit_overlaps_indices"] = {} - for i in range(num_templates): - (d["unit_overlaps_indices"][i],) = np.nonzero(d["units_overlaps"][i]) + self.units_overlaps = units_overlaps > 0 + self.unit_overlaps_indices = {} + for i in range(self.num_templates): + self.unit_overlaps_indices[i] = np.flatnonzero(self.units_overlaps[i]) - templates_array = templates.get_dense_templates().copy() + templates_array = self.templates.get_dense_templates().copy() # Then we keep only the strongest components - d["temporal"], d["singular"], d["spatial"], templates_array = compress_templates(templates_array, d["rank"]) + self.temporal, self.singular, self.spatial, templates_array = compress_templates(templates_array, self.rank) - d["normed_templates"] = np.zeros(templates_array.shape, dtype=np.float32) - d["norms"] = np.zeros(num_templates, dtype=np.float32) + self.normed_templates = np.zeros(templates_array.shape, dtype=np.float32) + self.norms = np.zeros(self.num_templates, dtype=np.float32) # And get the norms, saving compressed templates for CC matrix - for count in range(num_templates): + for count in range(self.num_templates): template = templates_array[count][:, sparsity[count]] - d["norms"][count] = np.linalg.norm(template) - d["normed_templates"][count][:, sparsity[count]] = template / d["norms"][count] + self.norms[count] = np.linalg.norm(template) + self.normed_templates[count][:, sparsity[count]] = template / self.norms[count] - d["temporal"] /= d["norms"][:, np.newaxis, np.newaxis] - d["temporal"] = np.flip(d["temporal"], axis=1) + self.temporal /= self.norms[:, np.newaxis, np.newaxis] + self.temporal = np.flip(self.temporal, axis=1) - d["overlaps"] = [] - d["max_similarity"] = np.zeros((num_templates, num_templates), dtype=np.float32) - for i in range(num_templates): - num_overlaps = np.sum(d["units_overlaps"][i]) - overlapping_units = np.where(d["units_overlaps"][i])[0] + self.overlaps = [] + self.max_similarity = np.zeros((self.num_templates, self.num_templates), dtype=np.float32) + for i in range(self.num_templates): + num_overlaps = np.sum(self.units_overlaps[i]) + overlapping_units = np.flatnonzero(self.units_overlaps[i]) # Reconstruct unit template from SVD Matrices - data = d["temporal"][i] * d["singular"][i][np.newaxis, :] - template_i = np.matmul(data, d["spatial"][i, :, :]) + data = self.temporal[i] * self.singular[i][np.newaxis, :] + template_i = np.matmul(data, self.spatial[i, :, :]) template_i = np.flipud(template_i) - unit_overlaps = np.zeros([num_overlaps, 2 * d["num_samples"] - 1], dtype=np.float32) + unit_overlaps = np.zeros([num_overlaps, 2 * self.num_samples - 1], dtype=np.float32) for count, j in enumerate(overlapping_units): overlapped_channels = sparsity[j] visible_i = template_i[:, overlapped_channels] - spatial_filters = d["spatial"][j, :, overlapped_channels] + spatial_filters = self.spatial[j, :, overlapped_channels] spatially_filtered_template = np.matmul(visible_i, spatial_filters) - visible_i = spatially_filtered_template * d["singular"][j] + visible_i = spatially_filtered_template * self.singular[j] for rank in range(visible_i.shape[1]): - unit_overlaps[count, :] += np.convolve(visible_i[:, rank], d["temporal"][j][:, rank], mode="full") + unit_overlaps[count, :] += np.convolve(visible_i[:, rank], self.temporal[j][:, rank], mode="full") - d["max_similarity"][i, j] = np.max(unit_overlaps[count]) + self.max_similarity[i, j] = np.max(unit_overlaps[count]) - d["overlaps"].append(unit_overlaps) + self.overlaps.append(unit_overlaps) - if d["amplitudes"] is None: - distances = np.sort(d["max_similarity"], axis=1)[:, ::-1] + if self.amplitudes is None: + distances = np.sort(self.max_similarity, axis=1)[:, ::-1] distances = 1 - distances[:, 1] / 2 - d["amplitudes"] = np.zeros((num_templates, 2)) - d["amplitudes"][:, 0] = distances - d["amplitudes"][:, 1] = np.inf - - d["spatial"] = np.moveaxis(d["spatial"], [0, 1, 2], [1, 0, 2]) - d["temporal"] = np.moveaxis(d["temporal"], [0, 1, 2], [1, 2, 0]) - d["singular"] = d["singular"].T[:, :, np.newaxis] - return d - - @classmethod - def initialize_and_check_kwargs(cls, recording, kwargs): - d = cls._default_params.copy() - d.update(kwargs) - - assert isinstance(d["templates"], Templates), ( - f"The templates supplied is of type {type(d['templates'])} " f"and must be a Templates" - ) - - d["num_channels"] = recording.get_num_channels() - d["num_samples"] = d["templates"].num_samples - d["nbefore"] = d["templates"].nbefore - d["nafter"] = d["templates"].nafter - d["sampling_frequency"] = recording.get_sampling_frequency() - d["vicinity"] *= d["num_samples"] - - if "overlaps" not in d: - d = cls._prepare_templates(d) - else: - for key in [ - "norms", - "temporal", - "spatial", - "singular", - "units_overlaps", - "unit_overlaps_indices", - ]: - assert d[key] is not None, "If templates are provided, %d should also be there" % key - - d["num_templates"] = len(d["templates"].templates_array) - d["ignore_inds"] = np.array(d["ignore_inds"]) - - d["unit_overlaps_tables"] = {} - for i in range(d["num_templates"]): - d["unit_overlaps_tables"][i] = np.zeros(d["num_templates"], dtype=int) - d["unit_overlaps_tables"][i][d["unit_overlaps_indices"][i]] = np.arange(len(d["unit_overlaps_indices"][i])) - - return d - - @classmethod - def serialize_method_kwargs(cls, kwargs): - kwargs = dict(kwargs) - return kwargs - - @classmethod - def unserialize_in_worker(cls, kwargs): - return kwargs - - @classmethod - def get_margin(cls, recording, kwargs): - if kwargs["vicinity"] > 0: - margin = kwargs["vicinity"] - else: - margin = 2 * kwargs["num_samples"] - return margin - - @classmethod - def main_function(cls, traces, d): + self.amplitudes = np.zeros((self.num_templates, 2)) + self.amplitudes[:, 0] = distances + self.amplitudes[:, 1] = np.inf + + self.spatial = np.moveaxis(self.spatial, [0, 1, 2], [1, 0, 2]) + self.temporal = np.moveaxis(self.temporal, [0, 1, 2], [1, 2, 0]) + self.singular = self.singular.T[:, :, np.newaxis] + + def _push_to_torch(self): + if self.engine == "torch": + self.spatial = torch.as_tensor(self.spatial, device=self.torch_device) + self.singular = torch.as_tensor(self.singular, device=self.torch_device) + self.temporal = torch.as_tensor(self.temporal.copy(), device=self.torch_device).swapaxes(0, 1) + self.temporal = torch.flip(self.temporal, (2,)) + self.is_pushed = True + + def get_extra_outputs(self): + output = {} + for key in self._more_output_keys: + output[key] = getattr(self, key) + return output + + def get_trace_margin(self): + return self.margin + + def compute_matching(self, traces, start_frame, end_frame, segment_index): import scipy.spatial import scipy + from scipy import ndimage - (potrs,) = scipy.linalg.get_lapack_funcs(("potrs",), dtype=np.float32) + if not self.is_pushed: + self._push_to_torch() + (potrs,) = scipy.linalg.get_lapack_funcs(("potrs",), dtype=np.float32) (nrm2,) = scipy.linalg.get_blas_funcs(("nrm2",), dtype=np.float32) - num_templates = d["num_templates"] - num_samples = d["num_samples"] - num_channels = d["num_channels"] - overlaps_array = d["overlaps"] - norms = d["norms"] omp_tol = np.finfo(np.float32).eps - num_samples = d["nafter"] + d["nbefore"] - neighbor_window = num_samples - 1 - if isinstance(d["amplitudes"], list): - min_amplitude, max_amplitude = d["amplitudes"] + neighbor_window = self.num_samples - 1 + + if isinstance(self.amplitudes, list): + min_amplitude, max_amplitude = self.amplitudes else: - min_amplitude, max_amplitude = d["amplitudes"][:, 0], d["amplitudes"][:, 1] + min_amplitude, max_amplitude = self.amplitudes[:, 0], self.amplitudes[:, 1] min_amplitude = min_amplitude[:, np.newaxis] max_amplitude = max_amplitude[:, np.newaxis] - ignore_inds = d["ignore_inds"] - vicinity = d["vicinity"] - num_timesteps = len(traces) + if self.engine == "torch": + blank = np.zeros((neighbor_window, self.num_channels), dtype=np.float32) + traces = np.vstack((blank, traces, blank)) + num_timesteps = traces.shape[0] + torch_traces = torch.as_tensor(traces.T[np.newaxis, :, :], device=self.torch_device) + num_templates, num_channels = self.temporal.shape[0], self.temporal.shape[1] + spatially_filtered_data = torch.matmul(self.spatial, torch_traces) + scaled_filtered_data = (spatially_filtered_data * self.singular).swapaxes(0, 1) + scaled_filtered_data_ = scaled_filtered_data.reshape(1, num_templates * num_channels, num_timesteps) + scalar_products = conv1d(scaled_filtered_data_, self.temporal, groups=num_templates, padding="valid") + scalar_products = scalar_products.cpu().numpy()[0, :, self.num_samples - 1 : -neighbor_window] + else: + num_timesteps = traces.shape[0] + num_peaks = num_timesteps - neighbor_window + conv_shape = (self.num_templates, num_peaks) + scalar_products = np.zeros(conv_shape, dtype=np.float32) + # Filter using overlap-and-add convolution + spatially_filtered_data = np.matmul(self.spatial, traces.T[np.newaxis, :, :]) + scaled_filtered_data = spatially_filtered_data * self.singular + from scipy import signal + + objective_by_rank = signal.oaconvolve(scaled_filtered_data, self.temporal, axes=2, mode="valid") + scalar_products += np.sum(objective_by_rank, axis=0) - num_peaks = num_timesteps - num_samples + 1 - conv_shape = (num_templates, num_peaks) - scalar_products = np.zeros(conv_shape, dtype=np.float32) + num_peaks = scalar_products.shape[1] # Filter using overlap-and-add convolution - if len(ignore_inds) > 0: - not_ignored = ~np.isin(np.arange(num_templates), ignore_inds) - spatially_filtered_data = np.matmul(d["spatial"][:, not_ignored, :], traces.T[np.newaxis, :, :]) - scaled_filtered_data = spatially_filtered_data * d["singular"][:, not_ignored, :] - objective_by_rank = scipy.signal.oaconvolve( - scaled_filtered_data, d["temporal"][:, not_ignored, :], axes=2, mode="valid" - ) - scalar_products[not_ignored] += np.sum(objective_by_rank, axis=0) - scalar_products[ignore_inds] = -np.inf - else: - spatially_filtered_data = np.matmul(d["spatial"], traces.T[np.newaxis, :, :]) - scaled_filtered_data = spatially_filtered_data * d["singular"] - objective_by_rank = scipy.signal.oaconvolve(scaled_filtered_data, d["temporal"], axes=2, mode="valid") - scalar_products += np.sum(objective_by_rank, axis=0) + if len(self.ignore_inds) > 0: + scalar_products[self.ignore_inds] = -np.inf + not_ignored = ~np.isin(np.arange(self.num_templates), self.ignore_inds) num_spikes = 0 spikes = np.empty(scalar_products.size, dtype=spike_dtype) - M = np.zeros((num_templates, num_templates), dtype=np.float32) + M = np.zeros((self.num_templates, self.num_templates), dtype=np.float32) all_selections = np.empty((2, scalar_products.size), dtype=np.int32) final_amplitudes = np.zeros(scalar_products.shape, dtype=np.float32) @@ -320,18 +364,16 @@ def main_function(cls, traces, d): full_sps = scalar_products.copy() - neighbors = {} - all_amplitudes = np.zeros(0, dtype=np.float32) is_in_vicinity = np.zeros(0, dtype=np.int32) - if d["stop_criteria"] == "omp_min_sps": - stop_criteria = d["omp_min_sps"] * np.maximum(d["norms"], np.sqrt(num_channels * num_samples)) - elif d["stop_criteria"] == "max_failures": + if self.stop_criteria == "omp_min_sps": + stop_criteria = self.omp_min_sps * np.maximum(self.norms, np.sqrt(self.num_channels * self.num_samples)) + elif self.stop_criteria == "max_failures": num_valids = 0 - nb_failures = d["max_failures"] - elif d["stop_criteria"] == "relative_error": - if len(ignore_inds) > 0: + nb_failures = self.max_failures + elif self.stop_criteria == "relative_error": + if len(self.ignore_inds) > 0: new_error = np.linalg.norm(scalar_products[not_ignored]) else: new_error = np.linalg.norm(scalar_products) @@ -340,128 +382,141 @@ def main_function(cls, traces, d): do_loop = True while do_loop: - best_amplitude_ind = scalar_products.argmax() - best_cluster_ind, peak_index = np.unravel_index(best_amplitude_ind, scalar_products.shape) - - if num_selection > 0: - delta_t = selection[1] - peak_index - idx = np.where((delta_t < num_samples) & (delta_t > -num_samples))[0] - myline = neighbor_window + delta_t[idx] - myindices = selection[0, idx] - - local_overlaps = overlaps_array[best_cluster_ind] - overlapping_templates = d["unit_overlaps_indices"][best_cluster_ind] - table = d["unit_overlaps_tables"][best_cluster_ind] - - if num_selection == M.shape[0]: - Z = np.zeros((2 * num_selection, 2 * num_selection), dtype=np.float32) - Z[:num_selection, :num_selection] = M - M = Z - - mask = np.isin(myindices, overlapping_templates) - a, b = myindices[mask], myline[mask] - M[num_selection, idx[mask]] = local_overlaps[table[a], b] - - if vicinity == 0: - scipy.linalg.solve_triangular( - M[:num_selection, :num_selection], - M[num_selection, :num_selection], - trans=0, - lower=1, - overwrite_b=True, - check_finite=False, - ) - - v = nrm2(M[num_selection, :num_selection]) ** 2 - Lkk = 1 - v - if Lkk <= omp_tol: # selected atoms are dependent - break - M[num_selection, num_selection] = np.sqrt(Lkk) - else: - is_in_vicinity = np.where(np.abs(delta_t) < vicinity)[0] - if len(is_in_vicinity) > 0: - L = M[is_in_vicinity, :][:, is_in_vicinity] + best_cluster_inds = np.argmax(scalar_products, axis=0, keepdims=True) + products = np.take_along_axis(scalar_products, best_cluster_inds, axis=0) - M[num_selection, is_in_vicinity] = scipy.linalg.solve_triangular( - L, M[num_selection, is_in_vicinity], trans=0, lower=1, overwrite_b=True, check_finite=False - ) + result = ndimage.maximum_filter(products[0], size=self.vicinity, mode="constant", cval=0) + cond_1 = products[0] / self.norms[best_cluster_inds[0]] > 0.25 + cond_2 = np.abs(products[0] - result) < 1e-9 + peak_indices = np.flatnonzero(cond_1 * cond_2) - v = nrm2(M[num_selection, is_in_vicinity]) ** 2 - Lkk = 1 - v - if Lkk <= omp_tol: # selected atoms are dependent - break - M[num_selection, num_selection] = np.sqrt(Lkk) - else: - M[num_selection, num_selection] = 1.0 - else: - M[0, 0] = 1 + if len(peak_indices) == 0: + break - all_selections[:, num_selection] = [best_cluster_ind, peak_index] - num_selection += 1 + for peak_index in peak_indices: - selection = all_selections[:, :num_selection] - res_sps = full_sps[selection[0], selection[1]] + best_cluster_ind = best_cluster_inds[0, peak_index] - if vicinity == 0: - all_amplitudes, _ = potrs(M[:num_selection, :num_selection], res_sps, lower=True, overwrite_b=False) - all_amplitudes /= norms[selection[0]] - else: - is_in_vicinity = np.append(is_in_vicinity, num_selection - 1) - all_amplitudes = np.append(all_amplitudes, np.float32(1)) - L = M[is_in_vicinity, :][:, is_in_vicinity] - all_amplitudes[is_in_vicinity], _ = potrs(L, res_sps[is_in_vicinity], lower=True, overwrite_b=False) - all_amplitudes[is_in_vicinity] /= norms[selection[0][is_in_vicinity]] + if num_selection > 0: + delta_t = selection[1] - peak_index + idx = np.flatnonzero((delta_t < self.num_samples) & (delta_t > -self.num_samples)) + myline = neighbor_window + delta_t[idx] + myindices = selection[0, idx] + + local_overlaps = self.overlaps[best_cluster_ind] + overlapping_templates = self.unit_overlaps_indices[best_cluster_ind] + table = self.unit_overlaps_tables[best_cluster_ind] - diff_amplitudes = all_amplitudes - final_amplitudes[selection[0], selection[1]] - modified = np.where(np.abs(diff_amplitudes) > omp_tol)[0] - final_amplitudes[selection[0], selection[1]] = all_amplitudes + if num_selection == M.shape[0]: + Z = np.zeros((2 * num_selection, 2 * num_selection), dtype=np.float32) + Z[:num_selection, :num_selection] = M + M = Z - for i in modified: - tmp_best, tmp_peak = selection[:, i] - diff_amp = diff_amplitudes[i] * norms[tmp_best] + mask = np.isin(myindices, overlapping_templates) + a, b = myindices[mask], myline[mask] + M[num_selection, idx[mask]] = local_overlaps[table[a], b] + + if self.vicinity == 0: + scipy.linalg.solve_triangular( + M[:num_selection, :num_selection], + M[num_selection, :num_selection], + trans=0, + lower=1, + overwrite_b=True, + check_finite=False, + ) - local_overlaps = overlaps_array[tmp_best] - overlapping_templates = d["units_overlaps"][tmp_best] + v = nrm2(M[num_selection, :num_selection]) ** 2 + Lkk = 1 - v + if Lkk <= omp_tol: # selected atoms are dependent + break + M[num_selection, num_selection] = np.sqrt(Lkk) + else: + is_in_vicinity = np.flatnonzero(np.abs(delta_t) < self.vicinity) + + if len(is_in_vicinity) > 0: + L = M[is_in_vicinity, :][:, is_in_vicinity] + + M[num_selection, is_in_vicinity] = scipy.linalg.solve_triangular( + L, + M[num_selection, is_in_vicinity], + trans=0, + lower=1, + overwrite_b=True, + check_finite=False, + ) + + v = nrm2(M[num_selection, is_in_vicinity]) ** 2 + Lkk = 1 - v + if Lkk <= omp_tol: # selected atoms are dependent + break + M[num_selection, num_selection] = np.sqrt(Lkk) + else: + M[num_selection, num_selection] = 1.0 + else: + M[0, 0] = 1 - if not tmp_peak in neighbors.keys(): - idx = [max(0, tmp_peak - neighbor_window), min(num_peaks, tmp_peak + num_samples)] - tdx = [neighbor_window + idx[0] - tmp_peak, num_samples + idx[1] - tmp_peak - 1] - neighbors[tmp_peak] = {"idx": idx, "tdx": tdx} + all_selections[:, num_selection] = [best_cluster_ind, peak_index] + num_selection += 1 - idx = neighbors[tmp_peak]["idx"] - tdx = neighbors[tmp_peak]["tdx"] + selection = all_selections[:, :num_selection] + res_sps = full_sps[selection[0], selection[1]] - to_add = diff_amp * local_overlaps[:, tdx[0] : tdx[1]] - scalar_products[overlapping_templates, idx[0] : idx[1]] -= to_add + if self.vicinity == 0: + new_amplitudes, _ = potrs(M[:num_selection, :num_selection], res_sps, lower=True, overwrite_b=False) + sub_selection = selection + new_amplitudes /= self.norms[sub_selection[0]] + else: + is_in_vicinity = np.append(is_in_vicinity, num_selection - 1) + all_amplitudes = np.append(all_amplitudes, np.float32(1)) + L = M[is_in_vicinity, :][:, is_in_vicinity] + new_amplitudes, _ = potrs(L, res_sps[is_in_vicinity], lower=True, overwrite_b=False) + sub_selection = selection[:, is_in_vicinity] + new_amplitudes /= self.norms[sub_selection[0]] + + diff_amplitudes = new_amplitudes - final_amplitudes[sub_selection[0], sub_selection[1]] + modified = np.flatnonzero(np.abs(diff_amplitudes) > omp_tol) + final_amplitudes[sub_selection[0], sub_selection[1]] = new_amplitudes + + for i in modified: + tmp_best, tmp_peak = sub_selection[:, i] + diff_amp = diff_amplitudes[i] * self.norms[tmp_best] + local_overlaps = self.overlaps[tmp_best] + overlapping_templates = self.units_overlaps[tmp_best] + tmp = tmp_peak - neighbor_window + idx = [max(0, tmp), min(num_peaks, tmp_peak + self.num_samples)] + tdx = [idx[0] - tmp, idx[1] - tmp] + to_add = diff_amp * local_overlaps[:, tdx[0] : tdx[1]] + scalar_products[overlapping_templates, idx[0] : idx[1]] -= to_add # We stop when updates do not modify the chosen spikes anymore - if d["stop_criteria"] == "omp_min_sps": + if self.stop_criteria == "omp_min_sps": is_valid = scalar_products > stop_criteria[:, np.newaxis] do_loop = np.any(is_valid) - elif d["stop_criteria"] == "max_failures": + elif self.stop_criteria == "max_failures": is_valid = (final_amplitudes > min_amplitude) * (final_amplitudes < max_amplitude) new_num_valids = np.sum(is_valid) if (new_num_valids - num_valids) > 0: - nb_failures = d["max_failures"] + nb_failures = self.max_failures else: nb_failures -= 1 num_valids = new_num_valids do_loop = nb_failures > 0 - elif d["stop_criteria"] == "relative_error": + elif self.stop_criteria == "relative_error": previous_error = new_error - if len(ignore_inds) > 0: + if len(self.ignore_inds) > 0: new_error = np.linalg.norm(scalar_products[not_ignored]) else: new_error = np.linalg.norm(scalar_products) delta_error = np.abs(new_error / previous_error - 1) - do_loop = delta_error > d["relative_error"] + do_loop = delta_error > self.relative_error is_valid = (final_amplitudes > min_amplitude) * (final_amplitudes < max_amplitude) valid_indices = np.where(is_valid) num_spikes = len(valid_indices[0]) - spikes["sample_index"][:num_spikes] = valid_indices[1] + d["nbefore"] + spikes["sample_index"][:num_spikes] = valid_indices[1] + self.nbefore spikes["channel_index"][:num_spikes] = 0 spikes["cluster_index"][:num_spikes] = valid_indices[0] spikes["amplitude"][:num_spikes] = final_amplitudes[valid_indices[0], valid_indices[1]] @@ -473,7 +528,7 @@ def main_function(cls, traces, d): return spikes -class CircusPeeler(BaseTemplateMatchingEngine): +class CircusPeeler(BaseTemplateMatching): """ Greedy Template-matching ported from the Spyking Circus sorter @@ -491,27 +546,27 @@ class CircusPeeler(BaseTemplateMatchingEngine): Parameters ---------- - peak_sign: str + peak_sign : str Sign of the peak (neg, pos, or both) - exclude_sweep_ms: float + exclude_sweep_ms : float The number of samples before/after to classify a peak (should be low) - jitter: int + jitter : int The number of samples considered before/after every peak to search for matches - detect_threshold: int + detect_threshold : int The detection threshold - noise_levels: array + noise_levels : array The noise levels, for every channels - random_chunk_kwargs: dict + random_chunk_kwargs : dict Parameters for computing noise levels, if not provided (sub optimal) - max_amplitude: float + max_amplitude : float Maximal amplitude allowed for every template - min_amplitude: float + min_amplitude : float Minimal amplitude allowed for every template - use_sparse_matrix_threshold: float + use_sparse_matrix_threshold : float If density of the templates is below a given threshold, sparse matrix are used (memory efficient) - sparse_kwargs: dict + sparse_kwargs : dict Parameters to extract a sparsity mask from the waveform_extractor, if not already sparse. ----- @@ -519,225 +574,120 @@ class CircusPeeler(BaseTemplateMatchingEngine): """ - _default_params = { - "peak_sign": "neg", - "exclude_sweep_ms": 0.1, - "jitter_ms": 0.1, - "detect_threshold": 5, - "noise_levels": None, - "random_chunk_kwargs": {}, - "max_amplitude": 1.5, - "min_amplitude": 0.5, - "use_sparse_matrix_threshold": 0.25, - "templates": None, - } - - @classmethod - def _prepare_templates(cls, d): + def __init__( + self, + recording, + return_output=True, + parents=None, + templates=None, + peak_sign="neg", + exclude_sweep_ms=0.1, + jitter_ms=0.1, + detect_threshold=5, + noise_levels=None, + random_chunk_kwargs={}, + max_amplitude=1.5, + min_amplitude=0.5, + use_sparse_matrix_threshold=0.25, + ): + + BaseTemplateMatching.__init__(self, recording, templates, return_output=True, parents=None) + + try: + from sklearn.feature_extraction.image import extract_patches_2d + + HAVE_SKLEARN = True + except ImportError: + HAVE_SKLEARN = False + + assert HAVE_SKLEARN, "CircusPeeler needs sklearn to work" + + assert (use_sparse_matrix_threshold >= 0) and ( + use_sparse_matrix_threshold <= 1 + ), f"use_sparse_matrix_threshold should be in [0, 1]" + + self.num_channels = recording.get_num_channels() + self.num_samples = templates.num_samples + self.num_templates = len(templates.unit_ids) + + if noise_levels is None: + print("CircusPeeler : noise should be computed outside") + noise_levels = get_noise_levels(recording, **d["random_chunk_kwargs"], return_scaled=False) + + self.abs_threholds = noise_levels * detect_threshold + + self.use_sparse_matrix_threshold = use_sparse_matrix_threshold + self._prepare_templates() + self.overlaps = compute_overlaps( + self.normed_templates, + self.num_samples, + self.num_channels, + self.sparsities, + ) + + self.exclude_sweep_size = int(exclude_sweep_ms * recording.get_sampling_frequency() / 1000.0) + + self.nbefore = templates.nbefore + self.nafter = templates.nafter + self.patch_sizes = (templates.num_samples, self.num_channels) + self.sym_patch = self.nbefore == self.nafter + self.jitter = int(jitter_ms * recording.get_sampling_frequency() / 1000.0) + + self.amplitudes = np.zeros((self.num_templates, 2), dtype=np.float32) + self.amplitudes[:, 0] = min_amplitude + self.amplitudes[:, 1] = max_amplitude + + self.margin = max(self.nbefore, self.nafter) * 2 + self.peak_sign = peak_sign + + def _prepare_templates(self): import scipy.spatial import scipy - templates = d["templates"] - num_samples = d["num_samples"] - num_channels = d["num_channels"] - num_templates = d["num_templates"] - use_sparse_matrix_threshold = d["use_sparse_matrix_threshold"] + self.norms = np.zeros(self.num_templates, dtype=np.float32) - d["norms"] = np.zeros(num_templates, dtype=np.float32) + all_units = self.templates.unit_ids - all_units = d["templates"].unit_ids + sparsity = self.templates.sparsity.mask - sparsity = templates.sparsity.mask - - templates_array = templates.get_dense_templates() - d["sparsities"] = {} - d["normed_templates"] = {} + templates_array = self.templates.get_dense_templates() + self.sparsities = {} + self.normed_templates = {} for count, unit_id in enumerate(all_units): - (d["sparsities"][count],) = np.nonzero(sparsity[count]) - d["norms"][count] = np.linalg.norm(templates_array[count]) - templates_array[count] /= d["norms"][count] - d["normed_templates"][count] = templates_array[count][:, sparsity[count]] + self.sparsities[count] = np.flatnonzero(sparsity[count]) + self.norms[count] = np.linalg.norm(templates_array[count]) + templates_array[count] /= self.norms[count] + self.normed_templates[count] = templates_array[count][:, sparsity[count]] - templates_array = templates_array.reshape(num_templates, -1) + templates_array = templates_array.reshape(self.num_templates, -1) - nnz = np.sum(templates_array != 0) / (num_templates * num_samples * num_channels) - if nnz <= use_sparse_matrix_threshold: + nnz = np.sum(templates_array != 0) / (self.num_templates * self.num_samples * self.num_channels) + if nnz <= self.use_sparse_matrix_threshold: templates_array = scipy.sparse.csr_matrix(templates_array) print(f"Templates are automatically sparsified (sparsity level is {nnz})") - d["is_dense"] = False + self.is_dense = False else: - d["is_dense"] = True - - d["circus_templates"] = templates_array - - return d - - # @classmethod - # def _mcc_error(cls, bounds, good, bad): - # fn = np.sum((good < bounds[0]) | (good > bounds[1])) - # fp = np.sum((bounds[0] <= bad) & (bad <= bounds[1])) - # tp = np.sum((bounds[0] <= good) & (good <= bounds[1])) - # tn = np.sum((bad < bounds[0]) | (bad > bounds[1])) - # denom = (tp + fp) * (tp + fn) * (tn + fp) * (tn + fn) - # if denom > 0: - # mcc = 1 - (tp * tn - fp * fn) / np.sqrt(denom) - # else: - # mcc = 1 - # return mcc - - # @classmethod - # def _cost_function_mcc(cls, bounds, good, bad, delta_amplitude, alpha): - # # We want a minimal error, with the larger bounds that are possible - # cost = alpha * cls._mcc_error(bounds, good, bad) + (1 - alpha) * np.abs( - # (1 - (bounds[1] - bounds[0]) / delta_amplitude) - # ) - # return cost - - # @classmethod - # def _optimize_amplitudes(cls, noise_snippets, d): - # parameters = d - # waveform_extractor = parameters["waveform_extractor"] - # templates = parameters["templates"] - # num_templates = parameters["num_templates"] - # max_amplitude = parameters["max_amplitude"] - # min_amplitude = parameters["min_amplitude"] - # alpha = 0.5 - # norms = parameters["norms"] - # all_units = list(waveform_extractor.sorting.unit_ids) - - # parameters["amplitudes"] = np.zeros((num_templates, 2), dtype=np.float32) - # noise = templates.dot(noise_snippets) / norms[:, np.newaxis] - - # all_amps = {} - # for count, unit_id in enumerate(all_units): - # waveform = waveform_extractor.get_waveforms(unit_id, force_dense=True) - # snippets = waveform.reshape(waveform.shape[0], -1).T - # amps = templates.dot(snippets) / norms[:, np.newaxis] - # good = amps[count, :].flatten() - - # sub_amps = amps[np.concatenate((np.arange(count), np.arange(count + 1, num_templates))), :] - # bad = sub_amps[sub_amps >= good] - # bad = np.concatenate((bad, noise[count])) - # cost_kwargs = [good, bad, max_amplitude - min_amplitude, alpha] - # cost_bounds = [(min_amplitude, 1), (1, max_amplitude)] - # res = scipy.optimize.differential_evolution(cls._cost_function_mcc, bounds=cost_bounds, args=cost_kwargs) - # parameters["amplitudes"][count] = res.x - - # return d - - @classmethod - def initialize_and_check_kwargs(cls, recording, kwargs): - try: - from sklearn.feature_extraction.image import extract_patches_2d - - HAVE_SKLEARN = True - except ImportError: - HAVE_SKLEARN = False + self.is_dense = True - assert HAVE_SKLEARN, "CircusPeeler needs sklearn to work" - d = cls._default_params.copy() - d.update(kwargs) + self.circus_templates = templates_array - # assert isinstance(d['waveform_extractor'], WaveformExtractor) - for v in ["use_sparse_matrix_threshold"]: - assert (d[v] >= 0) and (d[v] <= 1), f"{v} should be in [0, 1]" + def get_trace_margin(self): + return self.margin - d["num_channels"] = recording.get_num_channels() - d["num_samples"] = d["templates"].num_samples - d["num_templates"] = len(d["templates"].unit_ids) + def compute_matching(self, traces, start_frame, end_frame, segment_index): - if d["noise_levels"] is None: - print("CircusPeeler : noise should be computed outside") - d["noise_levels"] = get_noise_levels(recording, **d["random_chunk_kwargs"], return_scaled=False) - - d["abs_threholds"] = d["noise_levels"] * d["detect_threshold"] - - if "overlaps" not in d: - d = cls._prepare_templates(d) - d["overlaps"] = compute_overlaps( - d["normed_templates"], - d["num_samples"], - d["num_channels"], - d["sparsities"], - ) - else: - for key in ["circus_templates", "norms"]: - assert d[key] is not None, "If templates are provided, %d should also be there" % key + neighbor_window = self.num_samples - 1 - d["exclude_sweep_size"] = int(d["exclude_sweep_ms"] * recording.get_sampling_frequency() / 1000.0) - - d["nbefore"] = d["templates"].nbefore - d["nafter"] = d["templates"].nafter - d["patch_sizes"] = ( - d["templates"].num_samples, - d["num_channels"], - ) - d["sym_patch"] = d["nbefore"] == d["nafter"] - d["jitter"] = int(d["jitter_ms"] * recording.get_sampling_frequency() / 1000.0) - - d["amplitudes"] = np.zeros((d["num_templates"], 2), dtype=np.float32) - d["amplitudes"][:, 0] = d["min_amplitude"] - d["amplitudes"][:, 1] = d["max_amplitude"] - # num_segments = recording.get_num_segments() - # if d["waveform_extractor"]._params["max_spikes_per_unit"] is None: - # num_snippets = 1000 - # else: - # num_snippets = 2 * d["waveform_extractor"]._params["max_spikes_per_unit"] - - # num_chunks = num_snippets // num_segments - # noise_snippets = get_random_data_chunks( - # recording, num_chunks_per_segment=num_chunks, chunk_size=d["num_samples"], seed=42 - # ) - # noise_snippets = ( - # noise_snippets.reshape(num_chunks, d["num_samples"], d["num_channels"]) - # .reshape(num_chunks, -1) - # .T - # ) - # parameters = cls._optimize_amplitudes(noise_snippets, d) - - return d - - @classmethod - def serialize_method_kwargs(cls, kwargs): - kwargs = dict(kwargs) - return kwargs - - @classmethod - def unserialize_in_worker(cls, kwargs): - return kwargs - - @classmethod - def get_margin(cls, recording, kwargs): - margin = 2 * max(kwargs["nbefore"], kwargs["nafter"]) - return margin - - @classmethod - def main_function(cls, traces, d): - peak_sign = d["peak_sign"] - abs_threholds = d["abs_threholds"] - exclude_sweep_size = d["exclude_sweep_size"] - templates = d["circus_templates"] - num_templates = d["num_templates"] - overlaps = d["overlaps"] - margin = d["margin"] - norms = d["norms"] - jitter = d["jitter"] - patch_sizes = d["patch_sizes"] - num_samples = d["nafter"] + d["nbefore"] - neighbor_window = num_samples - 1 - amplitudes = d["amplitudes"] - sym_patch = d["sym_patch"] - - peak_traces = traces[margin // 2 : -margin // 2, :] + peak_traces = traces[self.margin // 2 : -self.margin // 2, :] peak_sample_index, peak_chan_ind = DetectPeakByChannel.detect_peaks( - peak_traces, peak_sign, abs_threholds, exclude_sweep_size + peak_traces, self.peak_sign, self.abs_threholds, self.exclude_sweep_size ) from sklearn.feature_extraction.image import extract_patches_2d - if jitter > 0: - jittered_peaks = peak_sample_index[:, np.newaxis] + np.arange(-jitter, jitter) - jittered_channels = peak_chan_ind[:, np.newaxis] + np.zeros(2 * jitter) + if self.jitter > 0: + jittered_peaks = peak_sample_index[:, np.newaxis] + np.arange(-self.jitter, self.jitter) + jittered_channels = peak_chan_ind[:, np.newaxis] + np.zeros(2 * self.jitter) mask = (jittered_peaks > 0) & (jittered_peaks < len(peak_traces)) jittered_peaks = jittered_peaks[mask] jittered_channels = jittered_channels[mask] @@ -749,26 +699,26 @@ def main_function(cls, traces, d): num_peaks = len(peak_sample_index) - if sym_patch: - snippets = extract_patches_2d(traces, patch_sizes)[peak_sample_index] - peak_sample_index += margin // 2 + if self.sym_patch: + snippets = extract_patches_2d(traces, self.patch_sizes)[peak_sample_index] + peak_sample_index += self.margin // 2 else: - peak_sample_index += margin // 2 - snippet_window = np.arange(-d["nbefore"], d["nafter"]) + peak_sample_index += self.margin // 2 + snippet_window = np.arange(-self.nbefore, self.nafter) snippets = traces[peak_sample_index[:, np.newaxis] + snippet_window] if num_peaks > 0: snippets = snippets.reshape(num_peaks, -1) - scalar_products = templates.dot(snippets.T) + scalar_products = self.circus_templates.dot(snippets.T) else: - scalar_products = np.zeros((num_templates, 0), dtype=np.float32) + scalar_products = np.zeros((self.num_templates, 0), dtype=np.float32) num_spikes = 0 spikes = np.empty(scalar_products.size, dtype=spike_dtype) - idx_lookup = np.arange(scalar_products.size).reshape(num_templates, -1) + idx_lookup = np.arange(scalar_products.size).reshape(self.num_templates, -1) - min_sps = (amplitudes[:, 0] * norms)[:, np.newaxis] - max_sps = (amplitudes[:, 1] * norms)[:, np.newaxis] + min_sps = (self.amplitudes[:, 0] * self.norms)[:, np.newaxis] + max_sps = (self.amplitudes[:, 1] * self.norms)[:, np.newaxis] is_valid = (scalar_products > min_sps) & (scalar_products < max_sps) @@ -787,7 +737,7 @@ def main_function(cls, traces, d): idx_neighbor = peak_data[is_valid_nn[0] : is_valid_nn[1]] + neighbor_window if not best_cluster_ind in cached_overlaps.keys(): - cached_overlaps[best_cluster_ind] = overlaps[best_cluster_ind].toarray() + cached_overlaps[best_cluster_ind] = self.overlaps[best_cluster_ind].toarray() to_add = -best_amplitude * cached_overlaps[best_cluster_ind][:, idx_neighbor] @@ -802,7 +752,7 @@ def main_function(cls, traces, d): is_valid = (scalar_products > min_sps) & (scalar_products < max_sps) - spikes["amplitude"][:num_spikes] /= norms[spikes["cluster_index"][:num_spikes]] + spikes["amplitude"][:num_spikes] /= self.norms[spikes["cluster_index"][:num_spikes]] spikes = spikes[:num_spikes] order = np.argsort(spikes["sample_index"]) diff --git a/src/spikeinterface/sortingcomponents/matching/main.py b/src/spikeinterface/sortingcomponents/matching/main.py index 6e5267cb70..f423d55e2a 100644 --- a/src/spikeinterface/sortingcomponents/matching/main.py +++ b/src/spikeinterface/sortingcomponents/matching/main.py @@ -3,8 +3,11 @@ from threadpoolctl import threadpool_limits import numpy as np -from spikeinterface.core.job_tools import ChunkRecordingExecutor, fix_job_kwargs -from spikeinterface.core import get_chunk_with_margin +# from spikeinterface.core.job_tools import ChunkRecordingExecutor, fix_job_kwargs +# from spikeinterface.core import get_chunk_with_margin + +from spikeinterface.core.job_tools import fix_job_kwargs +from spikeinterface.core.node_pipeline import run_node_pipeline def find_spikes_from_templates( @@ -21,7 +24,7 @@ def find_spikes_from_templates( method_kwargs : dict, optional Keyword arguments for the chosen method extra_outputs : bool - If True then method_kwargs is also returned + If True then a dict is also returned is also returned **job_kwargs : dict Parameters for ChunkRecordingExecutor verbose : Bool, default: False @@ -31,9 +34,8 @@ def find_spikes_from_templates( ------- spikes : ndarray Spikes found from templates. - method_kwargs: + outputs: Optionaly returns for debug purpose. - """ from .method_list import matching_methods @@ -42,117 +44,19 @@ def find_spikes_from_templates( job_kwargs = fix_job_kwargs(job_kwargs) method_class = matching_methods[method] + node0 = method_class(recording, **method_kwargs) + nodes = [node0] - # initialize - method_kwargs = method_class.initialize_and_check_kwargs(recording, method_kwargs) - - # add - method_kwargs["margin"] = method_class.get_margin(recording, method_kwargs) - - # serialiaze for worker - method_kwargs_seralized = method_class.serialize_method_kwargs(method_kwargs) - - # and run - func = _find_spikes_chunk - init_func = _init_worker_find_spikes - init_args = (recording, method, method_kwargs_seralized) - processor = ChunkRecordingExecutor( + spikes = run_node_pipeline( recording, - func, - init_func, - init_args, - handle_returns=True, + nodes, + job_kwargs, job_name=f"find spikes ({method})", - verbose=verbose, - **job_kwargs, + gather_mode="memory", + squeeze_output=True, ) - spikes = processor.run() - - spikes = np.concatenate(spikes) - if extra_outputs: - return spikes, method_kwargs + outputs = node0.get_extra_outputs() + return spikes, outputs else: return spikes - - -def _init_worker_find_spikes(recording, method, method_kwargs): - """Initialize worker for finding spikes.""" - - from .method_list import matching_methods - - method_class = matching_methods[method] - method_kwargs = method_class.unserialize_in_worker(method_kwargs) - - # create a local dict per worker - worker_ctx = {} - worker_ctx["recording"] = recording - worker_ctx["method"] = method - worker_ctx["method_kwargs"] = method_kwargs - worker_ctx["function"] = method_class.main_function - - return worker_ctx - - -def _find_spikes_chunk(segment_index, start_frame, end_frame, worker_ctx): - """Find spikes from a chunk of data.""" - - # recover variables of the worker - recording = worker_ctx["recording"] - method = worker_ctx["method"] - method_kwargs = worker_ctx["method_kwargs"] - margin = method_kwargs["margin"] - - # load trace in memory given some margin - recording_segment = recording._recording_segments[segment_index] - traces, left_margin, right_margin = get_chunk_with_margin( - recording_segment, start_frame, end_frame, None, margin, add_zeros=True - ) - - function = worker_ctx["function"] - - with threadpool_limits(limits=1): - spikes = function(traces, method_kwargs) - - # remove spikes in margin - if margin > 0: - keep = (spikes["sample_index"] >= margin) & (spikes["sample_index"] < (traces.shape[0] - margin)) - spikes = spikes[keep] - - spikes["sample_index"] += start_frame - margin - spikes["segment_index"] = segment_index - return spikes - - -# generic class for template engine -class BaseTemplateMatchingEngine: - default_params = {} - - @classmethod - def initialize_and_check_kwargs(cls, recording, kwargs): - """This function runs before loops""" - # need to be implemented in subclass - raise NotImplementedError - - @classmethod - def serialize_method_kwargs(cls, kwargs): - """This function serializes kwargs to distribute them to workers""" - # need to be implemented in subclass - raise NotImplementedError - - @classmethod - def unserialize_in_worker(cls, recording, kwargs): - """This function unserializes kwargs in workers""" - # need to be implemented in subclass - raise NotImplementedError - - @classmethod - def get_margin(cls, recording, kwargs): - # need to be implemented in subclass - raise NotImplementedError - - @classmethod - def main_function(cls, traces, method_kwargs): - """This function returns the number of samples for the chunk margins""" - # need to be implemented in subclass - raise NotImplementedError diff --git a/src/spikeinterface/sortingcomponents/matching/naive.py b/src/spikeinterface/sortingcomponents/matching/naive.py index 0dc71d789b..26f093c187 100644 --- a/src/spikeinterface/sortingcomponents/matching/naive.py +++ b/src/spikeinterface/sortingcomponents/matching/naive.py @@ -4,115 +4,68 @@ import numpy as np -from spikeinterface.core import get_noise_levels, get_channel_distances, get_random_data_chunks +from spikeinterface.core import get_noise_levels, get_channel_distances from spikeinterface.sortingcomponents.peak_detection import DetectPeakLocallyExclusive -from spikeinterface.core.template import Templates - -spike_dtype = [ - ("sample_index", "int64"), - ("channel_index", "int64"), - ("cluster_index", "int64"), - ("amplitude", "float64"), - ("segment_index", "int64"), -] - - -from .main import BaseTemplateMatchingEngine - - -class NaiveMatching(BaseTemplateMatchingEngine): - """ - This is a naive template matching that does not resolve collision - and does not take in account sparsity. - It just minimizes the distance to templates for detected peaks. - - It is implemented for benchmarking against this low quality template matching. - And also as an example how to deal with methods_kwargs, margin, intit, func, ... - """ - - default_params = { - "templates": None, - "peak_sign": "neg", - "exclude_sweep_ms": 0.1, - "detect_threshold": 5, - "noise_levels": None, - "radius_um": 100, - "random_chunk_kwargs": {}, - } - - @classmethod - def initialize_and_check_kwargs(cls, recording, kwargs): - d = cls.default_params.copy() - d.update(kwargs) - - assert isinstance(d["templates"], Templates), ( - f"The templates supplied is of type {type(d['templates'])} " f"and must be a Templates" - ) - - templates = d["templates"] - if d["noise_levels"] is None: - d["noise_levels"] = get_noise_levels(recording, **d["random_chunk_kwargs"], return_scaled=False) - d["abs_threholds"] = d["noise_levels"] * d["detect_threshold"] - - channel_distance = get_channel_distances(recording) - d["neighbours_mask"] = channel_distance < d["radius_um"] +from .base import BaseTemplateMatching, _base_matching_dtype - d["nbefore"] = templates.nbefore - d["nafter"] = templates.nafter - d["exclude_sweep_size"] = int(d["exclude_sweep_ms"] * recording.get_sampling_frequency() / 1000.0) +class NaiveMatching(BaseTemplateMatching): + def __init__( + self, + recording, + return_output=True, + parents=None, + templates=None, + peak_sign="neg", + exclude_sweep_ms=0.1, + detect_threshold=5, + noise_levels=None, + radius_um=100.0, + random_chunk_kwargs={}, + ): - return d + BaseTemplateMatching.__init__(self, recording, templates, return_output=True, parents=None) - @classmethod - def get_margin(cls, recording, kwargs): - margin = max(kwargs["nbefore"], kwargs["nafter"]) - return margin + self.templates_array = self.templates.get_dense_templates() - @classmethod - def serialize_method_kwargs(cls, kwargs): - kwargs = dict(kwargs) - return kwargs - - @classmethod - def unserialize_in_worker(cls, kwargs): - return kwargs - - @classmethod - def main_function(cls, traces, method_kwargs): - peak_sign = method_kwargs["peak_sign"] - abs_threholds = method_kwargs["abs_threholds"] - exclude_sweep_size = method_kwargs["exclude_sweep_size"] - neighbours_mask = method_kwargs["neighbours_mask"] - templates_array = method_kwargs["templates"].get_dense_templates() + if noise_levels is None: + noise_levels = get_noise_levels(recording, **random_chunk_kwargs, return_scaled=False) + self.abs_threholds = noise_levels * detect_threshold + self.peak_sign = peak_sign + channel_distance = get_channel_distances(recording) + self.neighbours_mask = channel_distance < radius_um + self.exclude_sweep_size = int(exclude_sweep_ms * recording.get_sampling_frequency() / 1000.0) + self.nbefore = self.templates.nbefore + self.nafter = self.templates.nafter + self.margin = max(self.nbefore, self.nafter) - nbefore = method_kwargs["nbefore"] - nafter = method_kwargs["nafter"] + def get_trace_margin(self): + return self.margin - margin = method_kwargs["margin"] + def compute_matching(self, traces, start_frame, end_frame, segment_index): - if margin > 0: - peak_traces = traces[margin:-margin, :] + if self.margin > 0: + peak_traces = traces[self.margin : -self.margin, :] else: peak_traces = traces peak_sample_ind, peak_chan_ind = DetectPeakLocallyExclusive.detect_peaks( - peak_traces, peak_sign, abs_threholds, exclude_sweep_size, neighbours_mask + peak_traces, self.peak_sign, self.abs_threholds, self.exclude_sweep_size, self.neighbours_mask ) - peak_sample_ind += margin + peak_sample_ind += self.margin - spikes = np.zeros(peak_sample_ind.size, dtype=spike_dtype) + spikes = np.zeros(peak_sample_ind.size, dtype=_base_matching_dtype) spikes["sample_index"] = peak_sample_ind - spikes["channel_index"] = peak_chan_ind # TODO need to put the channel from template + spikes["channel_index"] = peak_chan_ind # naively take the closest template for i in range(peak_sample_ind.size): - i0 = peak_sample_ind[i] - nbefore - i1 = peak_sample_ind[i] + nafter + i0 = peak_sample_ind[i] - self.nbefore + i1 = peak_sample_ind[i] + self.nafter waveforms = traces[i0:i1, :] - dist = np.sum(np.sum((templates_array - waveforms[None, :, :]) ** 2, axis=1), axis=1) + dist = np.sum(np.sum((self.templates_array - waveforms[None, :, :]) ** 2, axis=1), axis=1) cluster_index = np.argmin(dist) spikes["cluster_index"][i] = cluster_index diff --git a/src/spikeinterface/sortingcomponents/matching/tdc.py b/src/spikeinterface/sortingcomponents/matching/tdc.py index e66929e2b1..125baa3bda 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc.py @@ -2,24 +2,13 @@ import numpy as np from spikeinterface.core import ( - get_noise_levels, get_channel_distances, - compute_sparsity, get_template_extremum_channel, ) -from spikeinterface.sortingcomponents.peak_detection import DetectPeakLocallyExclusive -from spikeinterface.core.template import Templates +from spikeinterface.sortingcomponents.peak_detection import DetectPeakLocallyExclusive, DetectPeakMatchedFiltering +from .base import BaseTemplateMatching, _base_matching_dtype -spike_dtype = [ - ("sample_index", "int64"), - ("channel_index", "int64"), - ("cluster_index", "int64"), - ("amplitude", "float64"), - ("segment_index", "int64"), -] - -from .main import BaseTemplateMatchingEngine try: import numba @@ -30,9 +19,9 @@ HAVE_NUMBA = False -class TridesclousPeeler(BaseTemplateMatchingEngine): +class TridesclousPeeler(BaseTemplateMatching): """ - Template-matching ported from Tridesclous sorter. + Template-matching used by Tridesclous sorter. The idea of this peeler is pretty simple. 1. Find peaks @@ -41,350 +30,667 @@ class TridesclousPeeler(BaseTemplateMatchingEngine): 4. remove it from traces. 5. in the residual find peaks again - This method is quite fast but don't give exelent results to resolve - spike collision when templates have high similarity. + Contrary tp circus_peeler or wobble, this template matching is working directly one the waveforms. + There is no SVD decomposition + + """ - default_params = { - "templates": None, - "peak_sign": "neg", - "peak_shift_ms": 0.2, - "detect_threshold": 5, - "noise_levels": None, - "radius_um": 100, - "num_closest": 5, - "sample_shift": 3, - "ms_before": 0.8, - "ms_after": 1.2, - "num_peeler_loop": 2, - "num_template_try": 1, - } - - @classmethod - def initialize_and_check_kwargs(cls, recording, kwargs): - assert HAVE_NUMBA, "TridesclousPeeler needs numba to be installed" - - d = cls.default_params.copy() - d.update(kwargs) - - assert isinstance(d["templates"], Templates), ( - f"The templates supplied is of type {type(d['templates'])} " f"and must be a Templates" - ) + def __init__( + self, + recording, + return_output=True, + parents=None, + templates=None, + peak_sign="neg", + exclude_sweep_ms=0.5, + peak_shift_ms=0.2, + detect_threshold=5, + noise_levels=None, + use_fine_detector=True, + # TODO optimize theses radius + detection_radius_um=80.0, + cluster_radius_um=150.0, + amplitude_fitting_radius_um=150.0, + sample_shift=2, + ms_before=0.5, + ms_after=0.8, + max_peeler_loop=2, + amplitude_limits=(0.7, 1.4), + ): + + BaseTemplateMatching.__init__(self, recording, templates, return_output=True, parents=None) - templates = d["templates"] unit_ids = templates.unit_ids - channel_ids = templates.channel_ids + channel_ids = recording.channel_ids - sr = templates.sampling_frequency + num_templates = unit_ids.size - d["nbefore"] = templates.nbefore - d["nafter"] = templates.nafter - templates_array = templates.get_dense_templates() + sr = recording.sampling_frequency - nbefore_short = int(d["ms_before"] * sr / 1000.0) - nafter_short = int(d["ms_before"] * sr / 1000.0) + self.nbefore = templates.nbefore + self.nafter = templates.nafter + + self.peak_sign = peak_sign + + nbefore_short = int(ms_before * sr / 1000.0) + nafter_short = int(ms_after * sr / 1000.0) assert nbefore_short <= templates.nbefore assert nafter_short <= templates.nafter - d["nbefore_short"] = nbefore_short - d["nafter_short"] = nafter_short + self.nbefore_short = nbefore_short + self.nafter_short = nafter_short s0 = templates.nbefore - nbefore_short s1 = -(templates.nafter - nafter_short) if s1 == 0: s1 = None - templates_short = templates_array[:, slice(s0, s1), :].copy() - d["templates_short"] = templates_short - d["peak_shift"] = int(d["peak_shift_ms"] / 1000 * sr) + # TODO check with out copy + self.sparse_templates_array_short = templates.templates_array[:, slice(s0, s1), :].copy() - if d["noise_levels"] is None: - print("TridesclousPeeler : noise should be computed outside") - d["noise_levels"] = get_noise_levels(recording) + self.peak_shift = int(peak_shift_ms / 1000 * sr) - d["abs_thresholds"] = d["noise_levels"] * d["detect_threshold"] + assert noise_levels is not None, "TridesclousPeeler : noise should be computed outside" - channel_distance = get_channel_distances(recording) - d["neighbours_mask"] = channel_distance < d["radius_um"] + self.abs_thresholds = noise_levels * detect_threshold - sparsity = compute_sparsity( - templates, method="best_channels" - ) # , peak_sign=d["peak_sign"], threshold=d["detect_threshold"]) - template_sparsity_inds = sparsity.unit_id_to_channel_indices - template_sparsity = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") - for unit_index, unit_id in enumerate(unit_ids): - chan_inds = template_sparsity_inds[unit_id] - template_sparsity[unit_index, chan_inds] = True + channel_distance = get_channel_distances(recording) + self.neighbours_mask = channel_distance <= detection_radius_um - d["template_sparsity"] = template_sparsity + if templates.sparsity is not None: + self.sparsity_mask = templates.sparsity.mask + else: + self.sparsity_mask = np.ones((unit_ids.size, channel_ids.size), dtype=bool) - extremum_channel = get_template_extremum_channel(templates, peak_sign=d["peak_sign"], outputs="index") + extremum_chan = get_template_extremum_channel(templates, peak_sign=peak_sign, outputs="index") # as numpy vector - extremum_channel = np.array([extremum_channel[unit_id] for unit_id in unit_ids], dtype="int64") - d["extremum_channel"] = extremum_channel + self.extremum_channel = np.array([extremum_chan[unit_id] for unit_id in unit_ids], dtype="int64") channel_locations = templates.probe.contact_positions - - # TODO try it with real locaion - unit_locations = channel_locations[extremum_channel] - # ~ print(unit_locations) + unit_locations = channel_locations[self.extremum_channel] # distance between units import scipy - unit_distances = scipy.spatial.distance.cdist(unit_locations, unit_locations, metric="euclidean") - - # seach for closet units and unitary discriminant vector - closest_units = [] - for unit_ind, unit_id in enumerate(unit_ids): - order = np.argsort(unit_distances[unit_ind, :]) - closest_u = np.arange(unit_ids.size)[order].tolist() - closest_u.remove(unit_ind) - closest_u = np.array(closest_u[: d["num_closest"]]) - - # compute unitary discriminent vector - (chans,) = np.nonzero(d["template_sparsity"][unit_ind, :]) - template_sparse = templates_array[unit_ind, :, :][:, chans] - closest_vec = [] - # against N closets - for u in closest_u: - vec = templates_array[u, :, :][:, chans] - template_sparse - vec /= np.sum(vec**2) - closest_vec.append((u, vec)) - # against noise - closest_vec.append((None, -template_sparse / np.sum(template_sparse**2))) - - closest_units.append(closest_vec) - - d["closest_units"] = closest_units - - # distance channel from unit - import scipy - - distances = scipy.spatial.distance.cdist(channel_locations, unit_locations, metric="euclidean") - near_cluster_mask = distances < d["radius_um"] - # nearby cluster for each channel - possible_clusters_by_channel = [] + distances = scipy.spatial.distance.cdist(channel_locations, unit_locations, metric="euclidean") + near_cluster_mask = distances <= cluster_radius_um + self.possible_clusters_by_channel = [] for channel_index in range(distances.shape[0]): (cluster_inds,) = np.nonzero(near_cluster_mask[channel_index, :]) - possible_clusters_by_channel.append(cluster_inds) - - d["possible_clusters_by_channel"] = possible_clusters_by_channel - d["possible_shifts"] = np.arange(-d["sample_shift"], d["sample_shift"] + 1, dtype="int64") - - return d - - @classmethod - def serialize_method_kwargs(cls, kwargs): - kwargs = dict(kwargs) - return kwargs - - @classmethod - def unserialize_in_worker(cls, kwargs): - return kwargs - - @classmethod - def get_margin(cls, recording, kwargs): - margin = 2 * (kwargs["nbefore"] + kwargs["nafter"]) - return margin + self.possible_clusters_by_channel.append(cluster_inds) + + # precompute template norms ons sparse channels + self.template_norms = np.zeros(num_templates, dtype="float32") + for i in range(unit_ids.size): + chan_mask = self.sparsity_mask[i, :] + n = np.sum(chan_mask) + template = templates.templates_array[i, :, :n] + self.template_norms[i] = np.sum(template**2) + + # + distances = scipy.spatial.distance.cdist(channel_locations, channel_locations, metric="euclidean") + self.near_chan_mask = distances <= amplitude_fitting_radius_um + + self.possible_shifts = np.arange(-sample_shift, sample_shift + 1, dtype="int64") + + self.max_peeler_loop = max_peeler_loop + self.amplitude_limits = amplitude_limits + + self.fast_spike_detector = DetectPeakLocallyExclusive( + recording=recording, + peak_sign=peak_sign, + detect_threshold=detect_threshold, + exclude_sweep_ms=exclude_sweep_ms, + radius_um=detection_radius_um, + noise_levels=noise_levels, + ) - @classmethod - def main_function(cls, traces, d): - traces = traces.copy() + ##get prototype from best channel of each template + prototype = np.zeros(self.nbefore + self.nafter, dtype="float32") + for i in range(num_templates): + template = templates.templates_array[i, :, :] + chan_ind = np.argmax(np.abs(template[self.nbefore, :])) + if template[self.nbefore, chan_ind] != 0: + prototype += template[:, chan_ind] / np.abs(template[self.nbefore, chan_ind]) + prototype /= np.abs(prototype[self.nbefore]) + + # import matplotlib.pyplot as plt + # fig,ax = plt.subplots() + # ax.plot(prototype) + # plt.show() + + self.use_fine_detector = use_fine_detector + if self.use_fine_detector: + self.fine_spike_detector = DetectPeakMatchedFiltering( + recording=recording, + prototype=prototype, + ms_before=templates.nbefore / sr * 1000.0, + peak_sign="neg", + detect_threshold=detect_threshold, + exclude_sweep_ms=exclude_sweep_ms, + radius_um=detection_radius_um, + weight_method=dict( + z_list_um=np.array([50.0]), + sigma_3d=2.5, + mode="exponential_3d", + ), + noise_levels=None, + ) + + self.detector_margin0 = self.fast_spike_detector.get_trace_margin() + self.detector_margin1 = self.fine_spike_detector.get_trace_margin() if use_fine_detector else 0 + self.peeler_margin = max(self.nbefore, self.nafter) * 2 + self.margin = max(self.peeler_margin, self.detector_margin0, self.detector_margin1) + + def get_trace_margin(self): + return self.margin + + def compute_matching(self, traces, start_frame, end_frame, segment_index): + + # TODO check if this is usefull + residuals = traces.copy() all_spikes = [] level = 0 + spikes_prev_loop = np.zeros(0, dtype=_base_matching_dtype) + use_fine_detector_level = False while True: - spikes = _tdc_find_spikes(traces, d, level=level) - keep = spikes["cluster_index"] >= 0 - - if not np.any(keep): - break - all_spikes.append(spikes[keep]) + # print('level', level) + spikes = self._find_spikes_one_level(residuals, spikes_prev_loop, use_fine_detector_level, level) + if spikes.size > 0: + all_spikes.append(spikes) level += 1 - if level == d["num_peeler_loop"]: - break + # TODO concatenate all spikes for this instead of prev loop + spikes_prev_loop = spikes + + if (spikes.size == 0) or (level == self.max_peeler_loop): + if self.use_fine_detector and not use_fine_detector_level: + # extra loop with fine detector + use_fine_detector_level = True + level = self.max_peeler_loop - 1 + continue + else: + break if len(all_spikes) > 0: all_spikes = np.concatenate(all_spikes) order = np.argsort(all_spikes["sample_index"]) all_spikes = all_spikes[order] else: - all_spikes = np.zeros(0, dtype=spike_dtype) + all_spikes = np.zeros(0, dtype=_base_matching_dtype) return all_spikes + def _find_spikes_one_level(self, traces, spikes_prev_loop, use_fine_detector, level): -def _tdc_find_spikes(traces, d, level=0): - peak_sign = d["peak_sign"] - templates = d["templates"] - templates_short = d["templates_short"] - templates_array = templates.get_dense_templates() + # print(use_fine_detector, level) - margin = d["margin"] - possible_clusters_by_channel = d["possible_clusters_by_channel"] + # TODO change the threhold dynaically depending the level + # peak_traces = traces[self.detector_margin : -self.detector_margin, :] - peak_traces = traces[margin // 2 : -margin // 2, :] - peak_sample_ind, peak_chan_ind = DetectPeakLocallyExclusive.detect_peaks( - peak_traces, peak_sign, d["abs_thresholds"], d["peak_shift"], d["neighbours_mask"] - ) - peak_sample_ind += margin // 2 + # peak_sample_ind, peak_chan_ind = DetectPeakLocallyExclusive.detect_peaks( + # peak_traces, self.peak_sign, self.abs_thresholds, self.peak_shift, self.neighbours_mask + # ) + + if use_fine_detector: + peak_detector = self.fine_spike_detector + else: + peak_detector = self.fast_spike_detector + + detector_margin = peak_detector.get_trace_margin() + if self.peeler_margin > detector_margin: + margin_shift = self.peeler_margin - detector_margin + sl = slice(margin_shift, -margin_shift) + else: + sl = slice(None) + margin_shift = 0 + peak_traces = traces[sl, :] + (peaks,) = peak_detector.compute(peak_traces, None, None, 0, self.margin) + peak_sample_ind = peaks["sample_index"] + peak_chan_ind = peaks["channel_index"] + peak_sample_ind += margin_shift + + peak_amplitude = traces[peak_sample_ind, peak_chan_ind] + order = np.argsort(np.abs(peak_amplitude))[::-1] + peak_sample_ind = peak_sample_ind[order] + peak_chan_ind = peak_chan_ind[order] + + spikes = np.zeros(peak_sample_ind.size, dtype=_base_matching_dtype) + spikes["sample_index"] = peak_sample_ind + spikes["channel_index"] = peak_chan_ind + + distances_shift = np.zeros(self.possible_shifts.size) + + delta_sample = max(self.nbefore, self.nafter) # TODO check this maybe add margin + # neighbors_spikes_inds = get_neighbors_spikes(spikes["sample_index"], spikes["channel_index"], delta_sample, self.near_chan_mask) + + # neighbors in actual and previous level + neighbors_spikes_inds = get_neighbors_spikes( + np.concatenate([spikes["sample_index"], spikes_prev_loop["sample_index"]]), + np.concatenate([spikes["channel_index"], spikes_prev_loop["channel_index"]]), + delta_sample, + self.near_chan_mask, + ) - peak_amplitude = traces[peak_sample_ind, peak_chan_ind] - order = np.argsort(np.abs(peak_amplitude))[::-1] - peak_sample_ind = peak_sample_ind[order] - peak_chan_ind = peak_chan_ind[order] + for i in range(spikes.size): + sample_index = peak_sample_ind[i] - spikes = np.zeros(peak_sample_ind.size, dtype=spike_dtype) - spikes["sample_index"] = peak_sample_ind - spikes["channel_index"] = peak_chan_ind # TODO need to put the channel from template + chan_ind = peak_chan_ind[i] + possible_clusters = self.possible_clusters_by_channel[chan_ind] - possible_shifts = d["possible_shifts"] - distances_shift = np.zeros(possible_shifts.size) + if possible_clusters.size > 0: + cluster_index = get_most_probable_cluster( + traces, + self.sparse_templates_array_short, + possible_clusters, + sample_index, + chan_ind, + self.nbefore_short, + self.nafter_short, + self.sparsity_mask, + ) - for i in range(peak_sample_ind.size): - sample_index = peak_sample_ind[i] + # import matplotlib.pyplot as plt + # fig, ax = plt.subplots() + # chans = np.any(self.sparsity_mask[possible_clusters, :], axis=0) + # wf = traces[sample_index - self.nbefore : sample_index + self.nafter][:, chans] + # ax.plot(wf.T.flatten(), color='k') + # dense_templates_array = self.templates.get_dense_templates() + # for c_ind in possible_clusters: + # template = dense_templates_array[c_ind, :, :][:, chans] + # ax.plot(template.T.flatten()) + # if c_ind == cluster_index: + # ax.plot(template.T.flatten(), color='m', ls='--') + # ax.set_title(f"use_fine_detector{use_fine_detector} level{level}") + # plt.show() + + chan_sparsity_mask = self.sparsity_mask[cluster_index, :] - chan_ind = peak_chan_ind[i] - possible_clusters = possible_clusters_by_channel[chan_ind] + # find best shift + numba_best_shift_sparse( + traces, + self.sparse_templates_array_short[cluster_index, :, :], + sample_index, + self.nbefore_short, + self.possible_shifts, + distances_shift, + chan_sparsity_mask, + ) - if possible_clusters.size > 0: - # ~ s0 = sample_index - d['nbefore'] - # ~ s1 = sample_index + d['nafter'] + ind_shift = np.argmin(distances_shift) + shift = self.possible_shifts[ind_shift] + + # TODO DEBUG shift later + spikes["sample_index"][i] += shift + + spikes["cluster_index"][i] = cluster_index + + # check that the the same cluster is not already detected at same place + # this can happen for small template the substract forvever the traces + outer_neighbors_inds = [ind for ind in neighbors_spikes_inds[i] if ind > i and ind >= spikes.size] + is_valid = True + for b in outer_neighbors_inds: + b = b - spikes.size + if (spikes[i]["sample_index"] == spikes_prev_loop[b]["sample_index"]) and ( + spikes[i]["cluster_index"] == spikes_prev_loop[b]["cluster_index"] + ): + is_valid = False + + if is_valid: + # temporary assign a cluster to neighbors if not done yet + inner_neighbors_inds = [ind for ind in neighbors_spikes_inds[i] if (ind > i and ind < spikes.size)] + for b in inner_neighbors_inds: + spikes["cluster_index"][b] = get_most_probable_cluster( + traces, + self.sparse_templates_array_short, + possible_clusters, + spikes["sample_index"][b], + spikes["channel_index"][b], + self.nbefore_short, + self.nafter_short, + self.sparsity_mask, + ) + + amp = fit_one_amplitude_with_neighbors( + spikes[i], + spikes[inner_neighbors_inds], + traces, + self.sparsity_mask, + self.templates.templates_array, + self.template_norms, + self.nbefore, + self.nafter, + ) + + low_lim, up_lim = self.amplitude_limits + if low_lim <= amp <= up_lim: + spikes["amplitude"][i] = amp + wanted_channel_mask = np.ones(traces.shape[1], dtype=bool) # TODO move this before the loop + construct_prediction_sparse( + spikes[i : i + 1], + traces, + self.templates.templates_array, + self.sparsity_mask, + wanted_channel_mask, + self.nbefore, + additive=False, + ) + elif low_lim > amp: + # print("bad amp", amp) + spikes["cluster_index"][i] = -1 + + # import matplotlib.pyplot as plt + # fig, ax = plt.subplots() + # sample_ind = spikes["sample_index"][i] + # print(chan_sparsity_mask) + # wf = traces[sample_ind - self.nbefore : sample_ind + self.nafter][:, chan_sparsity_mask] + # dense_templates_array = self.templates.get_dense_templates() + # template = dense_templates_array[cluster_index, :, :][:, chan_sparsity_mask] + # ax.plot(wf.T.flatten()) + # ax.plot(template.T.flatten()) + # ax.plot(template.T.flatten() * amp) + # ax.set_title(f"amp{amp} use_fine_detector{use_fine_detector} level{level}") + # plt.show() + else: + # amp > up_lim + # TODO should try other cluster for the fit!! + # spikes["cluster_index"][i] = -1 + + # force amplitude to be one and need a fiting at next level + spikes["amplitude"][i] = 1 + + # print(amp) + # import matplotlib.pyplot as plt + # fig, ax = plt.subplots() + # sample_ind = spikes["sample_index"][i] + # wf = traces[sample_ind - self.nbefore : sample_ind + self.nafter][:, chan_sparsity_mask] + # dense_templates_array = self.templates.get_dense_templates() + # template = dense_templates_array[cluster_index, :, :][:, chan_sparsity_mask] + # ax.plot(wf.T.flatten()) + # ax.plot(template.T.flatten()) + # ax.plot(template.T.flatten() * amp) + # ax.set_title(f"amp{amp} use_fine_detector{use_fine_detector} level{level}") + # plt.show() + + # import matplotlib.pyplot as plt + # fig, ax = plt.subplots() + # chans = np.any(self.sparsity_mask[possible_clusters, :], axis=0) + # wf = traces[sample_index - self.nbefore : sample_index + self.nafter][:, chans] + # ax.plot(wf.T.flatten(), color='k') + # dense_templates_array = self.templates.get_dense_templates() + # for c_ind in possible_clusters: + # template = dense_templates_array[c_ind, :, :][:, chans] + # ax.plot(template.T.flatten()) + # if c_ind == cluster_index: + # ax.plot(template.T.flatten(), color='m', ls='--') + # ax.set_title(f"use_fine_detector{use_fine_detector} level{level}") + # plt.show() + + else: + # not valid because already detected + spikes["cluster_index"][i] = -1 - # ~ wf = traces[s0:s1, :] + else: + # no possible cluster in neighborhood for this channel + spikes["cluster_index"][i] = -1 + + # delta_sample = self.nbefore + self.nafter + # # TODO benchmark this and make this faster + # neighbors_spikes_inds = get_neighbors_spikes(spikes["sample_index"], spikes["channel_index"], delta_sample, self.near_chan_mask) + # for i in range(spikes.size): + # amp = fit_one_amplitude_with_neighbors(spikes[i], spikes[neighbors_spikes_inds[i]], traces, + # self.sparsity_mask, self.templates.templates_array, self.nbefore, self.nafter) + # spikes["amplitude"][i] = amp + + keep = spikes["cluster_index"] >= 0 + spikes = spikes[keep] + + # keep = (spikes["amplitude"] >= 0.7) & (spikes["amplitude"] <= 1.4) + # spikes = spikes[keep] + + # sparse_templates_array = self.templates.templates_array + # wanted_channel_mask = np.ones(traces.shape[1], dtype=bool) + # assert np.sum(wanted_channel_mask) == traces.shape[1] # TODO remove this DEBUG later + # construct_prediction_sparse(spikes, traces, sparse_templates_array, self.sparsity_mask, wanted_channel_mask, self.nbefore, additive=False) + + return spikes + + +def get_most_probable_cluster( + traces, + sparse_templates_array, + possible_clusters, + sample_index, + chan_ind, + nbefore_short, + nafter_short, + template_sparsity_mask, +): + s0 = sample_index - nbefore_short + s1 = sample_index + nafter_short + wf_short = traces[s0:s1, :] + + ## numba with cluster+channel spasity + union_channels = np.any(template_sparsity_mask[possible_clusters, :], axis=0) + distances = numba_sparse_distance( + wf_short, sparse_templates_array, template_sparsity_mask, union_channels, possible_clusters + ) - s0 = sample_index - d["nbefore_short"] - s1 = sample_index + d["nafter_short"] - wf_short = traces[s0:s1, :] + ind = np.argmin(distances) + cluster_index = possible_clusters[ind] - ## pure numpy with cluster spasity - # distances = np.sum(np.sum((templates[possible_clusters, :, :] - wf[None, : , :])**2, axis=1), axis=1) + return cluster_index - ## pure numpy with cluster+channel spasity - # union_channels, = np.nonzero(np.any(d['template_sparsity'][possible_clusters, :], axis=0)) - # distances = np.sum(np.sum((templates[possible_clusters][:, :, union_channels] - wf[: , union_channels][None, : :])**2, axis=1), axis=1) - ## numba with cluster+channel spasity - union_channels = np.any(d["template_sparsity"][possible_clusters, :], axis=0) - # distances = numba_sparse_dist(wf, templates, union_channels, possible_clusters) - distances = numba_sparse_dist(wf_short, templates_short, union_channels, possible_clusters) +def get_neighbors_spikes(sample_inds, chan_inds, delta_sample, near_chan_mask): - # DEBUG - # ~ ind = np.argmin(distances) - # ~ cluster_index = possible_clusters[ind] + neighbors_spikes_inds = [] + for i in range(sample_inds.size): - for ind in np.argsort(distances)[: d["num_template_try"]]: - cluster_index = possible_clusters[ind] + inds = np.flatnonzero(np.abs(sample_inds - sample_inds[i]) < delta_sample) + neighb = [] + for ind in inds: + if near_chan_mask[chan_inds[i], chan_inds[ind]] and i != ind: + neighb.append(ind) + neighbors_spikes_inds.append(neighb) - chan_sparsity = d["template_sparsity"][cluster_index, :] - template_sparse = templates_array[cluster_index, :, :][:, chan_sparsity] + return neighbors_spikes_inds - # find best shift - ## pure numpy version - # for s, shift in enumerate(possible_shifts): - # wf_shift = traces[s0 + shift: s1 + shift, chan_sparsity] - # distances_shift[s] = np.sum((template_sparse - wf_shift)**2) - # ind_shift = np.argmin(distances_shift) - # shift = possible_shifts[ind_shift] +def fit_one_amplitude_with_neighbors( + spike, neighbors_spikes, traces, template_sparsity_mask, sparse_templates_array, template_norms, nbefore, nafter +): + """ + Fit amplitude one spike of one spike with/without neighbors - ## numba version - numba_best_shift( - traces, - templates_array[cluster_index, :, :], - sample_index, - d["nbefore"], - possible_shifts, - distances_shift, - chan_sparsity, - ) - ind_shift = np.argmin(distances_shift) - shift = possible_shifts[ind_shift] - - sample_index = sample_index + shift - s0 = sample_index - d["nbefore"] - s1 = sample_index + d["nafter"] - wf_sparse = traces[s0:s1, chan_sparsity] - - # accept or not - - centered = wf_sparse - template_sparse - accepted = True - for other_ind, other_vector in d["closest_units"][cluster_index]: - v = np.sum(centered * other_vector) - if np.abs(v) > 0.5: - accepted = False - break - - if accepted: - # ~ if ind != np.argsort(distances)[0]: - # ~ print('not first one', np.argsort(distances), ind) - break + """ - if accepted: - amplitude = 1.0 + import scipy.linalg + + cluster_index = spike["cluster_index"] + sample_index = spike["sample_index"] + chan_sparsity_mask = template_sparsity_mask[cluster_index, :] + num_chans = np.sum(chan_sparsity_mask) + if num_chans == 0: + # protect against empty template because too sparse + return 0.0 + start, stop = sample_index - nbefore, sample_index + nafter + if neighbors_spikes is None or (neighbors_spikes.size == 0): + template = sparse_templates_array[cluster_index, :, :num_chans] + wf = traces[start:stop, :][:, chan_sparsity_mask] + # TODO precompute template norms + amplitude = np.sum(template.flatten() * wf.flatten()) / template_norms[cluster_index] + else: + + lim0 = min(start, np.min(neighbors_spikes["sample_index"]) - nbefore) + lim1 = max(stop, np.max(neighbors_spikes["sample_index"]) + nafter) + + local_traces = traces[lim0:lim1, :][:, chan_sparsity_mask] + mask_not_fitted = (neighbors_spikes["amplitude"] == 0.0) & (neighbors_spikes["cluster_index"] >= 0) + local_spike = spike.copy() + local_spike["sample_index"] -= lim0 + local_spike["amplitude"] = 1.0 + + local_neighbors_spikes = neighbors_spikes.copy() + local_neighbors_spikes["sample_index"] -= lim0 + local_neighbors_spikes["amplitude"][:] = 1.0 + + num_spikes_to_fit = 1 + np.sum(mask_not_fitted) + x = np.zeros((lim1 - lim0, num_chans, num_spikes_to_fit), dtype="float32") + wanted_channel_mask = chan_sparsity_mask + construct_prediction_sparse( + np.array([local_spike]), + x[:, :, 0], + sparse_templates_array, + template_sparsity_mask, + chan_sparsity_mask, + nbefore, + True, + ) - # remove template - template = templates_array[cluster_index, :, :] - s0 = sample_index - d["nbefore"] - s1 = sample_index + d["nafter"] - traces[s0:s1, :] -= template * amplitude + j = 1 + for i in range(neighbors_spikes.size): + if mask_not_fitted[i]: + # add to one regressor + construct_prediction_sparse( + local_neighbors_spikes[i : i + 1], + x[:, :, j], + sparse_templates_array, + template_sparsity_mask, + chan_sparsity_mask, + nbefore, + True, + ) + j += 1 + elif local_neighbors_spikes[neighbors_spikes[i]]["sample_index"] >= 0: + # remove from traces + construct_prediction_sparse( + local_neighbors_spikes[i : i + 1], + local_traces, + sparse_templates_array, + template_sparsity_mask, + chan_sparsity_mask, + nbefore, + False, + ) + # else: + # pass - else: - cluster_index = -1 - amplitude = 0.0 + x = x.reshape(-1, num_spikes_to_fit) + y = local_traces.flatten() - else: - cluster_index = -1 - amplitude = 0.0 + res = scipy.linalg.lstsq(x, y, cond=None, lapack_driver="gelsd") + amplitudes = res[0] + amplitude = amplitudes[0] - spikes["cluster_index"][i] = cluster_index - spikes["amplitude"][i] = amplitude + # import matplotlib.pyplot as plt + # x_plot = x.reshape((lim1 - lim0, num_chans, num_spikes_to_fit)).swapaxes(0, 1).reshape(-1, num_spikes_to_fit) + # pred = x @ amplitudes + # pred_plot = pred.reshape(-1, num_chans).T.flatten() + # y_plot = y.reshape(-1, num_chans).T.flatten() + # fig, ax = plt.subplots() + # ax.plot(x_plot, color='b') + # print(x_plot.shape, y_plot.shape) + # ax.plot(y_plot, color='g') + # ax.plot(pred_plot , color='r') + # ax.set_title(f"{amplitudes}") + # # ax.set_title(f"{amplitudes} {amp_dot}") + # plt.show() - return spikes + return amplitude if HAVE_NUMBA: @jit(nopython=True) - def numba_sparse_dist(wf, templates, union_channels, possible_clusters): + def construct_prediction_sparse( + spikes, traces, sparse_templates_array, template_sparsity_mask, wanted_channel_mask, nbefore, additive + ): + # must have np.sum(wanted_channel_mask) == traces.shape[0] + total_chans = wanted_channel_mask.shape[0] + for spike in spikes: + ind0 = spike["sample_index"] - nbefore + ind1 = ind0 + sparse_templates_array.shape[1] + cluster_index = spike["cluster_index"] + amplitude = spike["amplitude"] + chan_in_template = 0 + chan_in_trace = 0 + for chan in range(total_chans): + if wanted_channel_mask[chan]: + if template_sparsity_mask[cluster_index, chan]: + if additive: + traces[ind0:ind1, chan_in_trace] += ( + sparse_templates_array[cluster_index, :, chan_in_template] * amplitude + ) + else: + traces[ind0:ind1, chan_in_trace] -= ( + sparse_templates_array[cluster_index, :, chan_in_template] * amplitude + ) + chan_in_template += 1 + chan_in_trace += 1 + else: + if template_sparsity_mask[cluster_index, chan]: + chan_in_template += 1 + + @jit(nopython=True) + def numba_sparse_distance( + wf, sparse_templates_array, template_sparsity_mask, wanted_channel_mask, possible_clusters + ): """ numba implementation that compute distance from template with sparsity - handle by two separate vectors + + wf is dense + sparse_templates_array is sparse with the template_sparsity_mask """ - total_cluster, width, num_chan = templates.shape + width, total_chans = wf.shape num_cluster = possible_clusters.shape[0] distances = np.zeros((num_cluster,), dtype=np.float32) for i in prange(num_cluster): cluster_index = possible_clusters[i] sum_dist = 0.0 - for chan_ind in range(num_chan): - if union_channels[chan_ind]: - for s in range(width): - v = wf[s, chan_ind] - t = templates[cluster_index, s, chan_ind] - sum_dist += (v - t) ** 2 + chan_in_template = 0 + for chan in range(total_chans): + if wanted_channel_mask[chan]: + if template_sparsity_mask[cluster_index, chan]: + for s in range(width): + v = wf[s, chan] + t = sparse_templates_array[cluster_index, s, chan_in_template] + sum_dist += (v - t) ** 2 + chan_in_template += 1 + else: + for s in range(width): + v = wf[s, chan] + t = 0 + sum_dist += (v - t) ** 2 + else: + if template_sparsity_mask[cluster_index, chan]: + chan_in_template += 1 distances[i] = sum_dist return distances @jit(nopython=True) - def numba_best_shift(traces, template, sample_index, nbefore, possible_shifts, distances_shift, chan_sparsity): + def numba_best_shift_sparse( + traces, sparse_template, sample_index, nbefore, possible_shifts, distances_shift, chan_sparsity + ): """ numba implementation to compute several sample shift before template substraction """ - width, num_chan = template.shape + width = sparse_template.shape[0] + total_chans = traces.shape[1] n_shift = possible_shifts.size for i in range(n_shift): shift = possible_shifts[i] sum_dist = 0.0 - for chan_ind in range(num_chan): - if chan_sparsity[chan_ind]: + chan_in_template = 0 + for chan in range(total_chans): + if chan_sparsity[chan]: for s in range(width): - v = traces[sample_index - nbefore + s + shift, chan_ind] - t = template[s, chan_ind] + v = traces[sample_index - nbefore + s + shift, chan] + t = sparse_template[s, chan_in_template] sum_dist += (v - t) ** 2 + chan_in_template += 1 distances_shift[i] = sum_dist return distances_shift diff --git a/src/spikeinterface/sortingcomponents/matching/wobble.py b/src/spikeinterface/sortingcomponents/matching/wobble.py index 99de6fcd4e..59e171fe52 100644 --- a/src/spikeinterface/sortingcomponents/matching/wobble.py +++ b/src/spikeinterface/sortingcomponents/matching/wobble.py @@ -4,9 +4,19 @@ from dataclasses import dataclass from typing import List, Tuple, Optional -from .main import BaseTemplateMatchingEngine + +from .base import BaseTemplateMatching, _base_matching_dtype from spikeinterface.core.template import Templates +try: + import torch + import torch.nn.functional as F + + HAVE_TORCH = True + from torch.nn.functional import conv1d +except ImportError: + HAVE_TORCH = False + @dataclass class WobbleParameters: @@ -40,6 +50,10 @@ class WobbleParameters: Maximum value for ampltiude scaling of templates. scale_amplitudes : bool If True, scale amplitudes of templates to match spikes. + engine : string in ["numpy", "torch", "auto"]. Default "auto" + The engine to use for the convolutions + torch_device : string in ["cpu", "cuda", None]. Default "cpu" + Controls torch device if the torch engine is selected Notes ----- @@ -61,6 +75,8 @@ class WobbleParameters: scale_min: float = 0 scale_max: float = np.inf scale_amplitudes: bool = False + engine: str = "numpy" + torch_device: str = "cpu" def __post_init__(self): assert self.amplitude_variance >= 0, "amplitude_variance must be a non-negative scalar" @@ -197,8 +213,9 @@ def from_parameters_and_templates(cls, params, templates): return template_meta +# important : this is differents from the spikeinterface.core.Sparsity @dataclass -class Sparsity: +class WobbleSparsity: """Variables that describe channel sparsity. Parameters @@ -226,7 +243,7 @@ def from_parameters_and_templates(cls, params, templates): Returns ------- - sparsity : Sparsity + sparsity : WobbleSparsity Dataclass object for aggregating channel sparsity variables together. """ visible_channels = np.ptp(templates, axis=1) > params.visibility_threshold @@ -250,7 +267,7 @@ def from_templates(cls, params, templates): Returns ------- - sparsity : Sparsity + sparsity : WobbleSparsity Dataclass object for aggregating channel sparsity variables together. """ visible_channels = templates.sparsity.mask @@ -297,7 +314,7 @@ def __post_init__(self): self.temporal, self.singular, self.spatial, self.temporal_jittered = self.compressed_templates -class WobbleMatch(BaseTemplateMatchingEngine): +class WobbleMatch(BaseTemplateMatching): """Template matching method from the Paninski lab. Templates are jittered or "wobbled" in time and amplitude to capture variability in spike amplitude and @@ -331,53 +348,47 @@ class WobbleMatch(BaseTemplateMatchingEngine): - "peaks" are considered spikes if their amplitude clears the threshold parameter """ - default_params = { - "templates": None, - } - spike_dtype = [ - ("sample_index", "int64"), - ("channel_index", "int64"), - ("cluster_index", "int64"), - ("amplitude", "float64"), - ("segment_index", "int64"), - ] + # default_params = { + # "templates": None, + # } - @classmethod - def initialize_and_check_kwargs(cls, recording, kwargs): - """Initialize the objective and precompute various useful objects. - - Parameters - ---------- - recording : RecordingExtractor - The recording extractor object. - kwargs : dict - Keyword arguments for matching method. - - Returns - ------- - d : dict - Updated Keyword arguments. - """ - d = cls.default_params.copy() + def __init__( + self, + recording, + return_output=True, + parents=None, + templates=None, + parameters={}, + engine="numpy", + torch_device="cpu", + ): - required_kwargs_keys = ["templates"] - for required_key in required_kwargs_keys: - assert required_key in kwargs, f"`{required_key}` is a required key in the kwargs" + BaseTemplateMatching.__init__(self, recording, templates, return_output=True, parents=None) - parameters = kwargs.get("parameters", {}) - templates = kwargs["templates"] - assert isinstance(templates, Templates), ( - f"The templates supplied is of type {type(d['templates'])} " f"and must be a Templates" - ) - templates_array = templates.get_dense_templates().astype(np.float32, casting="safe") + templates_array = templates.get_dense_templates().astype(np.float32) # Aggregate useful parameters/variables for handy access in downstream functions params = WobbleParameters(**parameters) + + assert engine in ["numpy", "torch", "auto"], "engine should be numpy, torch or auto" + if engine == "auto": + if HAVE_TORCH: + self.engine = "torch" + else: + self.engine = "numpy" + else: + if engine == "torch": + assert HAVE_TORCH, "please install torch to use the torch engine" + self.engine = engine + + assert torch_device in ["cuda", "cpu", None] + self.torch_device = torch_device + template_meta = TemplateMetadata.from_parameters_and_templates(params, templates_array) if not templates.are_templates_sparse(): - sparsity = Sparsity.from_parameters_and_templates(params, templates_array) + sparsity = WobbleSparsity.from_parameters_and_templates(params, templates_array) else: - sparsity = Sparsity.from_templates(params, templates) + sparsity = WobbleSparsity.from_templates(params, templates) # Perform initial computations on templates necessary for computing the objective sparse_templates = np.where(sparsity.visible_channels[:, np.newaxis, :], templates_array, 0) @@ -387,91 +398,77 @@ def initialize_and_check_kwargs(cls, recording, kwargs): pairwise_convolution = convolve_templates( compressed_templates, params.jitter_factor, params.approx_rank, template_meta.jittered_indices, sparsity ) + norm_squared = compute_template_norm(sparsity.visible_channels, templates_array) + + spatial = np.moveaxis(spatial, [0, 1, 2], [1, 0, 2]) + temporal = np.moveaxis(temporal, [0, 1, 2], [1, 2, 0]) + singular = singular.T[:, :, np.newaxis] + + compressed_templates = (temporal, singular, spatial, temporal_jittered) template_data = TemplateData( compressed_templates=compressed_templates, pairwise_convolution=pairwise_convolution, norm_squared=norm_squared, ) - # Pack initial data into kwargs - kwargs["params"] = params - kwargs["template_meta"] = template_meta - kwargs["sparsity"] = sparsity - kwargs["template_data"] = template_data - kwargs["nbefore"] = templates.nbefore - kwargs["nafter"] = templates.nafter - d.update(kwargs) - return d - - @classmethod - def serialize_method_kwargs(cls, kwargs): - # This function does nothing without a waveform extractor -- candidate for refactor - kwargs = dict(kwargs) - return kwargs - - @classmethod - def unserialize_in_worker(cls, kwargs): - # This function does nothing without a waveform extractor -- candidate for refactor - return kwargs + self.is_pushed = False + self.params = params + self.template_meta = template_meta + self.sparsity = sparsity + self.template_data = template_data + self.nbefore = templates.nbefore + self.nafter = templates.nafter - @classmethod - def get_margin(cls, recording, kwargs): - """Get margin for chunking recording. + # buffer_ms = 10 + # self.margin = int(buffer_ms*1e-3 * recording.sampling_frequency) + self.margin = 300 # To ensure equivalence with spike-psvae version of the algorithm - Parameters - ---------- - recording : RecordingExtractor - The recording extractor object. - kwargs : dict - Keyword arguments for matching method. + def _push_to_torch(self): + if self.engine == "torch": + temporal, singular, spatial, temporal_jittered = self.template_data.compressed_templates + spatial = torch.as_tensor(spatial, device=self.torch_device) + singular = torch.as_tensor(singular, device=self.torch_device) + temporal = torch.as_tensor(temporal.copy(), device=self.torch_device).swapaxes(0, 1) + temporal = torch.flip(temporal, (2,)) + self.template_data.compressed_templates = (temporal, singular, spatial, temporal_jittered) + self.is_pushed = True - Returns - ------- - margin : int - Buffer in samples on each side of a chunk. - """ - buffer_ms = 10 - # margin = int(buffer_ms*1e-3 * recording.sampling_frequency) - margin = 300 # To ensure equivalence with spike-psvae version of the algorithm - return margin + def get_trace_margin(self): + return self.margin - @classmethod - def main_function(cls, traces, method_kwargs): - """Detect spikes in traces using the template matching algorithm. + def compute_matching(self, traces, start_frame, end_frame, segment_index): - Parameters - ---------- - traces : ndarray (chunk_len + 2*margin, num_channels) - Voltage traces for a chunk of the recording. - method_kwargs : dict - Keyword arguments for matching method. + if not self.is_pushed: + self._push_to_torch() - Returns - ------- - spikes : ndarray (num_spikes,) - Resulting spike train. - """ # Unpack method_kwargs - nbefore, nafter = method_kwargs["nbefore"], method_kwargs["nafter"] - template_meta = method_kwargs["template_meta"] - params = method_kwargs["params"] - sparsity = method_kwargs["sparsity"] - template_data = method_kwargs["template_data"] + # nbefore, nafter = method_kwargs["nbefore"], method_kwargs["nafter"] + # template_meta = method_kwargs["template_meta"] + # params = method_kwargs["params"] + # sparsity = method_kwargs["sparsity"] + # template_data = method_kwargs["template_data"] # Check traces assert traces.dtype == np.float32, "traces must be specified as np.float32" # Compute objective - objective = compute_objective(traces, template_data, params.approx_rank) - objective_normalized = 2 * objective - template_data.norm_squared[:, np.newaxis] + objective = compute_objective( + traces, self.template_data, self.params.approx_rank, self.engine, self.torch_device + ) + objective_normalized = 2 * objective - self.template_data.norm_squared[:, np.newaxis] # Compute spike train spike_trains, scalings, distance_metrics = [], [], [] - for i in range(params.max_iter): + for i in range(self.params.max_iter): # find peaks - spike_train, scaling, distance_metric = cls.find_peaks( - objective, objective_normalized, np.array(spike_trains), params, template_data, template_meta + spike_train, scaling, distance_metric = self.find_peaks( + objective, + objective_normalized, + np.array(spike_trains), + self.params, + self.template_data, + self.template_meta, ) if len(spike_train) == 0: break @@ -482,15 +479,22 @@ def main_function(cls, traces, method_kwargs): distance_metrics.extend(list(distance_metric)) # subtract newly detected spike train from traces (via the objective) - objective, objective_normalized = cls.subtract_spike_train( - spike_train, scaling, template_data, objective, objective_normalized, params, template_meta, sparsity + objective, objective_normalized = self.subtract_spike_train( + spike_train, + scaling, + self.template_data, + objective, + objective_normalized, + self.params, + self.template_meta, + self.sparsity, ) spike_train = np.array(spike_trains) scalings = np.array(scalings) distance_metric = np.array(distance_metrics) if len(spike_train) == 0: # no spikes found - return np.zeros(0, dtype=cls.spike_dtype) + return np.zeros(0, dtype=_base_matching_dtype) # order spike times index = np.argsort(spike_train[:, 0]) @@ -499,8 +503,8 @@ def main_function(cls, traces, method_kwargs): distance_metric = distance_metric[index] # adjust spike_train - spike_train[:, 0] += nbefore # beginning of template --> center of template - spike_train[:, 1] //= params.jitter_factor # jittered_index --> template_index + spike_train[:, 0] += self.nbefore # beginning of template --> center of template + spike_train[:, 1] //= self.params.jitter_factor # jittered_index --> template_index # TODO : Benchmark spike amplitudes # Find spike amplitudes / channels @@ -512,7 +516,7 @@ def main_function(cls, traces, method_kwargs): channel_inds.append(best_ch) # assign result to spikes array - spikes = np.zeros(spike_train.shape[0], dtype=cls.spike_dtype) + spikes = np.zeros(spike_train.shape[0], dtype=_base_matching_dtype) spikes["sample_index"] = spike_train[:, 0] spikes["cluster_index"] = spike_train[:, 1] spikes["channel_index"] = channel_inds @@ -622,7 +626,7 @@ def subtract_spike_train( Dataclass object for aggregating the parameters together. template_meta : TemplateMetadata Dataclass object for aggregating template metadata together. - sparsity : Sparsity + sparsity : WobbleSparsity Dataclass object for aggregating channel sparsity variables together. Returns @@ -837,10 +841,11 @@ def compress_templates(templates, approx_rank) -> tuple[np.ndarray, np.ndarray, temporal, singular, spatial = np.linalg.svd(templates, full_matrices=False) # Keep only the strongest components - temporal = temporal[:, :, :approx_rank] + temporal = temporal[:, :, :approx_rank].astype(np.float32) temporal = np.flip(temporal, axis=1) - singular = singular[:, :approx_rank] - spatial = spatial[:, :approx_rank, :] + singular = singular[:, :approx_rank].astype(np.float32) + spatial = spatial[:, :approx_rank, :].astype(np.float32) + return temporal, singular, spatial @@ -878,7 +883,6 @@ def upsample_and_jitter(temporal, jitter_factor, num_samples): shape_temporal_jittered = (-1, num_samples, approx_rank) temporal_jittered = np.reshape(temporal_jittered[:, shifted_index, :], shape_temporal_jittered) - temporal_jittered = np.flip(temporal_jittered, axis=1) return temporal_jittered @@ -940,7 +944,7 @@ def convolve_templates(compressed_templates, jitter_factor, approx_rank, jittere return pairwise_convolution -def compute_objective(traces, template_data, approx_rank) -> np.ndarray: +def compute_objective(traces, template_data, approx_rank, engine="numpy", torch_device=None) -> np.ndarray: """Compute objective by convolving templates with voltage traces. Parameters @@ -949,31 +953,39 @@ def compute_objective(traces, template_data, approx_rank) -> np.ndarray: Voltage traces for a chunk of the recording. template_data : TemplateData Dataclass object for aggregating template data together. - approx_rank : int - Rank of the compressed template matrices. Returns ------- objective : ndarray (template_meta.num_templates, traces.shape[0]+template_meta.num_samples-1) Template matching objective for each template. """ - temporal, singular, spatial, temporal_jittered = template_data.compressed_templates - num_templates = temporal.shape[0] - num_samples = temporal.shape[1] - objective_len = get_convolution_len(traces.shape[0], num_samples) - conv_shape = (num_templates, objective_len) - objective = np.zeros(conv_shape, dtype=np.float32) - spatial_filters = np.moveaxis(spatial[:, :approx_rank, :], [0, 1, 2], [1, 0, 2]) - temporal_filters = np.moveaxis(temporal[:, :, :approx_rank], [0, 1, 2], [1, 2, 0]) - singular_filters = singular.T[:, :, np.newaxis] - - # Filter using overlap-and-add convolution - spatially_filtered_data = np.matmul(spatial_filters, traces.T[np.newaxis, :, :]) - scaled_filtered_data = spatially_filtered_data * singular_filters - from scipy import signal + temporal, singular, spatial, _ = template_data.compressed_templates + if engine == "torch": + nt = temporal.shape[2] - 1 + num_channels = traces.shape[1] + blank = np.zeros((nt, num_channels), dtype=np.float32) + traces = np.vstack((blank, traces, blank)) + torch_traces = torch.as_tensor(traces.T[None, :, :], device=torch_device) + num_templates, num_channels = temporal.shape[0], temporal.shape[1] + num_timesteps = torch_traces.shape[2] + spatially_filtered_data = torch.matmul(spatial, torch_traces) + scaled_filtered_data = (spatially_filtered_data * singular).swapaxes(0, 1) + scaled_filtered_data_ = scaled_filtered_data.reshape(1, num_templates * num_channels, num_timesteps) + objective = conv1d(scaled_filtered_data_, temporal, groups=num_templates, padding="valid") + objective = objective.cpu().numpy()[0, :, :] + elif engine == "numpy": + num_channels, num_templates = temporal.shape[0], temporal.shape[1] + num_timesteps = temporal.shape[2] + objective_len = get_convolution_len(traces.shape[0], num_timesteps) + conv_shape = (num_templates, objective_len) + objective = np.zeros(conv_shape, dtype=np.float32) + # Filter using overlap-and-add convolution + spatially_filtered_data = np.matmul(spatial, traces.T[np.newaxis, :, :]) + scaled_filtered_data = spatially_filtered_data * singular + from scipy import signal - objective_by_rank = signal.oaconvolve(scaled_filtered_data, temporal_filters, axes=2, mode="full") - objective += np.sum(objective_by_rank, axis=0) + objective_by_rank = signal.oaconvolve(scaled_filtered_data, temporal, axes=2, mode="full") + objective += np.sum(objective_by_rank, axis=0) return objective diff --git a/src/spikeinterface/sortingcomponents/motion/dredge.py b/src/spikeinterface/sortingcomponents/motion/dredge.py index e2b6b1a2bc..bfedd4e1ee 100644 --- a/src/spikeinterface/sortingcomponents/motion/dredge.py +++ b/src/spikeinterface/sortingcomponents/motion/dredge.py @@ -22,20 +22,19 @@ """ +import gc import warnings -from tqdm.auto import trange import numpy as np - -import gc +from tqdm.auto import trange from .motion_utils import ( Motion, + get_spatial_bin_edges, get_spatial_windows, get_window_domains, - scipy_conv1d, make_2d_motion_histogram, - get_spatial_bin_edges, + scipy_conv1d, ) @@ -979,7 +978,7 @@ def xcorr_windows( if max_disp_um is None: if rigid: - max_disp_um = int(spatial_bin_edges_um.ptp() // 4) + max_disp_um = int(np.ptp(spatial_bin_edges_um) // 4) else: max_disp_um = int(win_scale_um // 4) diff --git a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index a5e6ded519..fc8ccb788b 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -6,6 +6,8 @@ from spikeinterface.preprocessing.basepreprocessor import BasePreprocessor, BasePreprocessorSegment from spikeinterface.preprocessing.filter import fix_dtype +from .motion_utils import ensure_time_bin_edges, ensure_time_bins + def correct_motion_on_peaks(peaks, peak_locations, motion, recording) -> np.ndarray: """ @@ -54,6 +56,7 @@ def interpolate_motion_on_traces( segment_index=None, channel_inds=None, interpolation_time_bin_centers_s=None, + interpolation_time_bin_edges_s=None, spatial_interpolation_method="kriging", spatial_interpolation_kwargs={}, dtype=None, @@ -61,7 +64,11 @@ def interpolate_motion_on_traces( """ Apply inverse motion with spatial interpolation on traces. - Traces can be full traces, but also waveforms snippets. + Traces can be full traces, but also waveforms snippets. Times used for looking up + displacements are controlled by interpolation_time_bin_edges_s or + interpolation_time_bin_centers_s, or fall back to the Motion object's time bins + by default; times in the recording outside these time bins use the closest edge + bin's displacement value during interpolation. Parameters ---------- @@ -80,6 +87,9 @@ def interpolate_motion_on_traces( interpolation_time_bin_centers_s : None or np.array Manually specify the time bins which the interpolation happens in for this segment. If None, these are the motion estimate's time bins. + interpolation_time_bin_edges_s : None or np.array + If present, interpolation chunks will be the time bins defined by these edges + rather than interpolation_time_bin_centers_s or the motion's bins. spatial_interpolation_method : "idw" | "kriging", default: "kriging" The spatial interpolation method used to interpolate the channel locations: * idw : Inverse Distance Weighing @@ -119,26 +129,33 @@ def interpolate_motion_on_traces( total_num_chans = channel_locations.shape[0] # -- determine the blocks of frames that will land in the same interpolation time bin - time_bins = interpolation_time_bin_centers_s - if time_bins is None: - time_bins = motion.temporal_bins_s[segment_index] - bin_s = time_bins[1] - time_bins[0] - bins_start = time_bins[0] - 0.5 * bin_s - # nearest bin center for each frame? - bin_inds = (times - bins_start) // bin_s - bin_inds = bin_inds.astype(int) + if interpolation_time_bin_centers_s is None and interpolation_time_bin_edges_s is None: + interpolation_time_bin_centers_s = motion.temporal_bins_s[segment_index] + interpolation_time_bin_edges_s = motion.temporal_bin_edges_s[segment_index] + else: + interpolation_time_bin_centers_s, interpolation_time_bin_edges_s = ensure_time_bins( + interpolation_time_bin_centers_s, interpolation_time_bin_edges_s + ) + + # bin the frame times according to the interpolation time bins. + # searchsorted(b, t, side="right") == i means that b[i-1] <= t < b[i] + # hence the -1. doing it with "left" is not as nice -- we want t==b[0] + # to lead to i=1 (rounding down). + interpolation_bin_inds = np.searchsorted(interpolation_time_bin_edges_s, times, side="right") - 1 + # the time bins may not cover the whole set of times in the recording, # so we need to clip these indices to the valid range - np.clip(bin_inds, 0, time_bins.size, out=bin_inds) + n_bins = interpolation_time_bin_edges_s.shape[0] - 1 + np.clip(interpolation_bin_inds, 0, n_bins - 1, out=interpolation_bin_inds) # -- what are the possibilities here anyway? - bins_here = np.arange(bin_inds[0], bin_inds[-1] + 1) + interpolation_bins_here = np.arange(interpolation_bin_inds[0], interpolation_bin_inds[-1] + 1) # inperpolation kernel will be the same per temporal bin interp_times = np.empty(total_num_chans) current_start_index = 0 - for bin_ind in bins_here: - bin_time = time_bins[bin_ind] + for interp_bin_ind in interpolation_bins_here: + bin_time = interpolation_time_bin_centers_s[interp_bin_ind] interp_times.fill(bin_time) channel_motions = motion.get_displacement_at_time_and_depth( interp_times, @@ -166,16 +183,17 @@ def interpolate_motion_on_traces( # ax.set_title(f"bin_ind {bin_ind} - {bin_time}s - {spatial_interpolation_method}") # plt.show() + # quick search logic to find frames corresponding to this interpolation bin in the recording # quickly find the end of this bin, which is also the start of the next next_start_index = current_start_index + np.searchsorted( - bin_inds[current_start_index:], bin_ind + 1, side="left" + interpolation_bin_inds[current_start_index:], interp_bin_ind + 1, side="left" ) - in_bin = slice(current_start_index, next_start_index) + frames_in_bin = slice(current_start_index, next_start_index) # here we use a simple np.matmul even if dirft_kernel can be super sparse. # because the speed for a sparse matmul is not so good when we disable multi threaad (due multi processing # in ChunkRecordingExecutor) - np.matmul(traces[in_bin], drift_kernel, out=traces_corrected[in_bin]) + np.matmul(traces[frames_in_bin], drift_kernel, out=traces_corrected[frames_in_bin]) current_start_index = next_start_index return traces_corrected @@ -297,6 +315,7 @@ def __init__( p=1, num_closest=3, interpolation_time_bin_centers_s=None, + interpolation_time_bin_edges_s=None, interpolation_time_bin_size_s=None, dtype=None, **spatial_interpolation_kwargs, @@ -363,9 +382,14 @@ def __init__( # handle manual interpolation_time_bin_centers_s # the case where interpolation_time_bin_size_s is set is handled per-segment below - if interpolation_time_bin_centers_s is None: + if interpolation_time_bin_centers_s is None and interpolation_time_bin_edges_s is None: if interpolation_time_bin_size_s is None: interpolation_time_bin_centers_s = motion.temporal_bins_s + interpolation_time_bin_edges_s = motion.temporal_bin_edges_s + else: + interpolation_time_bin_centers_s, interpolation_time_bin_edges_s = ensure_time_bins( + interpolation_time_bin_centers_s, interpolation_time_bin_edges_s + ) for segment_index, parent_segment in enumerate(recording._recording_segments): # finish the per-segment part of the time bin logic @@ -375,8 +399,13 @@ def __init__( t_start, t_end = parent_segment.sample_index_to_time(np.array([0, s_end])) halfbin = interpolation_time_bin_size_s / 2.0 segment_interpolation_time_bins_s = np.arange(t_start + halfbin, t_end, interpolation_time_bin_size_s) + segment_interpolation_time_bin_edges_s = np.arange( + t_start, t_end + halfbin, interpolation_time_bin_size_s + ) + assert segment_interpolation_time_bin_edges_s.shape == (segment_interpolation_time_bins_s.shape[0] + 1,) else: segment_interpolation_time_bins_s = interpolation_time_bin_centers_s[segment_index] + segment_interpolation_time_bin_edges_s = interpolation_time_bin_edges_s[segment_index] rec_segment = InterpolateMotionRecordingSegment( parent_segment, @@ -387,6 +416,7 @@ def __init__( channel_inds, segment_index, segment_interpolation_time_bins_s, + segment_interpolation_time_bin_edges_s, dtype=dtype_, ) self.add_recording_segment(rec_segment) @@ -420,6 +450,7 @@ def __init__( channel_inds, segment_index, interpolation_time_bin_centers_s, + interpolation_time_bin_edges_s, dtype="float32", ): BasePreprocessorSegment.__init__(self, parent_recording_segment) @@ -429,13 +460,11 @@ def __init__( self.channel_inds = channel_inds self.segment_index = segment_index self.interpolation_time_bin_centers_s = interpolation_time_bin_centers_s + self.interpolation_time_bin_edges_s = interpolation_time_bin_edges_s self.dtype = dtype self.motion = motion def get_traces(self, start_frame, end_frame, channel_indices): - if self.time_vector is not None: - raise NotImplementedError("InterpolateMotionRecording does not yet support recordings with time_vectors.") - if start_frame is None: start_frame = 0 if end_frame is None: @@ -453,7 +482,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): channel_inds=self.channel_inds, spatial_interpolation_method=self.spatial_interpolation_method, spatial_interpolation_kwargs=self.spatial_interpolation_kwargs, - interpolation_time_bin_centers_s=self.interpolation_time_bin_centers_s, + interpolation_time_bin_edges_s=self.interpolation_time_bin_edges_s, ) if channel_indices is not None: diff --git a/src/spikeinterface/sortingcomponents/motion/motion_utils.py b/src/spikeinterface/sortingcomponents/motion/motion_utils.py index 635624cca8..680d75f221 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_utils.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_utils.py @@ -1,5 +1,5 @@ -import warnings import json +import warnings from pathlib import Path import numpy as np @@ -54,6 +54,7 @@ def __init__(self, displacement, temporal_bins_s, spatial_bins_um, direction="y" self.direction = direction self.dim = ["x", "y", "z"].index(direction) self.check_properties() + self.temporal_bin_edges_s = [ensure_time_bin_edges(tbins) for tbins in self.temporal_bins_s] def check_properties(self): assert all(d.ndim == 2 for d in self.displacement) @@ -576,3 +577,40 @@ def make_3d_motion_histograms( motion_histograms = np.log2(1 + motion_histograms) return motion_histograms, temporal_bin_edges, spatial_bin_edges + + +def ensure_time_bins(time_bin_centers_s=None, time_bin_edges_s=None): + """Ensure that both bin edges and bin centers are present + + If either of the inputs are None but not both, the missing is reconstructed + from the present. Going from edges to centers is done by taking midpoints. + Going from centers to edges is done by taking midpoints and padding with the + left and rightmost centers. + + Parameters + ---------- + time_bin_centers_s : None or np.array + time_bin_edges_s : None or np.array + + Returns + ------- + time_bin_centers_s, time_bin_edges_s + """ + if time_bin_centers_s is None and time_bin_edges_s is None: + raise ValueError("Need at least one of time_bin_centers_s or time_bin_edges_s.") + + if time_bin_centers_s is None: + assert time_bin_edges_s.ndim == 1 and time_bin_edges_s.size >= 2 + time_bin_centers_s = 0.5 * (time_bin_edges_s[1:] + time_bin_edges_s[:-1]) + + if time_bin_edges_s is None: + time_bin_edges_s = np.empty(time_bin_centers_s.shape[0] + 1, dtype=time_bin_centers_s.dtype) + time_bin_edges_s[[0, -1]] = time_bin_centers_s[[0, -1]] + if time_bin_centers_s.size > 2: + time_bin_edges_s[1:-1] = 0.5 * (time_bin_centers_s[1:] + time_bin_centers_s[:-1]) + + return time_bin_centers_s, time_bin_edges_s + + +def ensure_time_bin_edges(time_bin_centers_s=None, time_bin_edges_s=None): + return ensure_time_bins(time_bin_centers_s, time_bin_edges_s)[1] diff --git a/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py index e022f0cc6c..e4ba870325 100644 --- a/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py @@ -1,16 +1,14 @@ -from pathlib import Path +import warnings import numpy as np -import pytest import spikeinterface.core as sc -from spikeinterface import download_dataset +from spikeinterface.sortingcomponents.motion import Motion from spikeinterface.sortingcomponents.motion.motion_interpolation import ( InterpolateMotionRecording, correct_motion_on_peaks, interpolate_motion, interpolate_motion_on_traces, ) -from spikeinterface.sortingcomponents.motion import Motion from spikeinterface.sortingcomponents.tests.common import make_dataset @@ -67,43 +65,45 @@ def test_interpolate_motion_on_traces(): times = rec.get_times()[0:30000] for method in ("kriging", "idw", "nearest"): - traces_corrected = interpolate_motion_on_traces( - traces, - times, - channel_locations, - motion, - channel_inds=None, - spatial_interpolation_method=method, - # spatial_interpolation_kwargs={}, - spatial_interpolation_kwargs={"force_extrapolate": True}, - ) - assert traces.shape == traces_corrected.shape - assert traces.dtype == traces_corrected.dtype + for interpolation_time_bin_centers_s in (None, np.linspace(*times[[0, -1]], num=3)): + traces_corrected = interpolate_motion_on_traces( + traces, + times, + channel_locations, + motion, + channel_inds=None, + spatial_interpolation_method=method, + interpolation_time_bin_centers_s=interpolation_time_bin_centers_s, + # spatial_interpolation_kwargs={}, + spatial_interpolation_kwargs={"force_extrapolate": True}, + ) + assert traces.shape == traces_corrected.shape + assert traces.dtype == traces_corrected.dtype def test_interpolation_simple(): # a recording where a 1 moves at 1 chan per second. 30 chans 10 frames. # there will be 9 chans of drift, so we add 9 chans of padding to the bottom - nt = nc0 = 10 # these need to be the same for this test - nc1 = nc0 + nc0 - 1 - traces = np.zeros((nt, nc1), dtype="float32") - traces[:, :nc0] = np.eye(nc0) + n_samples = num_chans_orig = 10 # these need to be the same for this test + num_chans_drifted = num_chans_orig + num_chans_orig - 1 + traces = np.zeros((n_samples, num_chans_drifted), dtype="float32") + traces[:, :num_chans_orig] = np.eye(num_chans_orig) rec = sc.NumpyRecording(traces, sampling_frequency=1) - rec.set_dummy_probe_from_locations(np.c_[np.zeros(nc1), np.arange(nc1)]) + rec.set_dummy_probe_from_locations(np.c_[np.zeros(num_chans_drifted), np.arange(num_chans_drifted)]) - true_motion = Motion(np.arange(nt)[:, None], 0.5 + np.arange(nt), np.zeros(1)) + true_motion = Motion(np.arange(n_samples)[:, None], 0.5 + np.arange(n_samples), np.zeros(1)) rec_corrected = interpolate_motion(rec, true_motion, spatial_interpolation_method="nearest") traces_corrected = rec_corrected.get_traces() - assert traces_corrected.shape == (nc0, nc0) - assert np.array_equal(traces_corrected[:, 0], np.ones(nt)) - assert np.array_equal(traces_corrected[:, 1:], np.zeros((nt, nc0 - 1))) + assert traces_corrected.shape == (num_chans_orig, num_chans_orig) + assert np.array_equal(traces_corrected[:, 0], np.ones(n_samples)) + assert np.array_equal(traces_corrected[:, 1:], np.zeros((n_samples, num_chans_orig - 1))) # let's try a new version where we interpolate too slowly rec_corrected = interpolate_motion( rec, true_motion, spatial_interpolation_method="nearest", num_closest=2, interpolation_time_bin_size_s=2 ) traces_corrected = rec_corrected.get_traces() - assert traces_corrected.shape == (nc0, nc0) + assert traces_corrected.shape == (num_chans_orig, num_chans_orig) # what happens with nearest here? # well... due to rounding towards the nearest even number, the motion (which at # these time bin centers is 0.5, 2.5, 4.5, ...) flips the signal's nearest @@ -115,6 +115,66 @@ def test_interpolation_simple(): assert np.all(traces_corrected[:, 2:] == 0) +def test_cross_band_interpolation(): + """Simple version of using LFP to interpolate AP data + + This also tests the time vector implementation in interpolation. + The idea is to have two recordings which are all 0s with a 1 that + moves from one channel to another after 3s. They're at different + sampling frequencies. motion estimation in one sampling frequency + applied to the other should still lead to perfect correction. + """ + from spikeinterface.sortingcomponents.motion import estimate_motion + + # sampling freqs and timing for AP and LFP recordings + fs_lfp = 50.0 + fs_ap = 300.0 + t_start = 10.0 + total_duration = 5.0 + num_samples_lfp = int(fs_lfp * total_duration) + num_samples_ap = int(fs_ap * total_duration) + t_switch = 3 + + # because interpolation uses bin centers logic, there will be a half + # bin offset at the change point in the AP recording. + halfbin_ap_lfp = int(0.5 * (fs_ap / fs_lfp)) + + # channel geometry + num_chans = 10 + geom = np.c_[np.zeros(num_chans), np.arange(num_chans)] + + # make an LFP recording which drifts a bit + traces_lfp = np.zeros((num_samples_lfp, num_chans)) + traces_lfp[: int(t_switch * fs_lfp), 5] = 1.0 + traces_lfp[int(t_switch * fs_lfp) :, 6] = 1.0 + rec_lfp = sc.NumpyRecording(traces_lfp, sampling_frequency=fs_lfp) + rec_lfp.set_dummy_probe_from_locations(geom) + + # same for AP + traces_ap = np.zeros((num_samples_ap, num_chans)) + traces_ap[: int(t_switch * fs_ap) - halfbin_ap_lfp, 5] = 1.0 + traces_ap[int(t_switch * fs_ap) - halfbin_ap_lfp :, 6] = 1.0 + rec_ap = sc.NumpyRecording(traces_ap, sampling_frequency=fs_ap) + rec_ap.set_dummy_probe_from_locations(geom) + + # set times for both, and silence the warning + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=UserWarning) + rec_lfp.set_times(t_start + np.arange(num_samples_lfp) / fs_lfp) + rec_ap.set_times(t_start + np.arange(num_samples_ap) / fs_ap) + + # estimate motion + motion = estimate_motion(rec_lfp, method="dredge_lfp", rigid=True) + + # nearest to keep it simple + rec_corrected = interpolate_motion(rec_ap, motion, spatial_interpolation_method="nearest", num_closest=2) + traces_corrected = rec_corrected.get_traces() + target = np.zeros((num_samples_ap, num_chans - 2)) + target[:, 4] = 1 + ii, jj = np.nonzero(traces_corrected) + assert np.array_equal(traces_corrected, target) + + def test_InterpolateMotionRecording(): rec, sorting = make_dataset() motion = make_fake_motion(rec) @@ -147,6 +207,7 @@ def test_InterpolateMotionRecording(): if __name__ == "__main__": # test_correct_motion_on_peaks() - # test_interpolate_motion_on_traces() - test_interpolation_simple() - test_InterpolateMotionRecording() + test_interpolate_motion_on_traces() + # test_interpolation_simple() + # test_InterpolateMotionRecording() + test_cross_band_interpolation() diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index 4fe90dd7bc..12955e2c40 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -50,7 +50,15 @@ def detect_peaks( - recording, method="locally_exclusive", pipeline_nodes=None, gather_mode="memory", folder=None, names=None, **kwargs + recording, + method="locally_exclusive", + pipeline_nodes=None, + gather_mode="memory", + folder=None, + names=None, + skip_after_n_peaks=None, + recording_slices=None, + **kwargs, ): """Peak detection based on threshold crossing in term of k x MAD. @@ -73,6 +81,13 @@ def detect_peaks( If gather_mode is "npy", the folder where the files are created. names : list List of strings with file stems associated with returns. + skip_after_n_peaks : None | int + Skip the computation after n_peaks. + This is not an exact because internally this skip is done per worker in average. + recording_slices : None | list[tuple] + Optionaly give a list of slices to run the pipeline only on some chunks of the recording. + It must be a list of (segment_index, frame_start, frame_stop). + If None (default), the function iterates over the entire duration of the recording. {method_doc} {job_doc} @@ -103,7 +118,11 @@ def detect_peaks( squeeze_output = True else: squeeze_output = False - job_name += f" + {len(pipeline_nodes)} nodes" + if len(pipeline_nodes) == 1: + plural = "" + else: + plural = "s" + job_name += f" + {len(pipeline_nodes)} node{plural}" # because node are modified inplace (insert parent) they need to copy incase # the same pipeline is run several times @@ -124,6 +143,8 @@ def detect_peaks( squeeze_output=squeeze_output, folder=folder, names=names, + skip_after_n_peaks=skip_after_n_peaks, + recording_slices=recording_slices, ) return outs @@ -592,13 +613,13 @@ class DetectPeakMatchedFiltering(PeakDetector): params_doc = ( DetectPeakByChannel.params_doc + """ - radius_um: float + radius_um : float The radius to use to select neighbour channels for locally exclusive detection. - prototype: array + prototype : array The canonical waveform of action potentials - rank : int (default 1) - The rank for SVD convolution of spatiotemporal templates with the traces - weight_method: dict + ms_before : float + The time in ms before the maximial value of the absolute prototype + weight_method : dict Parameter that should be provided to the get_convolution_weights() function in order to know how to estimate the positions. One argument is mode that could be either gaussian_2d (KS like) or exponential_3d (default) @@ -614,12 +635,12 @@ def __init__( detect_threshold=5, exclude_sweep_ms=0.1, radius_um=50, - rank=1, noise_levels=None, random_chunk_kwargs={"num_chunks_per_segment": 5}, weight_method={}, ): PeakDetector.__init__(self, recording, return_output=True) + from scipy.sparse import csr_matrix if not HAVE_NUMBA: raise ModuleNotFoundError('matched_filtering" needs numba which is not installed') @@ -631,52 +652,35 @@ def __init__( self.conv_margin = prototype.shape[0] assert peak_sign in ("both", "neg", "pos") - idx = np.argmax(np.abs(prototype)) + self.nbefore = int(ms_before * recording.sampling_frequency / 1000) if peak_sign == "neg": - assert prototype[idx] < 0, "Prototype should have a negative peak" + assert prototype[self.nbefore] < 0, "Prototype should have a negative peak" peak_sign = "pos" elif peak_sign == "pos": - assert prototype[idx] > 0, "Prototype should have a positive peak" - elif peak_sign == "both": - raise NotImplementedError("Matched filtering not working with peak_sign=both yet!") + assert prototype[self.nbefore] > 0, "Prototype should have a positive peak" self.peak_sign = peak_sign - self.nbefore = int(ms_before * recording.sampling_frequency / 1000) + self.prototype = np.flip(prototype) / np.linalg.norm(prototype) + contact_locations = recording.get_channel_locations() dist = np.linalg.norm(contact_locations[:, np.newaxis] - contact_locations[np.newaxis, :], axis=2) - weights, self.z_factors = get_convolution_weights(dist, **weight_method) - - num_channels = recording.get_num_channels() - num_templates = num_channels * len(self.z_factors) - weights = weights.reshape(num_templates, -1) - - templates = weights[:, None, :] * prototype[None, :, None] - templates -= templates.mean(axis=(1, 2))[:, None, None] - temporal, singular, spatial = np.linalg.svd(templates, full_matrices=False) - temporal = temporal[:, :, :rank] - singular = singular[:, :rank] - spatial = spatial[:, :rank, :] - templates = np.matmul(temporal * singular[:, np.newaxis, :], spatial) - norms = np.linalg.norm(templates, axis=(1, 2)) - del templates - - temporal /= norms[:, np.newaxis, np.newaxis] - temporal = np.flip(temporal, axis=1) - spatial = np.moveaxis(spatial, [0, 1, 2], [1, 0, 2]) - temporal = np.moveaxis(temporal, [0, 1, 2], [1, 2, 0]) - singular = singular.T[:, :, np.newaxis] - - self.temporal = temporal - self.spatial = spatial - self.singular = singular - + self.weights, self.z_factors = get_convolution_weights(dist, **weight_method) + self.num_z_factors = len(self.z_factors) + self.num_channels = recording.get_num_channels() + self.num_templates = self.num_channels + if peak_sign == "both": + self.weights = np.hstack((self.weights, self.weights)) + self.weights[:, self.num_templates :, :] *= -1 + self.num_templates *= 2 + + self.weights = self.weights.reshape(self.num_templates * self.num_z_factors, -1) + self.weights = csr_matrix(self.weights) random_data = get_random_data_chunks(recording, return_scaled=False, **random_chunk_kwargs) - conv_random_data = self.get_convolved_traces(random_data, temporal, spatial, singular) + conv_random_data = self.get_convolved_traces(random_data) medians = np.median(conv_random_data, axis=1) medians = medians[:, None] noise_levels = np.median(np.abs(conv_random_data - medians), axis=1) / 0.6744897501960817 self.abs_thresholds = noise_levels * detect_threshold - self._dtype = np.dtype(base_peak_dtype + [("z", "float32")]) def get_dtype(self): @@ -688,16 +692,13 @@ def get_trace_margin(self): def compute(self, traces, start_frame, end_frame, segment_index, max_margin): assert HAVE_NUMBA, "You need to install numba" - conv_traces = self.get_convolved_traces(traces, self.temporal, self.spatial, self.singular) + conv_traces = self.get_convolved_traces(traces) conv_traces /= self.abs_thresholds[:, None] conv_traces = conv_traces[:, self.conv_margin : -self.conv_margin] traces_center = conv_traces[:, self.exclude_sweep_size : -self.exclude_sweep_size] - num_z_factors = len(self.z_factors) - num_templates = traces.shape[1] - - traces_center = traces_center.reshape(num_z_factors, num_templates, traces_center.shape[1]) - conv_traces = conv_traces.reshape(num_z_factors, num_templates, conv_traces.shape[1]) + traces_center = traces_center.reshape(self.num_z_factors, self.num_templates, traces_center.shape[1]) + conv_traces = conv_traces.reshape(self.num_z_factors, self.num_templates, conv_traces.shape[1]) peak_mask = traces_center > 1 peak_mask = _numba_detect_peak_matched_filtering( @@ -708,11 +709,13 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): self.abs_thresholds, self.peak_sign, self.neighbours_mask, - num_templates, + self.num_channels, ) # Find peaks and correct for time shift z_ind, peak_chan_ind, peak_sample_ind = np.nonzero(peak_mask) + if self.peak_sign == "both": + peak_chan_ind = peak_chan_ind % self.num_channels # If we want to estimate z # peak_chan_ind = peak_chan_ind % num_channels @@ -727,8 +730,8 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): return (np.zeros(0, dtype=self._dtype),) peak_sample_ind += self.exclude_sweep_size + self.conv_margin + self.nbefore - peak_amplitude = traces[peak_sample_ind, peak_chan_ind] + local_peaks = np.zeros(peak_sample_ind.size, dtype=self._dtype) local_peaks["sample_index"] = peak_sample_ind local_peaks["channel_index"] = peak_chan_ind @@ -739,16 +742,11 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): # return is always a tuple return (local_peaks,) - def get_convolved_traces(self, traces, temporal, spatial, singular): - import scipy.signal + def get_convolved_traces(self, traces): + from scipy.signal import oaconvolve - num_timesteps, num_templates = len(traces), temporal.shape[1] - num_peaks = num_timesteps - self.conv_margin + 1 - scalar_products = np.zeros((num_templates, num_peaks), dtype=np.float32) - spatially_filtered_data = np.matmul(spatial, traces.T[np.newaxis, :, :]) - scaled_filtered_data = spatially_filtered_data * singular - objective_by_rank = scipy.signal.oaconvolve(scaled_filtered_data, temporal, axes=2, mode="valid") - scalar_products += np.sum(objective_by_rank, axis=0) + tmp = oaconvolve(self.prototype[None, :], traces.T, axes=1, mode="valid") + scalar_products = self.weights.dot(tmp) return scalar_products @@ -873,37 +871,28 @@ def _numba_detect_peak_neg( @numba.jit(nopython=True, parallel=False) def _numba_detect_peak_matched_filtering( - traces, traces_center, peak_mask, exclude_sweep_size, abs_thresholds, peak_sign, neighbours_mask, num_templates + traces, traces_center, peak_mask, exclude_sweep_size, abs_thresholds, peak_sign, neighbours_mask, num_channels ): num_z = traces_center.shape[0] + num_templates = traces_center.shape[1] for template_ind in range(num_templates): for z in range(num_z): for s in range(peak_mask.shape[2]): if not peak_mask[z, template_ind, s]: continue for neighbour in range(num_templates): - if not neighbours_mask[template_ind, neighbour]: - continue for j in range(num_z): + if not neighbours_mask[template_ind % num_channels, neighbour % num_channels]: + continue for i in range(exclude_sweep_size): - if template_ind >= neighbour: - if z >= j: - peak_mask[z, template_ind, s] &= ( - traces_center[z, template_ind, s] >= traces_center[j, neighbour, s] - ) - else: - peak_mask[z, template_ind, s] &= ( - traces_center[z, template_ind, s] > traces_center[j, neighbour, s] - ) - elif template_ind < neighbour: - if z > j: - peak_mask[z, template_ind, s] &= ( - traces_center[z, template_ind, s] > traces_center[j, neighbour, s] - ) - else: - peak_mask[z, template_ind, s] &= ( - traces_center[z, template_ind, s] > traces_center[j, neighbour, s] - ) + if template_ind >= neighbour and z >= j: + peak_mask[z, template_ind, s] &= ( + traces_center[z, template_ind, s] >= traces_center[j, neighbour, s] + ) + else: + peak_mask[z, template_ind, s] &= ( + traces_center[z, template_ind, s] > traces_center[j, neighbour, s] + ) peak_mask[z, template_ind, s] &= ( traces_center[z, template_ind, s] > traces[j, neighbour, s + i] ) diff --git a/src/spikeinterface/sortingcomponents/peak_localization.py b/src/spikeinterface/sortingcomponents/peak_localization.py index ddc8add995..1e4e0edded 100644 --- a/src/spikeinterface/sortingcomponents/peak_localization.py +++ b/src/spikeinterface/sortingcomponents/peak_localization.py @@ -33,7 +33,7 @@ get_grid_convolution_templates_and_weights, ) -from .tools import get_prototype_spike +from .tools import get_prototype_and_waveforms_from_peaks def get_localization_pipeline_nodes( @@ -73,8 +73,8 @@ def get_localization_pipeline_nodes( assert isinstance(peak_source, (PeakRetriever, SpikeRetriever)) # extract prototypes silently job_kwargs["progress_bar"] = False - method_kwargs["prototype"] = get_prototype_spike( - recording, peak_source.peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs + method_kwargs["prototype"], _, _ = get_prototype_and_waveforms_from_peaks( + recording, peaks=peak_source.peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs ) extract_dense_waveforms = ExtractDenseWaveforms( recording, parents=[peak_source], ms_before=ms_before, ms_after=ms_after, return_output=False @@ -135,7 +135,7 @@ def __init__(self, recording, return_output=True, parents=None, radius_um=75.0): self.radius_um = radius_um self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < radius_um + self.neighbours_mask = self.channel_distance <= radius_um self._kwargs["radius_um"] = radius_um def get_dtype(self): diff --git a/src/spikeinterface/sortingcomponents/tests/common.py b/src/spikeinterface/sortingcomponents/tests/common.py index 01e4445a13..d5e5b6be1b 100644 --- a/src/spikeinterface/sortingcomponents/tests/common.py +++ b/src/spikeinterface/sortingcomponents/tests/common.py @@ -21,4 +21,10 @@ def make_dataset(): noise_kwargs=dict(noise_levels=5.0, strategy="on_the_fly"), seed=2205, ) + + channel_ids_as_integers = [id for id in range(recording.get_num_channels())] + unit_ids_as_integers = [id for id in range(sorting.get_num_units())] + recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers) + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers) + return recording, sorting diff --git a/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py b/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py index fa30ba3483..341ed3426d 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py +++ b/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py @@ -22,7 +22,7 @@ ) from spikeinterface.core.node_pipeline import run_node_pipeline -from spikeinterface.sortingcomponents.tools import get_prototype_spike +from spikeinterface.sortingcomponents.tools import get_prototype_and_waveforms_from_peaks from spikeinterface.sortingcomponents.tests.common import make_dataset @@ -314,7 +314,9 @@ def test_detect_peaks_locally_exclusive_matched_filtering(recording, job_kwargs) ms_before = 1.0 ms_after = 1.0 - prototype = get_prototype_spike(recording, peaks_by_channel_np, ms_before, ms_after, **job_kwargs) + prototype, _, _ = get_prototype_and_waveforms_from_peaks( + recording, peaks=peaks_by_channel_np, ms_before=ms_before, ms_after=ms_after, **job_kwargs + ) peaks_local_mf_filtering = detect_peaks( recording, @@ -328,19 +330,38 @@ def test_detect_peaks_locally_exclusive_matched_filtering(recording, job_kwargs) ) assert len(peaks_local_mf_filtering) > len(peaks_by_channel_np) + peaks_local_mf_filtering_both = detect_peaks( + recording, + method="matched_filtering", + peak_sign="both", + detect_threshold=5, + exclude_sweep_ms=0.1, + prototype=prototype, + ms_before=1.0, + **job_kwargs, + ) + assert len(peaks_local_mf_filtering_both) > len(peaks_local_mf_filtering) + DEBUG = False if DEBUG: import matplotlib.pyplot as plt - peaks = peaks_local_mf_filtering + peaks_local = peaks_by_channel_np + peaks_mf_neg = peaks_local_mf_filtering + peaks_mf_both = peaks_local_mf_filtering_both + labels = ["locally_exclusive", "mf_neg", "mf_both"] - sample_inds, chan_inds, amplitudes = peaks["sample_index"], peaks["channel_index"], peaks["amplitude"] + fig, ax = plt.subplots() chan_offset = 500 traces = recording.get_traces().copy() traces += np.arange(traces.shape[1])[None, :] * chan_offset - fig, ax = plt.subplots() ax.plot(traces, color="k") - ax.scatter(sample_inds, chan_inds * chan_offset + amplitudes, color="r") + + for count, peaks in enumerate([peaks_local, peaks_mf_neg, peaks_mf_both]): + sample_inds, chan_inds, amplitudes = peaks["sample_index"], peaks["channel_index"], peaks["amplitude"] + ax.scatter(sample_inds, chan_inds * chan_offset + amplitudes, label=labels[count]) + + ax.legend() plt.show() diff --git a/src/spikeinterface/sortingcomponents/tests/test_template_matching.py b/src/spikeinterface/sortingcomponents/tests/test_template_matching.py index dab19809be..7cd899a3bb 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_template_matching.py +++ b/src/spikeinterface/sortingcomponents/tests/test_template_matching.py @@ -9,7 +9,8 @@ from spikeinterface.sortingcomponents.tests.common import make_dataset -job_kwargs = dict(n_jobs=-1, chunk_duration="500ms", progress_bar=True) +# job_kwargs = dict(n_jobs=-1, chunk_duration="500ms", progress_bar=True) +job_kwargs = dict(n_jobs=1, chunk_duration="500ms", progress_bar=True) def get_sorting_analyzer(): @@ -40,40 +41,48 @@ def test_find_spikes_from_templates(method, sorting_analyzer): noise_levels = sorting_analyzer.get_extension("noise_levels").get_data() # sorting_analyzer - method_kwargs_all = {"templates": templates, "noise_levels": noise_levels} + method_kwargs_all = { + "templates": templates, + } method_kwargs = {} + if method in ("naive", "tdc-peeler", "circus", "tdc-peeler2"): + method_kwargs["noise_levels"] = noise_levels + # method_kwargs["wobble"] = { # "templates": waveform_extractor.get_all_templates(), # "nbefore": waveform_extractor.nbefore, # "nafter": waveform_extractor.nafter, # } - sampling_frequency = recording.get_sampling_frequency() + method_kwargs.update(method_kwargs_all) + spikes, info = find_spikes_from_templates( + recording, method=method, method_kwargs=method_kwargs, extra_outputs=True, **job_kwargs + ) - method_kwargs_ = method_kwargs.get(method, {}) - method_kwargs_.update(method_kwargs_all) - spikes = find_spikes_from_templates(recording, method=method, method_kwargs=method_kwargs_, **job_kwargs) + # print(info) - # DEBUG = True + DEBUG = True - # if DEBUG: - # import matplotlib.pyplot as plt - # import spikeinterface.full as si + if DEBUG: + import matplotlib.pyplot as plt + import spikeinterface.full as si - # sorting_analyzer.compute("waveforms") - # sorting_analyzer.compute("templates") + sorting_analyzer.compute("waveforms") + sorting_analyzer.compute("templates") - # gt_sorting = sorting_analyzer.sorting + gt_sorting = sorting_analyzer.sorting - # sorting = NumpySorting.from_times_labels(spikes["sample_index"], spikes["cluster_index"], sampling_frequency) + sorting = NumpySorting.from_times_labels( + spikes["sample_index"], spikes["cluster_index"], recording.sampling_frequency + ) - # metrics = si.compute_quality_metrics(sorting_analyzer, metric_names=["snr"]) + ##metrics = si.compute_quality_metrics(sorting_analyzer, metric_names=["snr"]) - # fig, ax = plt.subplots() - # comp = si.compare_sorter_to_ground_truth(gt_sorting, sorting) - # si.plot_agreement_matrix(comp, ax=ax) - # ax.set_title(method) - # plt.show() + # fig, ax = plt.subplots() + # comp = si.compare_sorter_to_ground_truth(gt_sorting, sorting) + # si.plot_agreement_matrix(comp, ax=ax) + # ax.set_title(method) + # plt.show() if __name__ == "__main__": diff --git a/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_hanning_filter.py b/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_hanning_filter.py new file mode 100644 index 0000000000..1b006af429 --- /dev/null +++ b/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_hanning_filter.py @@ -0,0 +1,33 @@ +import pytest + + +from spikeinterface.sortingcomponents.waveforms.hanning_filter import HanningFilter + +from spikeinterface.core.node_pipeline import ( + PeakRetriever, + ExtractDenseWaveforms, + run_node_pipeline, +) + + +def test_hanning_filter(generated_recording, detected_peaks, chunk_executor_kwargs): + recording = generated_recording + peaks = detected_peaks + + # Parameters + ms_before = 1.0 + ms_after = 1.0 + + # Node initialization + peak_retriever = PeakRetriever(recording, peaks) + + extract_waveforms = ExtractDenseWaveforms( + recording=recording, parents=[peak_retriever], ms_before=ms_before, ms_after=ms_after, return_output=True + ) + + hanning_filter = HanningFilter(recording=recording, parents=[peak_retriever, extract_waveforms]) + pipeline_nodes = [peak_retriever, extract_waveforms, hanning_filter] + + # Extract projected waveforms and compare + waveforms, denoised_waveforms = run_node_pipeline(recording, nodes=pipeline_nodes, job_kwargs=chunk_executor_kwargs) + assert waveforms.shape == denoised_waveforms.shape diff --git a/src/spikeinterface/sortingcomponents/tests/test_wobble.py b/src/spikeinterface/sortingcomponents/tests/test_wobble.py index 5e6be02409..0d46b790ad 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_wobble.py +++ b/src/spikeinterface/sortingcomponents/tests/test_wobble.py @@ -44,7 +44,7 @@ def test_compress_templates(): elif test_case == "num_channels == num_samples": num_channels = rng.integers(1, 100) num_samples = num_channels - templates = rng.random((num_templates, num_samples, num_channels)) + templates = rng.random((num_templates, num_samples, num_channels), dtype=np.float32) full_rank = np.minimum(num_samples, num_channels) approx_rank = rng.integers(1, full_rank) @@ -66,15 +66,31 @@ def test_compress_templates(): assert np.all(singular_full >= 0) # check that svd matrices are orthonormal if applicable if num_channels > num_samples: - assert np.allclose(np.matmul(temporal_full, temporal_full.transpose(0, 2, 1)), np.eye(num_samples)) + assert np.allclose( + np.matmul(temporal_full, temporal_full.transpose(0, 2, 1)), + np.eye(num_samples, dtype=np.float32), + atol=1e-3, + ) elif num_samples > num_channels: - assert np.allclose(np.matmul(spatial_full, spatial_full.transpose(0, 2, 1)), np.eye(num_channels)) + assert np.allclose( + np.matmul(spatial_full, spatial_full.transpose(0, 2, 1)), + np.eye(num_channels, dtype=np.float32), + atol=1e-3, + ) elif num_channels == num_samples: - assert np.allclose(np.matmul(temporal_full, temporal_full.transpose(0, 2, 1)), np.eye(num_samples)) - assert np.allclose(np.matmul(spatial_full, spatial_full.transpose(0, 2, 1)), np.eye(num_channels)) + assert np.allclose( + np.matmul(temporal_full, temporal_full.transpose(0, 2, 1)), + np.eye(num_samples, dtype=np.float32), + atol=1e-3, + ) + assert np.allclose( + np.matmul(spatial_full, spatial_full.transpose(0, 2, 1)), + np.eye(num_channels, dtype=np.float32), + atol=1e-3, + ) # check that the full rank svd matrices reconstruct the original templates reconstructed_templates = np.matmul(temporal_full * singular_full[:, np.newaxis, :], spatial_full) - assert np.allclose(reconstructed_templates, templates) + assert np.allclose(reconstructed_templates, templates, atol=1e-3) def test_upsample_and_jitter(): @@ -143,7 +159,7 @@ def test_convolve_templates(): ) unit_overlap = unit_overlap > 0 unit_overlap = np.repeat(unit_overlap, jitter_factor, axis=0) - sparsity = wobble.Sparsity(visible_channels, unit_overlap) + sparsity = wobble.WobbleSparsity(visible_channels, unit_overlap) # Act: run convolve_templates pairwise_convolution = wobble.convolve_templates( @@ -211,18 +227,33 @@ def test_compute_objective(): approx_rank = rng.integers(1, num_samples) num_channels = rng.integers(1, 100) chunk_len = rng.integers(num_samples * 2, num_samples * 10) - traces = rng.random((chunk_len, num_channels)) + traces = rng.random((chunk_len, num_channels), dtype=np.float32) temporal = rng.random((num_templates, num_samples, approx_rank)) singular = rng.random((num_templates, approx_rank)) spatial = rng.random((num_templates, approx_rank, num_channels)) - compressed_templates = (temporal, singular, spatial, temporal) + + spatial_transformed = np.moveaxis(spatial, [0, 1, 2], [1, 0, 2]) + temporal_transformed = np.moveaxis(temporal, [0, 1, 2], [1, 2, 0]) + singular_transformed = singular.T[:, :, np.newaxis] + + compressed_templates_transformed = ( + temporal_transformed, + singular_transformed, + spatial_transformed, + temporal_transformed, + ) norm_squared = np.random.rand(num_templates) + + template_data_transformed = wobble.TemplateData( + compressed_templates=compressed_templates_transformed, pairwise_convolution=[], norm_squared=norm_squared + ) + # Act: run compute_objective + objective = wobble.compute_objective(traces, template_data_transformed, approx_rank, engine="numpy") + + compressed_templates = (temporal, singular, spatial, temporal) template_data = wobble.TemplateData( compressed_templates=compressed_templates, pairwise_convolution=[], norm_squared=norm_squared ) - - # Act: run compute_objective - objective = wobble.compute_objective(traces, template_data, approx_rank) expected_objective = compute_objective_loopy(traces, template_data, approx_rank) # Assert: check shape and equivalence to expected_objective diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 1501582336..1bd2381cda 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -69,25 +69,174 @@ def extract_waveform_at_max_channel(rec, peaks, ms_before=0.5, ms_after=1.5, **j return all_wfs -def get_prototype_spike(recording, peaks, ms_before=0.5, ms_after=0.5, nb_peaks=1000, **job_kwargs): +def get_prototype_and_waveforms_from_peaks( + recording, peaks, n_peaks=5000, ms_before=0.5, ms_after=0.5, seed=None, **all_kwargs +): + """ + Function to extract a prototype waveform from peaks. + + Parameters + ---------- + recording : Recording + The recording object containing the data. + peaks : numpy.array, optional + Array of peaks, if None, peaks will be detected, by default None. + n_peaks : int, optional + Number of peaks to consider, by default 5000. + ms_before : float, optional + Time in milliseconds before the peak to extract the waveform, by default 0.5. + ms_after : float, optional + Time in milliseconds after the peak to extract the waveform, by default 0.5. + seed : int or None, optional + Seed for random number generator, by default None. + **all_kwargs : dict + Additional keyword arguments for peak detection and job kwargs. + + Returns + ------- + prototype : numpy.array + The prototype waveform. + waveforms : numpy.array + The extracted waveforms for the selected peaks. + peaks : numpy.array + The selected peaks used to extract waveforms. + """ from spikeinterface.sortingcomponents.peak_selection import select_peaks + _, job_kwargs = split_job_kwargs(all_kwargs) + nbefore = int(ms_before * recording.sampling_frequency / 1000.0) nafter = int(ms_after * recording.sampling_frequency / 1000.0) - few_peaks = select_peaks(peaks, recording=recording, method="uniform", n_peaks=nb_peaks, margin=(nbefore, nafter)) - + few_peaks = select_peaks( + peaks, recording=recording, method="uniform", n_peaks=n_peaks, margin=(nbefore, nafter), seed=seed + ) waveforms = extract_waveform_at_max_channel( recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs ) + + with np.errstate(divide="ignore", invalid="ignore"): + prototype = np.nanmedian(waveforms[:, :, 0] / (np.abs(waveforms[:, nbefore, 0][:, np.newaxis])), axis=0) + + return prototype, waveforms[:, :, 0], few_peaks + + +def get_prototype_and_waveforms_from_recording( + recording, n_peaks=5000, ms_before=0.5, ms_after=0.5, seed=None, **all_kwargs +): + """ + Function to extract a prototype waveform from peaks detected on the fly. + + Parameters + ---------- + recording : Recording + The recording object containing the data. + n_peaks : int, optional + Number of peaks to consider, by default 5000. + ms_before : float, optional + Time in milliseconds before the peak to extract the waveform, by default 0.5. + ms_after : float, optional + Time in milliseconds after the peak to extract the waveform, by default 0.5. + seed : int or None, optional + Seed for random number generator, by default None. + **all_kwargs : dict + Additional keyword arguments for peak detection and job kwargs. + + Returns + ------- + prototype : numpy.array + The prototype waveform. + waveforms : numpy.array + The extracted waveforms for the selected peaks. + peaks : numpy.array + The selected peaks used to extract waveforms. + """ + from spikeinterface.sortingcomponents.peak_detection import detect_peaks + from spikeinterface.core.node_pipeline import ExtractSparseWaveforms + + detection_kwargs, job_kwargs = split_job_kwargs(all_kwargs) + + nbefore = int(ms_before * recording.sampling_frequency / 1000.0) + node = ExtractSparseWaveforms( + recording, + parents=None, + return_output=True, + ms_before=ms_before, + ms_after=ms_after, + radius_um=0, + ) + + pipeline_nodes = [node] + + recording_slices = get_shuffled_recording_slices(recording, seed=seed, **job_kwargs) + + res = detect_peaks( + recording, + pipeline_nodes=pipeline_nodes, + skip_after_n_peaks=n_peaks, + recording_slices=recording_slices, + **detection_kwargs, + **job_kwargs, + ) + + rng = np.random.RandomState(seed) + indices = rng.permutation(np.arange(len(res[0]))) + + few_peaks = res[0][indices[:n_peaks]] + waveforms = res[1][indices[:n_peaks]] + with np.errstate(divide="ignore", invalid="ignore"): prototype = np.nanmedian(waveforms[:, :, 0] / (np.abs(waveforms[:, nbefore, 0][:, np.newaxis])), axis=0) - return prototype + + return prototype, waveforms[:, :, 0], few_peaks + + +def get_prototype_and_waveforms( + recording, n_peaks=5000, peaks=None, ms_before=0.5, ms_after=0.5, seed=None, **all_kwargs +): + """ + Function to extract a prototype waveform either from peaks or from a peak detection. Note that in case + of a peak detection, the detection stops as soon as n_peaks are detected. + + Parameters + ---------- + recording : Recording + The recording object containing the data. + n_peaks : int, optional + Number of peaks to consider, by default 5000. + peaks : numpy.array, optional + Array of peaks, if None, peaks will be detected, by default None. + ms_before : float, optional + Time in milliseconds before the peak to extract the waveform, by default 0.5. + ms_after : float, optional + Time in milliseconds after the peak to extract the waveform, by default 0.5. + seed : int or None, optional + Seed for random number generator, by default None. + **all_kwargs : dict + Additional keyword arguments for peak detection and job kwargs. + + Returns + ------- + prototype : numpy.array + The prototype waveform. + waveforms : numpy.array + The extracted waveforms for the selected peaks. + peaks : numpy.array + The selected peaks used to extract waveforms. + """ + if peaks is None: + return get_prototype_and_waveforms_from_recording( + recording, n_peaks, ms_before=ms_before, ms_after=ms_after, seed=seed, **all_kwargs + ) + else: + return get_prototype_and_waveforms_from_peaks( + recording, peaks, n_peaks, ms_before=ms_before, ms_after=ms_after, seed=seed, **all_kwargs + ) def check_probe_for_drift_correction(recording, dist_x_max=60): num_channels = recording.get_num_channels() - if num_channels < 32: + if num_channels <= 32: return False else: locations = recording.get_channel_locations() @@ -151,3 +300,20 @@ def fit_sigmoid(xdata, ydata, p0=None): popt, pcov = curve_fit(sigmoid, xdata, ydata, p0) return popt + + +def get_shuffled_recording_slices(recording, seed=None, **job_kwargs): + from spikeinterface.core.job_tools import ensure_chunk_size + from spikeinterface.core.job_tools import divide_segment_into_chunks + + chunk_size = ensure_chunk_size(recording, **job_kwargs) + recording_slices = [] + for segment_index in range(recording.get_num_segments()): + num_frames = recording.get_num_samples(segment_index) + chunks = divide_segment_into_chunks(num_frames, chunk_size) + recording_slices.extend([(segment_index, frame_start, frame_stop) for frame_start, frame_stop in chunks]) + + rng = np.random.default_rng(seed) + recording_slices = rng.permutation(recording_slices) + + return recording_slices diff --git a/src/spikeinterface/sortingcomponents/waveforms/hanning_filter.py b/src/spikeinterface/sortingcomponents/waveforms/hanning_filter.py new file mode 100644 index 0000000000..e5d4962997 --- /dev/null +++ b/src/spikeinterface/sortingcomponents/waveforms/hanning_filter.py @@ -0,0 +1,50 @@ +from __future__ import annotations + + +from typing import List, Optional +import numpy as np +from spikeinterface.core import BaseRecording +from spikeinterface.core.node_pipeline import PipelineNode, WaveformsNode, find_parent_of_type + + +class HanningFilter(WaveformsNode): + """ + Hanning Filtering to remove border effects while extracting waveforms + + Parameters + ---------- + recording: BaseRecording + The recording extractor object + return_output: bool, default: True + Whether to return output from this node + parents: list of PipelineNodes, default: None + The parent nodes of this node + """ + + def __init__( + self, + recording: BaseRecording, + return_output: bool = True, + parents: Optional[List[PipelineNode]] = None, + ): + waveform_extractor = find_parent_of_type(parents, WaveformsNode) + if waveform_extractor is None: + raise TypeError(f"HanningFilter should have a single {WaveformsNode.__name__} in its parents") + + super().__init__( + recording, + waveform_extractor.ms_before, + waveform_extractor.ms_after, + return_output=return_output, + parents=parents, + ) + + hanning_before = np.hanning(2 * self.nbefore) + hanning_after = np.hanning(2 * self.nafter) + hanning = np.concatenate((hanning_before[: self.nbefore], hanning_after[self.nafter :])) + self.hanning = hanning[:, None] + self._kwargs.update(dict()) + + def compute(self, traces, peaks, waveforms): + denoised_waveforms = waveforms * self.hanning + return denoised_waveforms diff --git a/src/spikeinterface/widgets/autocorrelograms.py b/src/spikeinterface/widgets/autocorrelograms.py index c8acd93dc2..c211a277f8 100644 --- a/src/spikeinterface/widgets/autocorrelograms.py +++ b/src/spikeinterface/widgets/autocorrelograms.py @@ -9,7 +9,13 @@ class AutoCorrelogramsWidget(CrossCorrelogramsWidget): # the doc is copied form CrossCorrelogramsWidget def __init__(self, *args, **kargs): - CrossCorrelogramsWidget.__init__(self, *args, **kargs) + _ = kargs.pop("min_similarity_for_correlograms", 0.0) + CrossCorrelogramsWidget.__init__( + self, + *args, + **kargs, + min_similarity_for_correlograms=None, + ) def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt diff --git a/src/spikeinterface/widgets/crosscorrelograms.py b/src/spikeinterface/widgets/crosscorrelograms.py index cdb2041aa3..88dd803323 100644 --- a/src/spikeinterface/widgets/crosscorrelograms.py +++ b/src/spikeinterface/widgets/crosscorrelograms.py @@ -21,7 +21,8 @@ class CrossCorrelogramsWidget(BaseWidget): List of unit ids min_similarity_for_correlograms : float, default: 0.2 For sortingview backend. Threshold for computing pair-wise cross-correlograms. - If template similarity between two units is below this threshold, the cross-correlogram is not displayed + If template similarity between two units is below this threshold, the cross-correlogram is not displayed. + For auto-correlograms plot, this is automatically set to None. window_ms : float, default: 100.0 Window for CCGs in ms. If correlograms are already computed (e.g. with SortingAnalyzer), this argument is ignored diff --git a/src/spikeinterface/widgets/gtstudy.py b/src/spikeinterface/widgets/gtstudy.py index 85043d0d12..5e160a6a5a 100644 --- a/src/spikeinterface/widgets/gtstudy.py +++ b/src/spikeinterface/widgets/gtstudy.py @@ -1,127 +1,67 @@ +""" +This module will be deprecated and will be removed in 0.102.0 + +All ploting for the previous GTStudy is now centralized in spikeinterface.benchmark.benchmark_plot_tools +Please not that GTStudy is replaced by SorterStudy wich is based more generic BenchmarkStudy. +""" + from __future__ import annotations -import numpy as np +from .base import BaseWidget -from .base import BaseWidget, to_attr +import warnings class StudyRunTimesWidget(BaseWidget): """ - Plot sorter run times for a GroundTruthStudy - + Plot sorter run times for a SorterStudy. Parameters ---------- - study : GroundTruthStudy + study : SorterStudy A study object. case_keys : list or None A selection of cases to plot, if None, then all. """ - def __init__( - self, - study, - case_keys=None, - backend=None, - **backend_kwargs, - ): - if case_keys is None: - case_keys = list(study.cases.keys()) - - plot_data = dict( - study=study, run_times=study.get_run_times(case_keys), case_keys=case_keys, colors=study.get_colors() + def __init__(self, study, case_keys=None, backend=None, **backend_kwargs): + warnings.warn( + "plot_study_run_times is to be deprecated. Use spikeinterface.benchmark.benchmark_plot_tools instead." ) - + plot_data = dict(study=study, case_keys=case_keys) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) def plot_matplotlib(self, data_plot, **backend_kwargs): - import matplotlib.pyplot as plt - from .utils_matplotlib import make_mpl_figure - - dp = to_attr(data_plot) + from spikeinterface.benchmark.benchmark_plot_tools import plot_run_times - self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + plot_run_times(data_plot["study"], case_keys=data_plot["case_keys"]) - for i, key in enumerate(dp.case_keys): - label = dp.study.cases[key]["label"] - rt = dp.run_times.loc[key] - self.ax.bar(i, rt, width=0.8, label=label, facecolor=dp.colors[key]) - self.ax.set_ylabel("run time (s)") - self.ax.legend() - -# TODO : plot optionally average on some levels using group by class StudyUnitCountsWidget(BaseWidget): """ Plot unit counts for a study: "num_well_detected", "num_false_positive", "num_redundant", "num_overmerged" - Parameters ---------- - study : GroundTruthStudy + study : SorterStudy A study object. case_keys : list or None A selection of cases to plot, if None, then all. """ - def __init__( - self, - study, - case_keys=None, - backend=None, - **backend_kwargs, - ): - if case_keys is None: - case_keys = list(study.cases.keys()) - - plot_data = dict( - study=study, - count_units=study.get_count_units(case_keys=case_keys), - case_keys=case_keys, + def __init__(self, study, case_keys=None, backend=None, **backend_kwargs): + warnings.warn( + "plot_study_unit_counts is to be deprecated. Use spikeinterface.benchmark.benchmark_plot_tools instead." ) - + plot_data = dict(study=study, case_keys=case_keys) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) def plot_matplotlib(self, data_plot, **backend_kwargs): - import matplotlib.pyplot as plt - from .utils_matplotlib import make_mpl_figure - from .utils import get_some_colors + from spikeinterface.benchmark.benchmark_plot_tools import plot_unit_counts - dp = to_attr(data_plot) - - self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - - columns = dp.count_units.columns.tolist() - columns.remove("num_gt") - columns.remove("num_sorter") - - ncol = len(columns) - - colors = get_some_colors(columns, color_engine="auto", map_name="hot") - colors["num_well_detected"] = "green" - - xticklabels = [] - for i, key in enumerate(dp.case_keys): - for c, col in enumerate(columns): - x = i + 1 + c / (ncol + 1) - y = dp.count_units.loc[key, col] - if not "well_detected" in col: - y = -y - - if i == 0: - label = col.replace("num_", "").replace("_", " ").title() - else: - label = None - - self.ax.bar([x], [y], width=1 / (ncol + 2), label=label, color=colors[col]) - - xticklabels.append(dp.study.cases[key]["label"]) - - self.ax.set_xticks(np.arange(len(dp.case_keys)) + 1) - self.ax.set_xticklabels(xticklabels) - self.ax.legend() + plot_unit_counts(data_plot["study"], case_keys=data_plot["case_keys"]) class StudyPerformances(BaseWidget): @@ -154,78 +94,26 @@ def __init__( backend=None, **backend_kwargs, ): - if case_keys is None: - case_keys = list(study.cases.keys()) - + warnings.warn( + "plot_study_performances is to be deprecated. Use spikeinterface.benchmark.benchmark_plot_tools instead." + ) plot_data = dict( study=study, - perfs=study.get_performance_by_unit(case_keys=case_keys), mode=mode, performance_names=performance_names, case_keys=case_keys, ) - - self.colors = study.get_colors() - BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) def plot_matplotlib(self, data_plot, **backend_kwargs): - import matplotlib.pyplot as plt - from .utils_matplotlib import make_mpl_figure - from .utils import get_some_colors - - import pandas as pd - import seaborn as sns - - dp = to_attr(data_plot) - perfs = dp.perfs - study = dp.study - - if dp.mode in ("ordered", "snr"): - backend_kwargs["num_axes"] = len(dp.performance_names) - self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - - if dp.mode == "ordered": - for count, performance_name in enumerate(dp.performance_names): - ax = self.axes.flatten()[count] - for key in dp.case_keys: - label = study.cases[key]["label"] - val = perfs.xs(key).loc[:, performance_name].values - val = np.sort(val)[::-1] - ax.plot(val, label=label, c=self.colors[key]) - ax.set_title(performance_name) - if count == len(dp.performance_names) - 1: - ax.legend(bbox_to_anchor=(0.05, 0.05), loc="lower left", framealpha=0.8) - - elif dp.mode == "snr": - metric_name = dp.mode - for count, performance_name in enumerate(dp.performance_names): - ax = self.axes.flatten()[count] - - max_metric = 0 - for key in dp.case_keys: - x = study.get_metrics(key).loc[:, metric_name].values - y = perfs.xs(key).loc[:, performance_name].values - label = study.cases[key]["label"] - ax.scatter(x, y, s=10, label=label, color=self.colors[key]) - max_metric = max(max_metric, np.max(x)) - ax.set_title(performance_name) - ax.set_xlim(0, max_metric * 1.05) - ax.set_ylim(0, 1.05) - if count == 0: - ax.legend(loc="lower right") - - elif dp.mode == "swarm": - levels = perfs.index.names - df = pd.melt( - perfs.reset_index(), - id_vars=levels, - var_name="Metric", - value_name="Score", - value_vars=dp.performance_names, - ) - df["x"] = df.apply(lambda r: " ".join([r[col] for col in levels]), axis=1) - sns.swarmplot(data=df, x="x", y="Score", hue="Metric", dodge=True) + from spikeinterface.benchmark.benchmark_plot_tools import plot_performances + + plot_performances( + data_plot["study"], + mode=data_plot["mode"], + performance_names=data_plot["performance_names"], + case_keys=data_plot["case_keys"], + ) class StudyAgreementMatrix(BaseWidget): @@ -251,9 +139,9 @@ def __init__( backend=None, **backend_kwargs, ): - if case_keys is None: - case_keys = list(study.cases.keys()) - + warnings.warn( + "plot_study_agreement_matrix is to be deprecated. Use spikeinterface.benchmark.benchmark_plot_tools instead." + ) plot_data = dict( study=study, case_keys=case_keys, @@ -263,36 +151,9 @@ def __init__( BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) def plot_matplotlib(self, data_plot, **backend_kwargs): - import matplotlib.pyplot as plt - from .utils_matplotlib import make_mpl_figure - from .comparison import AgreementMatrixWidget - - dp = to_attr(data_plot) - study = dp.study - - backend_kwargs["num_axes"] = len(dp.case_keys) - self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - - for count, key in enumerate(dp.case_keys): - ax = self.axes.flatten()[count] - comp = study.comparisons[key] - unit_ticks = len(comp.sorting1.unit_ids) <= 16 - count_text = len(comp.sorting1.unit_ids) <= 16 - - AgreementMatrixWidget( - comp, ordered=dp.ordered, count_text=count_text, unit_ticks=unit_ticks, backend="matplotlib", ax=ax - ) - label = study.cases[key]["label"] - ax.set_xlabel(label) - - if count > 0: - ax.set_ylabel(None) - ax.set_yticks([]) - ax.set_xticks([]) + from spikeinterface.benchmark.benchmark_plot_tools import plot_agreement_matrix - # ax0 = self.axes.flatten()[0] - # for ax in self.axes.flatten()[1:]: - # ax.sharey(ax0) + plot_agreement_matrix(data_plot["study"], ordered=data_plot["ordered"], case_keys=data_plot["case_keys"]) class StudySummary(BaseWidget): @@ -320,25 +181,26 @@ def __init__( backend=None, **backend_kwargs, ): - if case_keys is None: - case_keys = list(study.cases.keys()) - plot_data = dict( - study=study, - case_keys=case_keys, + warnings.warn( + "plot_study_summary is to be deprecated. Use spikeinterface.benchmark.benchmark_plot_tools instead." ) - + plot_data = dict(study=study, case_keys=case_keys) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) def plot_matplotlib(self, data_plot, **backend_kwargs): - import matplotlib.pyplot as plt - from .utils_matplotlib import make_mpl_figure - study = data_plot["study"] case_keys = data_plot["case_keys"] - StudyPerformances(study=study, case_keys=case_keys, mode="ordered", backend="matplotlib", **backend_kwargs) - StudyPerformances(study=study, case_keys=case_keys, mode="snr", backend="matplotlib", **backend_kwargs) - StudyAgreementMatrix(study=study, case_keys=case_keys, backend="matplotlib", **backend_kwargs) - StudyRunTimesWidget(study=study, case_keys=case_keys, backend="matplotlib", **backend_kwargs) - StudyUnitCountsWidget(study=study, case_keys=case_keys, backend="matplotlib", **backend_kwargs) + from spikeinterface.benchmark.benchmark_plot_tools import ( + plot_agreement_matrix, + plot_performances, + plot_unit_counts, + plot_run_times, + ) + + plot_performances(study=study, case_keys=case_keys, mode="ordered") + plot_performances(study=study, case_keys=case_keys, mode="snr") + plot_agreement_matrix(study=study, case_keys=case_keys) + plot_run_times(study=study, case_keys=case_keys) + plot_unit_counts(study=study, case_keys=case_keys) diff --git a/src/spikeinterface/widgets/sorting_summary.py b/src/spikeinterface/widgets/sorting_summary.py index a113298851..8eada29b0e 100644 --- a/src/spikeinterface/widgets/sorting_summary.py +++ b/src/spikeinterface/widgets/sorting_summary.py @@ -2,6 +2,8 @@ import numpy as np +import warnings + from .base import BaseWidget, to_attr from .amplitudes import AmplitudesWidget @@ -14,6 +16,9 @@ from ..core import SortingAnalyzer +_default_displayed_unit_properties = ["firing_rate", "num_spikes", "x", "y", "amplitude_median", "snr", "rp_violation"] + + class SortingSummaryWidget(BaseWidget): """ Plots spike sorting summary. @@ -42,14 +47,24 @@ class SortingSummaryWidget(BaseWidget): label_choices : list or None, default: None List of labels to be added to the curation table (sortingview backend) - unit_table_properties : list or None, default: None + displayed_unit_properties : list or None, default: None List of properties to be added to the unit table. These may be drawn from the sorting extractor, and, if available, - the quality_metrics and template_metrics extensions of the SortingAnalyzer. + the quality_metrics/template_metrics/unit_locations extensions of the SortingAnalyzer. See all properties available with sorting.get_property_keys(), and, if available, analyzer.get_extension("quality_metrics").get_data().columns and analyzer.get_extension("template_metrics").get_data().columns. - (sortingview backend) + extra_unit_properties : dict or None, default: None + A dict with extra units properties to display. + curation_dict : dict or None, default: None + When curation is True, optionaly the viewer can get a previous 'curation_dict' + to continue/check previous curations on this analyzer. + In this case label_definitions must be None beacuse it is already included in the curation_dict. + (spikeinterface_gui backend) + label_definitions : dict or None, default: None + When curation is True, optionaly the user can provide a label_definitions dict. + This replaces the label_choices in the curation_format. + (spikeinterface_gui backend) """ def __init__( @@ -60,11 +75,24 @@ def __init__( max_amplitudes_per_unit=None, min_similarity_for_correlograms=0.2, curation=False, - unit_table_properties=None, + displayed_unit_properties=None, + extra_unit_properties=None, label_choices=None, + curation_dict=None, + label_definitions=None, backend=None, + unit_table_properties=None, **backend_kwargs, ): + + if unit_table_properties is not None: + warnings.warn( + "plot_sorting_summary() : unit_table_properties is deprecated, use displayed_unit_properties instead", + category=DeprecationWarning, + stacklevel=2, + ) + displayed_unit_properties = unit_table_properties + sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) self.check_extensions( sorting_analyzer, ["correlograms", "spike_amplitudes", "unit_locations", "template_similarity"] @@ -74,18 +102,29 @@ def __init__( if unit_ids is None: unit_ids = sorting.get_unit_ids() - plot_data = dict( + if curation_dict is not None and label_definitions is not None: + raise ValueError("curation_dict and label_definitions are mutualy exclusive, they cannot be not None both") + + if displayed_unit_properties is None: + displayed_unit_properties = list(_default_displayed_unit_properties) + if extra_unit_properties is not None: + displayed_unit_properties += list(extra_unit_properties.keys()) + + data_plot = dict( sorting_analyzer=sorting_analyzer, unit_ids=unit_ids, sparsity=sparsity, min_similarity_for_correlograms=min_similarity_for_correlograms, - unit_table_properties=unit_table_properties, + displayed_unit_properties=displayed_unit_properties, + extra_unit_properties=extra_unit_properties, curation=curation, label_choices=label_choices, max_amplitudes_per_unit=max_amplitudes_per_unit, + curation_dict=curation_dict, + label_definitions=label_definitions, ) - BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + BaseWidget.__init__(self, data_plot, backend=backend, **backend_kwargs) def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv @@ -156,7 +195,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs): # unit ids v_units_table = generate_unit_table_view( - dp.sorting_analyzer, dp.unit_table_properties, similarity_scores=similarity_scores + dp.sorting_analyzer, dp.displayed_unit_properties, similarity_scores=similarity_scores ) if dp.curation: @@ -190,9 +229,14 @@ def plot_sortingview(self, data_plot, **backend_kwargs): def plot_spikeinterface_gui(self, data_plot, **backend_kwargs): sorting_analyzer = data_plot["sorting_analyzer"] - import spikeinterface_gui + from spikeinterface_gui import run_mainwindow - app = spikeinterface_gui.mkQApp() - win = spikeinterface_gui.MainWindow(sorting_analyzer, curation=data_plot["curation"]) - win.show() - app.exec_() + run_mainwindow( + sorting_analyzer, + with_traces=True, + curation=data_plot["curation"], + curation_dict=data_plot["curation_dict"], + label_definitions=data_plot["label_definitions"], + extra_unit_properties=data_plot["extra_unit_properties"], + displayed_unit_properties=data_plot["displayed_unit_properties"], + ) diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 80f58f5ad9..d5ffec6dba 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -73,7 +73,9 @@ def setUpClass(cls): spike_amplitudes=dict(), unit_locations=dict(), spike_locations=dict(), - quality_metrics=dict(metric_names=["snr", "isi_violation", "num_spikes", "amplitude_cutoff"]), + quality_metrics=dict( + metric_names=["snr", "isi_violation", "num_spikes", "firing_rate", "amplitude_cutoff"] + ), template_metrics=dict(), correlograms=dict(), template_similarity=dict(), @@ -531,18 +533,29 @@ def test_plot_sorting_summary(self): possible_backends = list(sw.SortingSummaryWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_sorting_summary(self.sorting_analyzer_dense, backend=backend, **self.backend_kwargs[backend]) - sw.plot_sorting_summary(self.sorting_analyzer_sparse, backend=backend, **self.backend_kwargs[backend]) + sw.plot_sorting_summary( + self.sorting_analyzer_dense, + displayed_unit_properties=[], + backend=backend, + **self.backend_kwargs[backend], + ) + sw.plot_sorting_summary( + self.sorting_analyzer_sparse, + displayed_unit_properties=[], + backend=backend, + **self.backend_kwargs[backend], + ) sw.plot_sorting_summary( self.sorting_analyzer_sparse, sparsity=self.sparsity_strict, + displayed_unit_properties=[], backend=backend, **self.backend_kwargs[backend], ) - # add unit_properties + # select unit_properties sw.plot_sorting_summary( self.sorting_analyzer_sparse, - unit_table_properties=["firing_rate", "snr"], + displayed_unit_properties=["firing_rate", "snr"], backend=backend, **self.backend_kwargs[backend], ) @@ -550,7 +563,7 @@ def test_plot_sorting_summary(self): with self.assertWarns(UserWarning): sw.plot_sorting_summary( self.sorting_analyzer_sparse, - unit_table_properties=["missing_property"], + displayed_unit_properties=["missing_property"], backend=backend, **self.backend_kwargs[backend], ) @@ -688,9 +701,9 @@ def test_plot_motion_info(self): # mytest.test_plot_unit_presence() # mytest.test_plot_peak_activity() # mytest.test_plot_multicomparison() - # mytest.test_plot_sorting_summary() + mytest.test_plot_sorting_summary() # mytest.test_plot_motion() - mytest.test_plot_motion_info() - plt.show() + # mytest.test_plot_motion_info() + # plt.show() # TestWidgets.tearDownClass() diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index 86f2350a85..f944a4a80e 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -52,6 +52,8 @@ class TracesWidget(BaseWidget): If dict, keys should be the same as recording keys scale : float, default: 1 Scale factor for the traces + vspacing_factor : float, default: 1.5 + Vertical spacing between channels as a multiple of maximum channel amplitude with_colorbar : bool, default: True When mode is "map", a colorbar is added tile_size : int, default: 1500 @@ -82,6 +84,7 @@ def __init__( tile_size=1500, seconds_per_row=0.2, scale=1, + vspacing_factor=1.5, with_colorbar=True, add_legend=True, backend=None, @@ -142,7 +145,7 @@ def __init__( fs = rec0.get_sampling_frequency() if time_range is None: time_range = (t_start, t_start + 1.0) - time_range = np.array(time_range) + time_range = np.array(time_range, dtype=np.float64) if time_range[1] > t_end: warnings.warn( "You have selected a time after the end of the segment. The range will be clipped to " f"{t_end}" @@ -168,7 +171,7 @@ def __init__( traces0 = list_traces[0] mean_channel_std = np.mean(np.std(traces0, axis=0)) max_channel_amp = np.max(np.max(np.abs(traces0), axis=0)) - vspacing = max_channel_amp * 1.5 + vspacing = max_channel_amp * vspacing_factor if rec0.get_channel_groups() is None: color_groups = False diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index 755e60ccbf..9466110110 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -107,82 +107,90 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): # and use custum grid spec fig = self.figure nrows = 2 - ncols = 3 - if sorting_analyzer.has_extension("correlograms") or sorting_analyzer.has_extension("spike_amplitudes"): + ncols = 2 + if sorting_analyzer.has_extension("correlograms"): + ncols += 1 + if sorting_analyzer.has_extension("waveforms"): ncols += 1 if sorting_analyzer.has_extension("spike_amplitudes"): nrows += 1 gs = fig.add_gridspec(nrows, ncols) + col_counter = 0 - if sorting_analyzer.has_extension("unit_locations"): - ax1 = fig.add_subplot(gs[:2, 0]) - # UnitLocationsPlotter().do_plot(dp.plot_data_unit_locations, ax=ax1) - w = UnitLocationsWidget( - sorting_analyzer, - unit_ids=[unit_id], - unit_colors=unit_colors, - plot_legend=False, - backend="matplotlib", - ax=ax1, - **unitlocationswidget_kwargs, - ) - - unit_locations = sorting_analyzer.get_extension("unit_locations").get_data(outputs="by_unit") - unit_location = unit_locations[unit_id] - x, y = unit_location[0], unit_location[1] - ax1.set_xlim(x - 80, x + 80) - ax1.set_ylim(y - 250, y + 250) - ax1.set_xticks([]) - ax1.set_xlabel(None) - ax1.set_ylabel(None) - - ax2 = fig.add_subplot(gs[:2, 1]) - w = UnitWaveformsWidget( + # Unit locations and unit waveform plots are always generated + ax_unit_locations = fig.add_subplot(gs[:2, col_counter]) + _ = UnitLocationsWidget( + sorting_analyzer, + unit_ids=[unit_id], + unit_colors=unit_colors, + plot_legend=False, + backend="matplotlib", + ax=ax_unit_locations, + **unitlocationswidget_kwargs, + ) + col_counter += 1 + + unit_locations = sorting_analyzer.get_extension("unit_locations").get_data(outputs="by_unit") + unit_location = unit_locations[unit_id] + x, y = unit_location[0], unit_location[1] + ax_unit_locations.set_xlim(x - 80, x + 80) + ax_unit_locations.set_ylim(y - 250, y + 250) + ax_unit_locations.set_xticks([]) + ax_unit_locations.set_xlabel(None) + ax_unit_locations.set_ylabel(None) + + ax_unit_waveforms = fig.add_subplot(gs[:2, col_counter]) + _ = UnitWaveformsWidget( sorting_analyzer, unit_ids=[unit_id], unit_colors=unit_colors, plot_templates=True, + plot_waveforms=sorting_analyzer.has_extension("waveforms"), same_axis=True, plot_legend=False, sparsity=sparsity, backend="matplotlib", - ax=ax2, + ax=ax_unit_waveforms, **unitwaveformswidget_kwargs, ) + col_counter += 1 - ax2.set_title(None) + ax_unit_waveforms.set_title(None) - ax3 = fig.add_subplot(gs[:2, 2]) - UnitWaveformDensityMapWidget( - sorting_analyzer, - unit_ids=[unit_id], - unit_colors=unit_colors, - use_max_channel=True, - same_axis=False, - backend="matplotlib", - ax=ax3, - **unitwaveformdensitymapwidget_kwargs, - ) - ax3.set_ylabel(None) + if sorting_analyzer.has_extension("waveforms"): + ax_waveform_density = fig.add_subplot(gs[:2, col_counter]) + UnitWaveformDensityMapWidget( + sorting_analyzer, + unit_ids=[unit_id], + unit_colors=unit_colors, + use_max_channel=True, + same_axis=False, + backend="matplotlib", + ax=ax_waveform_density, + **unitwaveformdensitymapwidget_kwargs, + ) + col_counter += 1 + ax_waveform_density.set_ylabel(None) if sorting_analyzer.has_extension("correlograms"): - ax4 = fig.add_subplot(gs[:2, 3]) + ax_correlograms = fig.add_subplot(gs[:2, col_counter]) AutoCorrelogramsWidget( sorting_analyzer, unit_ids=[unit_id], unit_colors=unit_colors, backend="matplotlib", - ax=ax4, + ax=ax_correlograms, **autocorrelogramswidget_kwargs, ) + col_counter += 1 - ax4.set_title(None) - ax4.set_yticks([]) + ax_correlograms.set_title(None) + ax_correlograms.set_yticks([]) if sorting_analyzer.has_extension("spike_amplitudes"): - ax5 = fig.add_subplot(gs[2, :3]) - ax6 = fig.add_subplot(gs[2, 3]) - axes = np.array([ax5, ax6]) + ax_spike_amps = fig.add_subplot(gs[2, : col_counter - 1]) + ax_amps_distribution = fig.add_subplot(gs[2, col_counter - 1]) + axes = np.array([ax_spike_amps, ax_amps_distribution]) AmplitudesWidget( sorting_analyzer, unit_ids=[unit_id], diff --git a/src/spikeinterface/widgets/unit_waveforms.py b/src/spikeinterface/widgets/unit_waveforms.py index c593836061..3b31eacee5 100644 --- a/src/spikeinterface/widgets/unit_waveforms.py +++ b/src/spikeinterface/widgets/unit_waveforms.py @@ -565,7 +565,7 @@ def _update_plot(self, change): channel_locations = self.sorting_analyzer.get_channel_locations() else: unit_indices = [list(self.templates.unit_ids).index(unit_id) for unit_id in unit_ids] - templates = self.templates.templates_array[unit_indices] + templates = self.templates.get_dense_templates()[unit_indices] templates_shadings = None channel_locations = self.templates.get_channel_locations() diff --git a/src/spikeinterface/widgets/utils.py b/src/spikeinterface/widgets/utils.py index ca09cc4d8f..75c6248f0f 100644 --- a/src/spikeinterface/widgets/utils.py +++ b/src/spikeinterface/widgets/utils.py @@ -243,3 +243,93 @@ def array_to_image( output_image = np.frombuffer(image.tobytes(), dtype=np.uint8).reshape(output_image.shape) return output_image + + +def make_units_table_from_sorting(sorting, units_table=None): + """ + Make a DataFrame from sorting properties. + Only for properties with ndim=1 + + Parameters + ---------- + sorting : Sorting + The Sorting object + units_table : None | pd.DataFrame + Optionally a existing dataframe. + + Returns + ------- + units_table : pd.DataFrame + Table containing all columns. + """ + + if units_table is None: + import pandas as pd + + units_table = pd.DataFrame(index=sorting.unit_ids) + + for col in sorting.get_property_keys(): + values = sorting.get_property(col) + if values.dtype.kind in "iuUSfb" and values.ndim == 1: + units_table.loc[:, col] = values + + return units_table + + +def make_units_table_from_analyzer( + analyzer, + extra_properties=None, +): + """ + Make a DataFrame by aggregating : + * quality metrics + * template metrics + * unit_position + * sorting properties + * extra columns + + This used in sortingview and spikeinterface-gui to display the units table in a flexible way. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The SortingAnalyzer object + extra_properties : None | dict + Extra columns given as dict. + + Returns + ------- + units_table : pd.DataFrame + Table containing all columns. + """ + import pandas as pd + + all_df = [] + + if analyzer.get_extension("unit_locations") is not None: + locs = analyzer.get_extension("unit_locations").get_data() + df = pd.DataFrame(locs[:, :2], columns=["x", "y"], index=analyzer.unit_ids) + all_df.append(df) + + if analyzer.get_extension("quality_metrics") is not None: + df = analyzer.get_extension("quality_metrics").get_data() + all_df.append(df) + + if analyzer.get_extension("template_metrics") is not None: + df = analyzer.get_extension("template_metrics").get_data() + all_df.append(df) + + if len(all_df) > 0: + units_table = pd.concat(all_df, axis=1) + else: + units_table = pd.DataFrame(index=analyzer.unit_ids) + + make_units_table_from_sorting(analyzer.sorting, units_table=units_table) + + if extra_properties is not None: + for col, values in extra_properties.items(): + # the ndim = 1 is important because we need column only for the display in gui. + if values.dtype.kind in "iuUSfb" and values.ndim == 1: + units_table.loc[:, col] = values + + return units_table diff --git a/src/spikeinterface/widgets/utils_sortingview.py b/src/spikeinterface/widgets/utils_sortingview.py index a6cc562ba2..d594414287 100644 --- a/src/spikeinterface/widgets/utils_sortingview.py +++ b/src/spikeinterface/widgets/utils_sortingview.py @@ -1,10 +1,12 @@ from __future__ import annotations +from warnings import warn + import numpy as np from ..core import SortingAnalyzer, BaseSorting from ..core.core_tools import check_json -from warnings import warn +from .utils import make_units_table_from_sorting, make_units_table_from_analyzer def make_serializable(*args): @@ -50,105 +52,55 @@ def handle_display_and_url(widget, view, **backend_kwargs): def generate_unit_table_view( sorting_or_sorting_analyzer: SortingAnalyzer | BaseSorting, unit_properties: list[str] | None = None, - similarity_scores: npndarray | None = None, + similarity_scores: np.ndarray | None = None, ): import sortingview.views as vv if isinstance(sorting_or_sorting_analyzer, SortingAnalyzer): analyzer = sorting_or_sorting_analyzer + units_tables = make_units_table_from_analyzer(analyzer) sorting = analyzer.sorting else: sorting = sorting_or_sorting_analyzer - analyzer = None - - # Find available unit properties from all sources - sorting_props = list(sorting.get_property_keys()) - if analyzer is not None: - if analyzer.get_extension("quality_metrics") is not None: - qm_props = list(analyzer.get_extension("quality_metrics").get_data().columns) - qm_data = analyzer.get_extension("quality_metrics").get_data() - else: - qm_props = [] - if analyzer.get_extension("template_metrics") is not None: - tm_props = list(analyzer.get_extension("template_metrics").get_data().columns) - tm_data = analyzer.get_extension("template_metrics").get_data() - else: - tm_props = [] - # Check for any overlaps and warn user if any - all_props = sorting_props + qm_props + tm_props - else: - all_props = sorting_props - qm_props = [] - tm_props = [] - qm_data = None - tm_data = None - - overlap_props = [prop for prop in all_props if all_props.count(prop) > 1] - if len(overlap_props) > 0: - warn( - f"Warning: Overlapping properties found in sorting, quality_metrics, and template_metrics: {overlap_props}" - ) - - # Get unit properties + units_tables = make_units_table_from_sorting(sorting) + # analyzer = None + if unit_properties is None: ut_columns = [] ut_rows = [vv.UnitsTableRow(unit_id=u, values={}) for u in sorting.unit_ids] else: + # keep only selected columns + unit_properties = np.array(unit_properties) + keep = np.isin(unit_properties, units_tables.columns) + if sum(keep) < len(unit_properties): + warn(f"Some unit properties are not in the sorting: {unit_properties[~keep]}") + unit_properties = unit_properties[keep] + units_tables = units_tables.loc[:, unit_properties] + + dtype_convertor = {"i": "int", "u": "int", "f": "float", "U": "str", "S": "str", "b": "bool"} + ut_columns = [] + for col in unit_properties: + values = units_tables[col].to_numpy() + if values.dtype.kind in dtype_convertor: + txt_dtype = dtype_convertor[values.dtype.kind] + ut_columns.append(vv.UnitsTableColumn(key=col, label=col, dtype=txt_dtype)) + ut_rows = [] - values = {} - valid_unit_properties = [] - - # Create columns for each property - for prop_name in unit_properties: - - # Get property values from correct location - if prop_name in sorting_props: - property_values = sorting.get_property(prop_name) - elif prop_name in qm_props: - property_values = qm_data[prop_name].to_numpy() - elif prop_name in tm_props: - property_values = tm_data[prop_name].to_numpy() - else: - warn(f"Property '{prop_name}' not found in sorting, quality_metrics, or template_metrics") - continue - - # make dtype available - val0 = np.array(property_values[0]) - if val0.dtype.kind in ("i", "u"): - dtype = "int" - elif val0.dtype.kind in ("U", "S"): - dtype = "str" - elif val0.dtype.kind == "f": - dtype = "float" - elif val0.dtype.kind == "b": - dtype = "bool" - else: - warn(f"Unsupported dtype {val0.dtype} for property {prop_name}. Skipping") - continue - ut_columns.append(vv.UnitsTableColumn(key=prop_name, label=prop_name, dtype=dtype)) - valid_unit_properties.append(prop_name) - - # Create rows for each unit - for ui, unit in enumerate(sorting.unit_ids): - for prop_name in valid_unit_properties: - - # Get property values from correct location - if prop_name in sorting_props: - property_values = sorting.get_property(prop_name) - elif prop_name in qm_props: - property_values = qm_data[prop_name].to_numpy() - elif prop_name in tm_props: - property_values = tm_data[prop_name].to_numpy() - - # Check for NaN values and round floats - val0 = np.array(property_values[0]) - if val0.dtype.kind == "f": - if np.isnan(property_values[ui]): - continue - property_values[ui] = np.format_float_positional(property_values[ui], precision=4, fractional=False) - values[prop_name] = property_values[ui] - ut_rows.append(vv.UnitsTableRow(unit_id=unit, values=check_json(values))) + for unit_index, unit_id in enumerate(sorting.unit_ids): + row_values = {} + for col in unit_properties: + values = units_tables[col].to_numpy() + if values.dtype.kind in dtype_convertor: + value = values[unit_index] + if values.dtype.kind == "f": + # Check for NaN values and round floats + if np.isnan(values[unit_index]): + continue + value = np.format_float_positional(value, precision=4, fractional=False) + row_values[col] = value + ut_rows.append(vv.UnitsTableRow(unit_id=unit_id, values=check_json(row_values))) v_units_table = vv.UnitsTable(rows=ut_rows, columns=ut_columns, similarity_scores=similarity_scores) + return v_units_table