diff --git a/.github/actions/build-test-environment/action.yml b/.github/actions/build-test-environment/action.yml index 723e8a702f..c2524d2c16 100644 --- a/.github/actions/build-test-environment/action.yml +++ b/.github/actions/build-test-environment/action.yml @@ -1,41 +1,20 @@ name: Install packages description: This action installs the package and its dependencies for testing -inputs: - python-version: - description: 'Python version to set up' - required: false - os: - description: 'Operating system to set up' - required: false - runs: using: "composite" steps: - name: Install dependencies run: | - sudo apt install git git config --global user.email "CI@example.com" git config --global user.name "CI Almighty" - python -m venv ${{ github.workspace }}/test_env # Environment used in the caching step - python -m pip install -U pip # Official recommended way - source ${{ github.workspace }}/test_env/bin/activate pip install tabulate # This produces summaries at the end pip install -e .[test,extractors,streaming_extractors,test_extractors,full] shell: bash - - name: Force installation of latest dev from key-packages when running dev (not release) - run: | - source ${{ github.workspace }}/test_env/bin/activate - spikeinterface_is_dev_version=$(python -c "import spikeinterface; print(spikeinterface.DEV_MODE)") - if [ $spikeinterface_is_dev_version = "True" ]; then - echo "Running spikeinterface dev version" - pip install --no-cache-dir git+https://github.com/NeuralEnsemble/python-neo - pip install --no-cache-dir git+https://github.com/SpikeInterface/probeinterface - fi - echo "Running tests for release, using pyproject.toml versions of neo and probeinterface" + - name: Install git-annex shell: bash - - name: git-annex install run: | + pip install datalad-installer wget https://downloads.kitenet.net/git-annex/linux/current/git-annex-standalone-amd64.tar.gz mkdir /home/runner/work/installation mv git-annex-standalone-amd64.tar.gz /home/runner/work/installation/ @@ -44,4 +23,14 @@ runs: tar xvzf git-annex-standalone-amd64.tar.gz echo "$(pwd)/git-annex.linux" >> $GITHUB_PATH cd $workdir + git config --global filter.annex.process "git-annex filter-process" # recommended for efficiency + - name: Force installation of latest dev from key-packages when running dev (not release) + run: | + spikeinterface_is_dev_version=$(python -c "import spikeinterface; print(spikeinterface.DEV_MODE)") + if [ $spikeinterface_is_dev_version = "True" ]; then + echo "Running spikeinterface dev version" + pip install --no-cache-dir git+https://github.com/NeuralEnsemble/python-neo + pip install --no-cache-dir git+https://github.com/SpikeInterface/probeinterface + fi + echo "Running tests for release, using pyproject.toml versions of neo and probeinterface" shell: bash diff --git a/.github/workflows/all-tests.yml b/.github/workflows/all-tests.yml index 3f9c2f0f63..cfab49ef09 100644 --- a/.github/workflows/all-tests.yml +++ b/.github/workflows/all-tests.yml @@ -47,7 +47,7 @@ jobs: echo "$file was changed" done - - name: Set testing environment # This decides which tests are run and whether to install especial dependencies + - name: Set testing environment # This decides which tests are run and whether to install special dependencies shell: bash run: | changed_files="${{ steps.changed-files.outputs.all_changed_files }}" diff --git a/.github/workflows/full-test-with-codecov.yml b/.github/workflows/full-test-with-codecov.yml index 407c614ebf..f8ed2aa7a9 100644 --- a/.github/workflows/full-test-with-codecov.yml +++ b/.github/workflows/full-test-with-codecov.yml @@ -45,7 +45,6 @@ jobs: env: HDF5_PLUGIN_PATH: ${{ github.workspace }}/hdf5_plugin_path_maxwell run: | - source ${{ github.workspace }}/test_env/bin/activate pytest -m "not sorters_external" --cov=./ --cov-report xml:./coverage.xml -vv -ra --durations=0 | tee report_full.txt; test ${PIPESTATUS[0]} -eq 0 || exit 1 echo "# Timing profile of full tests" >> $GITHUB_STEP_SUMMARY python ./.github/scripts/build_job_summary.py report_full.txt >> $GITHUB_STEP_SUMMARY diff --git a/doc/api.rst b/doc/api.rst index 6bb9b39091..eb9a61eb9c 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -346,6 +346,9 @@ spikeinterface.curation .. autofunction:: remove_redundant_units .. autofunction:: remove_duplicated_spikes .. autofunction:: remove_excess_spikes + .. autofunction:: load_model + .. autofunction:: auto_label_units + .. autofunction:: train_model Deprecated ~~~~~~~~~~ diff --git a/doc/conf.py b/doc/conf.py index e3d58ca8f2..d229dc18ee 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -119,12 +119,15 @@ # for sphinx gallery plugin sphinx_gallery_conf = { - 'only_warn_on_example_error': True, + # This is the default but including here explicitly. Should build all docs and fail on gallery failures only. + # other option would be abort_on_example_error, but this fails on first failure. So we decided against this. + 'only_warn_on_example_error': False, 'examples_dirs': ['../examples/tutorials'], 'gallery_dirs': ['tutorials' ], # path where to save gallery generated examples 'subsection_order': ExplicitOrder([ '../examples/tutorials/core', '../examples/tutorials/extractors', + '../examples/tutorials/curation', '../examples/tutorials/qualitymetrics', '../examples/tutorials/comparison', '../examples/tutorials/widgets', diff --git a/doc/development/development.rst b/doc/development/development.rst index a91818a271..1638c41243 100644 --- a/doc/development/development.rst +++ b/doc/development/development.rst @@ -213,6 +213,25 @@ We use Sphinx to build the documentation. To build the documentation locally, yo This will build the documentation in the :code:`doc/_build/html` folder. You can open the :code:`index.html` file in your browser to see the documentation. +Adding new documentation +------------------------ + +Documentation can be added as a +`sphinx-gallery `_ +python file ('tutorials') +or a +`sphinx rst `_ +file (all other sections). + +To add a new tutorial, add your ``.py`` file to ``spikeinterface/examples``. +Then, update the ``spikeinterface/doc/tutorials_custom_index.rst`` file +to make a new card linking to the page and an optional image. See +``tutorials_custom_index.rst`` header for more information. + +For other sections, write your documentation in ``.rst`` format and add +the page to the appropriate ``index.rst`` file found in the relevant +folder (e.g. ``how_to/index.rst``). + How to run code coverage locally -------------------------------- To run code coverage locally, you can use the following command: diff --git a/doc/get_started/quickstart.rst b/doc/get_started/quickstart.rst index 3d45606a78..1349802ce5 100644 --- a/doc/get_started/quickstart.rst +++ b/doc/get_started/quickstart.rst @@ -673,7 +673,7 @@ compute quality metrics (some quality metrics require certain extensions 'min_spikes': 0, 'window_size_s': 1}, 'snr': {'peak_mode': 'extremum', 'peak_sign': 'neg'}, - 'synchrony': {'synchrony_sizes': (2, 4, 8)}} + 'synchrony': {} Since the recording is very short, let’s change some parameters to diff --git a/doc/how_to/auto_curation_prediction.rst b/doc/how_to/auto_curation_prediction.rst new file mode 100644 index 0000000000..9b1612ec12 --- /dev/null +++ b/doc/how_to/auto_curation_prediction.rst @@ -0,0 +1,43 @@ +How to use a trained model to predict the curation labels +========================================================= + +For a more detailed guide to using trained models, `read our tutorial here +`_). + +There is a Collection of models for automated curation available on the +`SpikeInterface HuggingFace page `_. + +We'll apply the model ``toy_tetrode_model`` from ``SpikeInterface`` on a SortingAnalyzer +called ``sorting_analyzer``. We assume that the quality and template metrics have +already been computed. + +We need to pass the ``sorting_analyzer``, the ``repo_id`` (which is just the part of the +repo's URL after huggingface.co/) and that we trust the model. + +.. code:: + + from spikeinterface.curation import auto_label_units + + labels_and_probabilities = auto_label_units( + sorting_analyzer = sorting_analyzer, + repo_id = "SpikeInterface/toy_tetrode_model", + trust_model = True + ) + +If you have a local directory containing the model in a ``skops`` file you can use this to +create the labels: + +.. code:: + + labels_and_probabilities = si.auto_label_units( + sorting_analyzer = sorting_analyzer, + model_folder = "my_folder_with_a_model_in_it", + ) + +The returned labels are a dictionary of model's predictions and it's confidence. These +are also saved as a property of your ``sorting_analyzer`` and can be accessed like so: + +.. code:: + + labels = sorting_analyzer.sorting.get_property("classifier_label") + probabilities = sorting_analyzer.sorting.get_property("classifier_probability") diff --git a/doc/how_to/auto_curation_training.rst b/doc/how_to/auto_curation_training.rst new file mode 100644 index 0000000000..20ab57d284 --- /dev/null +++ b/doc/how_to/auto_curation_training.rst @@ -0,0 +1,58 @@ +How to train a model to predict curation labels +=============================================== + +A full tutorial for model-based curation can be found `here `_. + +Here, we assume that you have: + +* Two SortingAnalyzers called ``analyzer_1`` and + ``analyzer_2``, and have calculated some template and quality metrics for both +* Manually curated labels for the units in each analyzer, in lists called + ``analyzer_1_labels`` and ``analyzer_2_labels``. If you have used phy, the lists can + be accessed using ``curated_labels = analyzer.sorting.get_property("quality")``. + +With these objects calculated, you can train a model as follows + +.. code:: + + from spikeinterface.curation import train_model + + analyzer_list = [analyzer_1, analyzer_2] + labels_list = [analyzer_1_labels, analyzer_2_labels] + output_folder = "/path/to/output_folder" + + trainer = train_model( + mode="analyzers", + labels=labels_list, + analyzers=analyzer_list, + output_folder=output_folder, + metric_names=None, # Set if you want to use a subset of metrics, defaults to all calculated quality and template metrics + imputation_strategies=None, # Default is all available imputation strategies + scaling_techniques=None, # Default is all available scaling techniques + classifiers=None, # Defaults to Random Forest classifier only - we usually find this gives the best results, but a range of classifiers is available + seed=None, # Set a seed for reproducibility + ) + + +The trainer tries several models and chooses the most accurate one. This model and +some metadata are stored in the ``output_folder``, which can later be loaded using the +``load_model`` function (`more details `_). +We can also access the model, which is an sklearn ``Pipeline``, from the trainer object + +.. code:: + + best_model = trainer.best_pipeline + + +The training function can also be run in “csv” mode, if you prefer to +store metrics in as .csv files. If the target labels are stored as a column in +the file, you can point to these with the ``target_label`` parameter + +.. code:: + + trainer = train_model( + mode="csv", + metrics_paths = ["/path/to/csv_file_1", "/path/to/csv_file_2"], + target_label = "my_label", + output_folder=output_folder, + ) diff --git a/doc/how_to/index.rst b/doc/how_to/index.rst index 5d7eae9003..7f79156a3b 100644 --- a/doc/how_to/index.rst +++ b/doc/how_to/index.rst @@ -15,3 +15,5 @@ Guides on how to solve specific, short problems in SpikeInterface. Learn how to. load_your_data_into_sorting benchmark_with_hybrid_recordings drift_with_lfp + auto_curation_training + auto_curation_prediction diff --git a/doc/images/files_screen.png b/doc/images/files_screen.png new file mode 100644 index 0000000000..ef2b5b0873 Binary files /dev/null and b/doc/images/files_screen.png differ diff --git a/doc/images/hf-logo.svg b/doc/images/hf-logo.svg new file mode 100644 index 0000000000..ab959d165f --- /dev/null +++ b/doc/images/hf-logo.svg @@ -0,0 +1,8 @@ + + + + + + + + diff --git a/doc/images/initial_model_screen.png b/doc/images/initial_model_screen.png new file mode 100644 index 0000000000..b01c4248a6 Binary files /dev/null and b/doc/images/initial_model_screen.png differ diff --git a/doc/index.rst b/doc/index.rst index ed443e4200..e6d8aa3fea 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -51,7 +51,7 @@ SpikeInterface is made of several modules to deal with different aspects of the overview get_started/index - tutorials/index + tutorials_custom_index how_to/index modules/index api diff --git a/doc/modules/core.rst b/doc/modules/core.rst index 8aa1815a55..5df9a7e6b1 100644 --- a/doc/modules/core.rst +++ b/doc/modules/core.rst @@ -385,7 +385,7 @@ and merging unit groups. sorting_analyzer_select = sorting_analyzer.select_units(unit_ids=[0, 1, 2, 3]) sorting_analyzer_remove = sorting_analyzer.remove_units(remove_unit_ids=[0]) - sorting_analyzer_merge = sorting_analyzer.merge_units([0, 1], [2, 3]) + sorting_analyzer_merge = sorting_analyzer.merge_units([[0, 1], [2, 3]]) All computed extensions will be automatically propagated or merged when curating. Please refer to the :ref:`modules/curation:Curation module` documentation for more information. diff --git a/doc/modules/curation.rst b/doc/modules/curation.rst index d115b33e4a..37de992806 100644 --- a/doc/modules/curation.rst +++ b/doc/modules/curation.rst @@ -88,7 +88,7 @@ The ``censored_period_ms`` parameter is the time window in milliseconds to consi The :py:func:`~spikeinterface.curation.remove_redundand_units` function removes redundant units from the sorting output. Redundant units are units that share over a certain percentage of spikes, by default 80%. -The function can acto both on a ``BaseSorting`` or a ``SortingAnalyzer`` object. +The function can act both on a ``BaseSorting`` or a ``SortingAnalyzer`` object. .. code-block:: python @@ -102,13 +102,18 @@ The function can acto both on a ``BaseSorting`` or a ``SortingAnalyzer`` object. ) # remove redundant units from SortingAnalyzer object - clean_sorting_analyzer = remove_redundant_units( + # note this returns a cleaned sorting + clean_sorting = remove_redundant_units( sorting_analyzer, duplicate_threshold=0.9, remove_strategy="min_shift" ) + # in order to have a SortingAnalyer with only the non-redundant units one must + # select the designed units remembering to give format and folder if one wants + # a persistent SortingAnalyzer. + clean_sorting_analyzer = sorting_analyzer.select_units(clean_sorting.unit_ids) -We recommend usinf the ``SortingAnalyzer`` approach, since the ``min_shift`` strategy keeps +We recommend using the ``SortingAnalyzer`` approach, since the ``min_shift`` strategy keeps the unit (among the redundant ones), with a better template alignment. diff --git a/doc/modules/qualitymetrics/synchrony.rst b/doc/modules/qualitymetrics/synchrony.rst index d244fd0c0f..696dacbd3c 100644 --- a/doc/modules/qualitymetrics/synchrony.rst +++ b/doc/modules/qualitymetrics/synchrony.rst @@ -12,7 +12,7 @@ trains. This way synchronous events can be found both in multi-unit and single-u Complexity is calculated by counting the number of spikes (i.e. non-empty bins) that occur at the same sample index, within and across spike trains. -Synchrony metrics can be computed for different synchrony sizes (>1), defining the number of simultaneous spikes to count. +Synchrony metrics are computed for 2, 4 and 8 synchronous spikes. @@ -29,7 +29,7 @@ Example code import spikeinterface.qualitymetrics as sqm # Combine a sorting and recording into a sorting_analyzer - synchrony = sqm.compute_synchrony_metrics(sorting_analyzer=sorting_analyzer synchrony_sizes=(2, 4, 8)) + synchrony = sqm.compute_synchrony_metrics(sorting_analyzer=sorting_analyzer) # synchrony is a tuple of dicts with the synchrony metrics for each unit diff --git a/doc/tutorials_custom_index.rst b/doc/tutorials_custom_index.rst new file mode 100644 index 0000000000..82f2c06eed --- /dev/null +++ b/doc/tutorials_custom_index.rst @@ -0,0 +1,229 @@ +.. This page provides a custom index to the 'Tutorials' page, rather than the default sphinx-gallery +.. generated page. The benefits of this are flexibility in design and inclusion of non-sphinx files in the index. +.. +.. To update this index with a new documentation page +.. 1) Copy the grid-item-card and associated ".. raw:: html" section. +.. 2) change :link: to a link to your page. If this is an `.rst` file, point to the rst file directly. +.. If it is a sphinx-gallery generated file, format the path as separated by underscore and prefix `sphx_glr`, +.. pointing to the .py file. e.g. `tutorials/my/page.py` -> `sphx_glr_tutorials_my_page.py +.. 3) Change :img-top: to point to the thumbnail image of your choosing. You can point to images generated +.. in the sphinx gallery page if you wish. +.. 4) In the `html` section, change the `default-title` to your pages title and `hover-content` to the subtitle. + +:orphan: + +Tutorials +============ + +Longer form tutorials about using SpikeInterface. Many of these are downloadable +as notebooks or Python scripts so that you can "code along" with the tutorials. + +If you're new to SpikeInterface, we recommend trying out the +:ref:`get_started/quickstart:Quickstart tutorial` first. + +Updating from legacy +-------------------- + +.. toctree:: + :maxdepth: 1 + + tutorials/waveform_extractor_to_sorting_analyzer + +Core tutorials +-------------- + +These tutorials focus on the :py:mod:`spikeinterface.core` module. + +.. grid:: 1 2 2 3 + :gutter: 2 + + .. grid-item-card:: Recording objects + :link-type: ref + :link: sphx_glr_tutorials_core_plot_1_recording_extractor.py + :img-top: /tutorials/core/images/thumb/sphx_glr_plot_1_recording_extractor_thumb.png + :img-alt: Recording objects + :class-card: gallery-card + :text-align: center + + .. grid-item-card:: Sorting objects + :link-type: ref + :link: sphx_glr_tutorials_core_plot_2_sorting_extractor.py + :img-top: /tutorials/core/images/thumb/sphx_glr_plot_2_sorting_extractor_thumb.png + :img-alt: Sorting objects + :class-card: gallery-card + :text-align: center + + .. grid-item-card:: Handling probe information + :link-type: ref + :link: sphx_glr_tutorials_core_plot_3_handle_probe_info.py + :img-top: /tutorials/core/images/thumb/sphx_glr_plot_3_handle_probe_info_thumb.png + :img-alt: Handling probe information + :class-card: gallery-card + :text-align: center + + .. grid-item-card:: SortingAnalyzer + :link-type: ref + :link: sphx_glr_tutorials_core_plot_4_sorting_analyzer.py + :img-top: /tutorials/core/images/thumb/sphx_glr_plot_4_sorting_analyzer_thumb.png + :img-alt: SortingAnalyzer + :class-card: gallery-card + :text-align: center + + .. grid-item-card:: Append and/or concatenate segments + :link-type: ref + :link: sphx_glr_tutorials_core_plot_5_append_concatenate_segments.py + :img-top: /tutorials/core/images/thumb/sphx_glr_plot_5_append_concatenate_segments_thumb.png + :img-alt: Append/Concatenate segments + :class-card: gallery-card + :text-align: center + + .. grid-item-card:: Handle time information + :link-type: ref + :link: sphx_glr_tutorials_core_plot_6_handle_times.py + :img-top: /tutorials/core/images/thumb/sphx_glr_plot_6_handle_times_thumb.png + :img-alt: Handle time information + :class-card: gallery-card + :text-align: center + +Extractors tutorials +-------------------- + +The :py:mod:`spikeinterface.extractors` module is designed to load and save recorded and sorted data, and to handle probe information. + +.. grid:: 1 2 2 3 + :gutter: 2 + + .. grid-item-card:: Read various formats + :link-type: ref + :link: sphx_glr_tutorials_extractors_plot_1_read_various_formats.py + :img-top: /tutorials/extractors/images/thumb/sphx_glr_plot_1_read_various_formats_thumb.png + :img-alt: Read various formats + :class-card: gallery-card + :text-align: center + + .. grid-item-card:: Working with unscaled traces + :link-type: ref + :link: sphx_glr_tutorials_extractors_plot_2_working_with_unscaled_traces.py + :img-top: /tutorials/extractors/images/thumb/sphx_glr_plot_2_working_with_unscaled_traces_thumb.png + :img-alt: Unscaled traces + :class-card: gallery-card + :text-align: center + +Quality metrics tutorial +------------------------ + +The :code:`spikeinterface.qualitymetrics` module allows users to compute various quality metrics to assess the goodness of a spike sorting output. + +.. grid:: 1 2 2 3 + :gutter: 2 + + .. grid-item-card:: Quality Metrics + :link-type: ref + :link: sphx_glr_tutorials_qualitymetrics_plot_3_quality_metrics.py + :img-top: /tutorials/qualitymetrics/images/thumb/sphx_glr_plot_3_quality_metrics_thumb.png + :img-alt: Quality Metrics + :class-card: gallery-card + :text-align: center + + .. grid-item-card:: Curation Tutorial + :link-type: ref + :link: sphx_glr_tutorials_qualitymetrics_plot_4_curation.py + :img-top: /tutorials/qualitymetrics/images/thumb/sphx_glr_plot_4_curation_thumb.png + :img-alt: Curation Tutorial + :class-card: gallery-card + :text-align: center + +Automated curation tutorials +---------------------------- + +Learn how to curate your units using a trained machine learning model. Or how to create +and share your own model. + +.. grid:: 1 2 2 3 + :gutter: 2 + + .. grid-item-card:: Model-based curation + :link-type: ref + :link: sphx_glr_tutorials_curation_plot_1_automated_curation.py + :img-top: /tutorials/curation/images/sphx_glr_plot_1_automated_curation_002.png + :img-alt: Model-based curation + :class-card: gallery-card + :text-align: center + + .. grid-item-card:: Train your own model + :link-type: ref + :link: sphx_glr_tutorials_curation_plot_2_train_a_model.py + :img-top: /tutorials/curation/images/thumb/sphx_glr_plot_2_train_a_model_thumb.png + :img-alt: Train your own model + :class-card: gallery-card + :text-align: center + + .. grid-item-card:: Upload your model to HuggingFaceHub + :link-type: ref + :link: sphx_glr_tutorials_curation_plot_3_upload_a_model.py + :img-top: /images/hf-logo.svg + :img-alt: Upload your model + :class-card: gallery-card + :text-align: center + +Comparison tutorial +------------------- + +The :code:`spikeinterface.comparison` module allows you to compare sorter outputs or benchmark against ground truth. + +.. grid:: 1 2 2 3 + :gutter: 2 + + .. grid-item-card:: Sorter Comparison + :link-type: ref + :link: sphx_glr_tutorials_comparison_plot_5_comparison_sorter_weaknesses.py + :img-top: /tutorials/comparison/images/thumb/sphx_glr_plot_5_comparison_sorter_weaknesses_thumb.png + :img-alt: Sorter Comparison + :class-card: gallery-card + :text-align: center + +Widgets tutorials +----------------- + +The :code:`widgets` module contains several plotting routines (widgets) for visualizing recordings, sorting data, probe layout, and more. + +.. grid:: 1 2 2 3 + :gutter: 2 + + .. grid-item-card:: RecordingExtractor Widgets + :link-type: ref + :link: sphx_glr_tutorials_widgets_plot_1_rec_gallery.py + :img-top: /tutorials/widgets/images/thumb/sphx_glr_plot_1_rec_gallery_thumb.png + :img-alt: Recording Widgets + :class-card: gallery-card + :text-align: center + + .. grid-item-card:: SortingExtractor Widgets + :link-type: ref + :link: sphx_glr_tutorials_widgets_plot_2_sort_gallery.py + :img-top: /tutorials/widgets/images/thumb/sphx_glr_plot_2_sort_gallery_thumb.png + :img-alt: Sorting Widgets + :class-card: gallery-card + :text-align: center + + .. grid-item-card:: Waveforms Widgets + :link-type: ref + :link: sphx_glr_tutorials_widgets_plot_3_waveforms_gallery.py + :img-top: /tutorials/widgets/images/thumb/sphx_glr_plot_3_waveforms_gallery_thumb.png + :img-alt: Waveforms Widgets + :class-card: gallery-card + :text-align: center + + .. grid-item-card:: Peaks Widgets + :link-type: ref + :link: sphx_glr_tutorials_widgets_plot_4_peaks_gallery.py + :img-top: /tutorials/widgets/images/thumb/sphx_glr_plot_4_peaks_gallery_thumb.png + :img-alt: Peaks Widgets + :class-card: gallery-card + :text-align: center + +Download All Examples +--------------------- + +- :download:`Download all examples in Python source code ` +- :download:`Download all examples in Jupyter notebooks ` diff --git a/examples/tutorials/curation/README.rst b/examples/tutorials/curation/README.rst new file mode 100644 index 0000000000..0f64179e65 --- /dev/null +++ b/examples/tutorials/curation/README.rst @@ -0,0 +1,5 @@ +Curation tutorials +------------------ + +Learn how to use models to automatically curated your sorted data, or generate models +based on your own curation. diff --git a/examples/tutorials/curation/plot_1_automated_curation.py b/examples/tutorials/curation/plot_1_automated_curation.py new file mode 100644 index 0000000000..e88b0973df --- /dev/null +++ b/examples/tutorials/curation/plot_1_automated_curation.py @@ -0,0 +1,287 @@ +""" +Model-based curation tutorial +============================= + +Sorters are not perfect. They output excellent units, as well as noisy ones, and ones that +should be split or merged. Hence one should curate the generated units. Historically, this +has been done using laborious manual curation. An alternative is to use automated methods +based on metrics which quantify features of the units. In spikeinterface these are the +quality metrics and the template metrics. A simple approach is to use thresholding: +only accept units whose metrics pass a certain quality threshold. Another approach is to +take one (or more) manually labelled sortings, whose metrics have been computed, and train +a machine learning model to predict labels. + +This notebook provides a step-by-step guide on how to take a machine learning model that +someone else has trained and use it to curate your own spike sorted output. SpikeInterface +also provides the tools to train your own model, +`which you can learn about here `_. + +We'll download a toy model and use it to label our sorted data. We start by importing some packages +""" + +import warnings +warnings.filterwarnings("ignore") +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt + +import spikeinterface.core as si +import spikeinterface.curation as sc +import spikeinterface.widgets as sw + +# note: you can use more cores using e.g. +# si.set_global_jobs_kwargs(n_jobs = 8) + +############################################################################## +# Download a pretrained model +# --------------------------- +# +# Let's download a pretrained model from `Hugging Face `_ (HF), +# a model sharing platform focused on AI and ML models and datasets. The +# ``load_model`` function allows us to download directly from HF, or use a model in a local +# folder. The function downloads the model and saves it in a temporary folder and returns a +# model and some metadata about the model. + +model, model_info = sc.load_model( + repo_id = "SpikeInterface/toy_tetrode_model", + trusted = ['numpy.dtype'] +) + + +############################################################################## +# This model was trained on artifically generated tetrode data. There are also models trained +# on real data, like the one discussed `below <#A-model-trained-on-real-Neuropixels-data>`_. +# Each model object has a nice html representation, which will appear if you're using a Jupyter notebook. + +model + +############################################################################## +# This tells us more information about the model. The one we've just downloaded was trained used +# a ``RandomForestClassifier```. You can also discover this information by running +# ``model.get_params()``. The model object (an `sklearn Pipeline `_) also contains information +# about which metrics were used to compute the model. We can access it from the model (or from the model_info) + +print(model.feature_names_in_) + +############################################################################## +# Hence, to use this model we need to create a ``sorting_analyzer`` with all these metrics computed. +# We'll do this by generating a recording and sorting, creating a sorting analyzer and computing a +# bunch of extensions. Follow these links for more info on `recordings `_, `sortings `_, `sorting analyzers `_ +# and `extensions `_. + +recording, sorting = si.generate_ground_truth_recording(num_channels=4, seed=4, num_units=10) +sorting_analyzer = si.create_sorting_analyzer(sorting=sorting, recording=recording) +sorting_analyzer.compute(['noise_levels','random_spikes','waveforms','templates','spike_locations','spike_amplitudes','correlograms','principal_components','quality_metrics','template_metrics']) +sorting_analyzer.compute('template_metrics', include_multi_channel_metrics=True) + +############################################################################## +# This sorting_analyzer now contains the required quality metrics and template metrics. +# We can check that this is true by accessing the extension data. + +all_metric_names = list(sorting_analyzer.get_extension('quality_metrics').get_data().keys()) + list(sorting_analyzer.get_extension('template_metrics').get_data().keys()) +print(set(all_metric_names) == set(model.feature_names_in_)) + +############################################################################## +# Great! We can now use the model to predict labels. Here, we pass the HF repo id directly +# to the ``auto_label_units`` function. This returns a dictionary containing a label and +# a confidence for each unit contained in the ``sorting_analyzer``. + +labels = sc.auto_label_units( + sorting_analyzer = sorting_analyzer, + repo_id = "SpikeInterface/toy_tetrode_model", + trusted = ['numpy.dtype'] +) + +print(labels) + + +############################################################################## +# The model has labelled one unit as bad. Let's look at that one, and also the 'good' unit +# with the highest confidence of being 'good'. + +sw.plot_unit_templates(sorting_analyzer, unit_ids=['7','9']) + +############################################################################## +# Nice! Unit 9 looks more like an expected action potential waveform while unit 7 doesn't, +# and it seems reasonable that unit 7 is labelled as `bad`. However, for certain experiments +# or brain areas, unit 7 might be a great small-amplitude unit. This example highlights that +# you should be careful applying models trained on one dataset to your own dataset. You can +# explore the currently available models on the `spikeinterface hugging face hub `_ +# page, or `train your own one `_. +# +# Assess the model performance +# ---------------------------- +# +# To assess the performance of the model relative to labels assigned by a human creator, we can load or generate some +# "human labels", and plot a confusion matrix of predicted vs human labels for all clusters. Here +# we'll be a conservative human, who has labelled several units with small amplitudes as 'bad'. + +human_labels = ['bad', 'good', 'good', 'bad', 'good', 'bad', 'good', 'bad', 'good', 'good'] + +# Note: if you labelled using phy, you can load the labels using: +# human_labels = sorting_analyzer.sorting.get_property('quality') +# We need to load in the `label_conversion` dictionary, which converts integers such +# as '0' and '1' to readable labels such as 'good' and 'bad'. This is stored as +# in `model_info`, which we loaded earlier. + +from sklearn.metrics import confusion_matrix, balanced_accuracy_score + +label_conversion = model_info['label_conversion'] +predictions = labels['prediction'] + +conf_matrix = confusion_matrix(human_labels, predictions) + +# Calculate balanced accuracy for the confusion matrix +balanced_accuracy = balanced_accuracy_score(human_labels, predictions) + +plt.imshow(conf_matrix) +for (index, value) in np.ndenumerate(conf_matrix): + plt.annotate( str(value), xy=index, color="white", fontsize="15") +plt.xlabel('Predicted Label') +plt.ylabel('Human Label') +plt.xticks(ticks = [0, 1], labels = list(label_conversion.values())) +plt.yticks(ticks = [0, 1], labels = list(label_conversion.values())) +plt.title('Predicted vs Human Label') +plt.suptitle(f"Balanced Accuracy: {balanced_accuracy}") +plt.show() + + +############################################################################## +# Here, there are several false positives (if we consider the human labels to be "the truth"). +# +# Next, we can also see how the model's confidence relates to the probability that the model +# label matches the human label. +# +# This could be used to help decide which units should be auto-curated and which need further +# manual creation. For example, we might accept any unit as 'good' that the model predicts +# as 'good' with confidence over a threshold, say 80%. If the confidence is lower we might decide to take a +# look at this unit manually. Below, we will create a plot that shows how the agreement +# between human and model labels changes as we increase the confidence threshold. We see that +# the agreement increases as the confidence does. So the model gets more accurate with a +# higher confidence threshold, as expceted. + + +def calculate_moving_avg(label_df, confidence_label, window_size): + + label_df[f'{confidence_label}_decile'] = pd.cut(label_df[confidence_label], 10, labels=False, duplicates='drop') + # Group by decile and calculate the proportion of correct labels (agreement) + p_label_grouped = label_df.groupby(f'{confidence_label}_decile')['model_x_human_agreement'].mean() + # Convert decile to range 0-1 + p_label_grouped.index = p_label_grouped.index / 10 + # Sort the DataFrame by confidence scores + label_df_sorted = label_df.sort_values(by=confidence_label) + + p_label_moving_avg = label_df_sorted['model_x_human_agreement'].rolling(window=window_size).mean() + + return label_df_sorted[confidence_label], p_label_moving_avg + +confidences = labels['probability'] + +# Make dataframe of human label, model label, and confidence +label_df = pd.DataFrame(data = { + 'human_label': human_labels, + 'decoder_label': predictions, + 'confidence': confidences}, + index = sorting_analyzer.sorting.get_unit_ids()) + +# Calculate the proportion of agreed labels by confidence decile +label_df['model_x_human_agreement'] = label_df['human_label'] == label_df['decoder_label'] + +p_agreement_sorted, p_agreement_moving_avg = calculate_moving_avg(label_df, 'confidence', 3) + +# Plot the moving average of agreement +plt.figure(figsize=(6, 6)) +plt.plot(p_agreement_sorted, p_agreement_moving_avg, label = 'Moving Average') +plt.axhline(y=1/len(np.unique(predictions)), color='black', linestyle='--', label='Chance') +plt.xlabel('Confidence'); #plt.xlim(0.5, 1) +plt.ylabel('Proportion Agreement with Human Label'); plt.ylim(0, 1) +plt.title('Agreement vs Confidence (Moving Average)') +plt.legend(); plt.grid(True); plt.show() + +############################################################################## +# In this case, you might decide to only trust labels which had confidence over above 0.88, +# and manually labels the ones the model isn't so confident about. +# +# A model trained on real Neuropixels data +# ---------------------------------------- +# +# Above, we used a toy model trained on generated data. There are also models on HuggingFace +# trained on real data. +# +# For example, the following classifiers are trained on Neuropixels data from 11 mice recorded in +# V1,SC and ALM: https://huggingface.co/AnoushkaJain3/noise_neural_classifier/ and +# https://huggingface.co/AnoushkaJain3/sua_mua_classifier/ . One will classify units into +# `noise` or `not-noise` and the other will classify the `not-noise` units into single +# unit activity (sua) units and multi-unit activity (mua) units. +# +# There is more information about the model on the model's HuggingFace page. Take a look! +# The idea here is to first apply the noise/not-noise classifier, then the sua/mua one. +# We can do so as follows: +# + +# Apply the noise/not-noise model +noise_neuron_labels = sc.auto_label_units( + sorting_analyzer = sorting_analyzer, + repo_id = "AnoushkaJain3/noise_neural_classifier", + trust_model=True, +) + +noise_units = noise_neuron_labels[noise_neuron_labels['prediction']=='noise'] +analyzer_neural = sorting_analyzer.remove_units(noise_units.index) + +# Apply the sua/mua model +sua_mua_labels = sc.auto_label_units( + sorting_analyzer = analyzer_neural, + repo_id = "AnoushkaJain3/sua_mua_classifier", + trust_model=True, +) + +all_labels = pd.concat([sua_mua_labels, noise_units]).sort_index() +print(all_labels) + +############################################################################## +# If you run this without the ``trust_model=True`` parameter, you will receive an error: +# +# .. code-block:: +# +# UntrustedTypesFoundException: Untrusted types found in the file: ['sklearn.metrics._classification.balanced_accuracy_score', 'sklearn.metrics._scorer._Scorer', 'sklearn.model_selection._search_successive_halving.HalvingGridSearchCV', 'sklearn.model_selection._split.StratifiedKFold'] +# +# This is a security warning, which can be overcome by passing the trusted types list +# ``trusted = ['sklearn.metrics._classification.balanced_accuracy_score', 'sklearn.metrics._scorer._Scorer', 'sklearn.model_selection._search_successive_halving.HalvingGridSearchCV', 'sklearn.model_selection._split.StratifiedKFold']`` +# or by passing the ``trust_model=True``` keyword. +# +# .. dropdown:: More about security +# +# Sharing models, with are Python objects, is complicated. +# We have chosen to use the `skops format `_, instead +# of the common but insecure ``.pkl`` format (read about ``pickle`` security issues +# `here `_). While unpacking the ``.skops`` file, each function +# is checked. Ideally, skops should recognise all `sklearn`, `numpy` and `scipy` functions and +# allow the object to be loaded if it only contains these (and no unkown malicious code). But +# when ``skops`` it's not sure, it raises an error. Here, it doesn't recognise +# ``['sklearn.metrics._classification.balanced_accuracy_score', 'sklearn.metrics._scorer._Scorer', +# 'sklearn.model_selection._search_successive_halving.HalvingGridSearchCV', +# 'sklearn.model_selection._split.StratifiedKFold']``. Taking a look, these are all functions +# from `sklearn`, and we can happily add them to the ``trusted`` functions to load. +# +# In general, you should be cautious when downloading ``.skops`` files and ``.pkl`` files from repos, +# especially from unknown sources. +# +# Directly applying a sklearn Pipeline +# ------------------------------------ +# +# Instead of using ``HuggingFace`` and ``skops``, someone might have given you a model +# in differet way: perhaps by e-mail or a download. If you have the model in a +# folder, you can apply it in a very similar way: +# +# .. code-block:: +# +# labels = sc.auto_label_units( +# sorting_analyzer = sorting_analyzer, +# model_folder = "path/to/model/folder", +# ) + +############################################################################## +# Using this, you lose the advantages of the model metadata: the quality metric parameters +# are not checked and the labels are not converted their original human readable names (like +# 'good' and 'bad'). Hence we advise using the methods discussed above, when possible. diff --git a/examples/tutorials/curation/plot_2_train_a_model.py b/examples/tutorials/curation/plot_2_train_a_model.py new file mode 100644 index 0000000000..1a38836527 --- /dev/null +++ b/examples/tutorials/curation/plot_2_train_a_model.py @@ -0,0 +1,168 @@ +""" +Training a model for automated curation +============================= + +If the pretrained models do not give satisfactory performance on your data, it is easy to train your own classifier using SpikeInterface. +""" + + +############################################################################## +# Step 1: Generate and label data +# ------------------------------- +# +# First we will import our dependencies +import warnings +warnings.filterwarnings("ignore") +from pathlib import Path +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt + +import spikeinterface.core as si +import spikeinterface.curation as sc +import spikeinterface.widgets as sw + +# Note, you can set the number of cores you use using e.g. +# si.set_global_job_kwargs(n_jobs = 8) + +############################################################################## +# For this tutorial, we will use simulated data to create ``recording`` and ``sorting`` objects. We'll +# create two sorting objects: :code:`sorting_1` is coupled to the real recording, so the spike times of the sorter will +# perfectly match the spikes in the recording. Hence this will contain good units. However, we've +# uncoupled :code:`sorting_2` to the recording and the spike times will not be matched with the spikes in the recording. +# Hence these units will mostly be random noise. We'll combine the "good" and "noise" sortings into one sorting +# object using :code:`si.aggregate_units`. +# +# (When making your own model, you should +# `load your own recording `_ +# and `do a sorting `_ on your data.) + +recording, sorting_1 = si.generate_ground_truth_recording(num_channels=4, seed=1, num_units=5) +_, sorting_2 =si.generate_ground_truth_recording(num_channels=4, seed=2, num_units=5) + +both_sortings = si.aggregate_units([sorting_1, sorting_2]) + +############################################################################## +# To do some visualisation and postprocessing, we need to create a sorting analyzer, and +# compute some extensions: + +analyzer = si.create_sorting_analyzer(sorting = both_sortings, recording=recording) +analyzer.compute(['noise_levels','random_spikes','waveforms','templates']) + +############################################################################## +# Now we can plot the templates for the first and fifth units. The first (unit id 0) belongs to +# :code:`sorting_1` so should look like a real unit; the sixth (unit id 5) belongs to :code:`sorting_2` +# so should look like noise. + +sw.plot_unit_templates(analyzer, unit_ids=["0", "5"]) + +############################################################################## +# This is as expected: great! (Find out more about plotting using widgets `here `_.) +# We've set up our system so that the first five units are 'good' and the next five are 'bad'. +# So we can make a list of labels which contain this information. For real data, you could +# use a manual curation tool to make your own list. + +labels = ['good', 'good', 'good', 'good', 'good', 'bad', 'bad', 'bad', 'bad', 'bad'] + +############################################################################## +# Step 2: Train our model +# ----------------------- +# +# We'll now train a model, based on our labelled data. The model will be trained using properties +# of the units, and then be applied to units from other sortings. The properties we use are the +# `quality metrics `_ +# and `template metrics `_. +# Hence we need to compute these, using some ``sorting_analyzer``` extensions. + +analyzer.compute(['spike_locations','spike_amplitudes','correlograms','principal_components','quality_metrics','template_metrics']) + +############################################################################## +# Now that we have metrics and labels, we're ready to train the model using the +# ``train_model``` function. The trainer will try several classifiers, imputation strategies and +# scaling techniques then save the most accurate. To save time in this tutorial, +# we'll only try one classifier (Random Forest), imputation strategy (median) and scaling +# technique (standard scaler). +# +# We will use a list of one analyzer here, so the model is trained on a single +# session. In reality, we would usually train a model using multiple analyzers from an +# experiment, which should make the model more robust. To do this, you can simply pass +# a list of analyzers and a list of manually curated labels for each +# of these analyzers. Then the model would use all of these data as input. + +trainer = sc.train_model( + mode = "analyzers", # You can supply a labelled csv file instead of an analyzer + labels = [labels], + analyzers = [analyzer], + folder = "my_folder", # Where to save the model and model_info.json file + metric_names = None, # Specify which metrics to use for training: by default uses those already calculted + imputation_strategies = ["median"], # Defaults to all + scaling_techniques = ["standard_scaler"], # Defaults to all + classifiers = None, # Default to Random Forest only. Other classifiers you can try [ "AdaBoostClassifier","GradientBoostingClassifier","LogisticRegression","MLPClassifier"] + overwrite = True, # Whether or not to overwrite `folder` if it already exists. Default is False. + search_kwargs = {'cv': 3} # Parameters used during the model hyperparameter search +) + +best_model = trainer.best_pipeline + +############################################################################## +# +# You can pass many sklearn `classifiers `_ +# `imputation strategies `_ and +# `scalers `_, although the +# documentation is quite overwhelming. You can find the classifiers we've tried out +# using the ``sc.get_default_classifier_search_spaces`` function. +# +# The above code saves the model in ``model.skops``, some metadata in +# ``model_info.json`` and the model accuracies in ``model_accuracies.csv`` +# in the specified ``folder`` (in this case ``'my_folder'``). +# +# (``skops`` is a file format: you can think of it as a more-secure pkl file. `Read more `_.) +# +# The ``model_accuracies.csv`` file contains the accuracy, precision and recall of the +# tested models. Let's take a look: + +accuracies = pd.read_csv(Path("my_folder") / "model_accuracies.csv", index_col = 0) +accuracies.head() + +############################################################################## +# Our model is perfect!! This is because the task was *very* easy. We had 10 units; where +# half were pure noise and half were not. +# +# The model also contains some more information, such as which features are "important", +# as defined by sklearn (learn about feature importance of a Random Forest Classifier +# `here `_.) +# We can plot these: + +# Plot feature importances +importances = best_model.named_steps['classifier'].feature_importances_ +indices = np.argsort(importances)[::-1] + +# The sklearn importances are not computed for inputs whose values are all `nan`. +# Hence, we need to pick out the non-`nan` columns of our metrics +features = best_model.feature_names_in_ +n_features = best_model.n_features_in_ + +metrics = pd.concat([analyzer.get_extension('quality_metrics').get_data(), analyzer.get_extension('template_metrics').get_data()], axis=1) +non_null_metrics = ~(metrics.isnull().all()).values + +features = features[non_null_metrics] +n_features = len(features) + +plt.figure(figsize=(12, 7)) +plt.title("Feature Importances") +plt.bar(range(n_features), importances[indices], align="center") +plt.xticks(range(n_features), features[indices], rotation=90) +plt.xlim([-1, n_features]) +plt.subplots_adjust(bottom=0.3) +plt.show() + +############################################################################## +# Roughly, this means the model is using metrics such as "nn_hit_rate" and "l_ratio" +# but is not using "sync_spike_4" and "rp_contanimation". This is a toy model, so don't +# take these results seriously. But using this information, you could retrain another, +# simpler model using a subset of the metrics, by passing, e.g., +# ``metric_names = ['nn_hit_rate', 'l_ratio',...]`` to the ``train_model`` function. +# +# Now that you have a model, you can `apply it to another sorting +# `_ +# or `upload it to HuggingFaceHub `_. diff --git a/examples/tutorials/curation/plot_3_upload_a_model.py b/examples/tutorials/curation/plot_3_upload_a_model.py new file mode 100644 index 0000000000..0a9ea402db --- /dev/null +++ b/examples/tutorials/curation/plot_3_upload_a_model.py @@ -0,0 +1,139 @@ +""" +Upload a pipeline to Hugging Face Hub +===================================== +""" +############################################################################## +# In this tutorial we will upload a pipeline, trained in SpikeInterface, to the +# `Hugging Face Hub `_ (HFH). +# +# To do this, you first need to train a model. `Learn how here! `_ +# +# Hugging Face Hub? +# ----------------- +# Hugging Face Hub (HFH) is a model sharing platform focused on AI and ML models and datasets. +# To upload your own model to HFH, you need to make an account with them. +# If you do not want to make an account, you can simply share the model folder with colleagues. +# There are also several ways to interaction with HFH: the way we propose here doesn't use +# many of the tools ``skops`` and hugging face have developed such as the ``Card`` and +# ``hub_utils``. Feel free to check those out `here `_. +# +# Prepare your model +# ------------------ +# +# The plan is to make a folder with the following file structure +# +# .. code-block:: +# +# my_model_folder/ +# my_model_name.skops +# model_info.json +# training_data.csv +# labels.csv +# metadata.json +# +# SpikeInterface and HFH don't require you to keep this folder structure, we just advise it as +# best practice. +# +# If you've used SpikeInterface to train your model, the ``train_model`` function auto-generates +# most of this data. The only thing missing is the the ``metadata.json`` file. The purpose of this +# file is to detail how the model was trained, which can help prospective users decide if it +# is relevant for them. For example, taking +# a model trained on mouse data and applying it to a primate is likely a bad idea (or a +# great research paper!). And a model trained using tetrode data might have limited application +# on a silcone high-density probes. Hence we suggest saving at least the species, brain areas +# and probe information, as is done in the dictionary below. Note that we format the metadata +# so that the information +# in common with the NWB data format is consistent with it. Since the models can be trained +# on several curations, all the metadata fields are lists: +# +# .. code-block:: +# +# import json +# +# model_metadata = { +# "subject_species": ["Mus musculus"], +# "brain_areas": ["CA1"], +# "probes": +# [{ +# "manufacturer": "IMEc", +# "name": "Neuropixels 2.0" +# }] +# } +# with open("my_model_folder/metadata.json", "w") as file: +# json.dump(model_metadata, file) +# +# Upload to HuggingFaceHub +# ------------------------ +# +# We'll now upload this folder to HFH using the web interface. +# +# First, go to https://huggingface.co/ and make an account. Once you've logged in, press +# ``+`` then ``New model`` or find ``+ New Model`` in the user menu. You will be asked +# to enter a model name, to choose a license for the model and whether the model should +# be public or private. After you have made these choices, press ``Create Model``. +# +# You should be on your model's landing page, whose header looks something like +# +# .. image:: ../../images/initial_model_screen.png +# :width: 550 +# :align: center +# :alt: The page shown on HuggingFaceHub when a user first initialises a model +# +# Click Files, then ``+ Add file`` then ``Upload file(s)``. You can then add your files to the repository. Upload these by pressing ``Commit changes to main``. +# +# You are returned to the Files page, which should look similar to +# +# .. image:: ../../images/files_screen.png +# :width: 700 +# :align: center +# :alt: The file list for a model HuggingFaceHub. +# +# Let's add some information about the model for users to see when they go on your model's +# page. Click on ``Model card`` then ``Edit model card``. Here is a sample model card for +# For a model based on synthetically generated tetrode data, +# +# .. code-block:: +# +# --- +# license: mit +# --- +# +# ## Model description +# +# A toy model, trained on toy data generated from spikeinterface. +# +# # Intended use +# +# Used to try out automated curation in SpikeInterface. +# +# # How to Get Started with the Model +# +# This can be used to automatically label a sorting in spikeinterface. Provided you have a `sorting_analyzer`, it is used as follows +# +# ` ` ` python (NOTE: you should remove the spaces between each backtick. This is just formatting for the notebook you are reading) +# +# from spikeinterface.curation import auto_label_units +# labels = auto_label_units( +# sorting_analyzer = sorting_analyzer, +# repo_id = "SpikeInterface/toy_tetrode_model", +# trust_model=True +# ) +# ` ` ` +# +# or you can download the entire repositry to `a_folder_for_a_model`, and use +# +# ` ` ` python +# from spikeinterface.curation import auto_label_units +# +# labels = auto_label_units( +# sorting_analyzer = sorting_analyzer, +# model_folder = "path/to/a_folder_for_a_model", +# trusted = ['numpy.dtype'] +# ) +# ` ` ` +# +# # Authors +# +# Chris Halcrow +# +# You can see the repo with this Model card `here `_. diff --git a/examples/tutorials/qualitymetrics/plot_3_quality_mertics.py b/examples/tutorials/qualitymetrics/plot_3_quality_metrics.py similarity index 100% rename from examples/tutorials/qualitymetrics/plot_3_quality_mertics.py rename to examples/tutorials/qualitymetrics/plot_3_quality_metrics.py diff --git a/examples/tutorials/widgets/plot_2_sort_gallery.py b/examples/tutorials/widgets/plot_2_sort_gallery.py index da5c611ce4..056b5e3a8d 100644 --- a/examples/tutorials/widgets/plot_2_sort_gallery.py +++ b/examples/tutorials/widgets/plot_2_sort_gallery.py @@ -31,14 +31,14 @@ # plot_autocorrelograms() # ~~~~~~~~~~~~~~~~~~~~~~~~ -w_ach = sw.plot_autocorrelograms(sorting, window_ms=150.0, bin_ms=5.0, unit_ids=[1, 2, 5]) +w_ach = sw.plot_autocorrelograms(sorting, window_ms=150.0, bin_ms=5.0, unit_ids=['1', '2', '5']) ############################################################################## # plot_crosscorrelograms() # ~~~~~~~~~~~~~~~~~~~~~~~~ -w_cch = sw.plot_crosscorrelograms(sorting, window_ms=150.0, bin_ms=5.0, unit_ids=[1, 2, 5]) +w_cch = sw.plot_crosscorrelograms(sorting, window_ms=150.0, bin_ms=5.0, unit_ids=['1', '2', '5']) plt.show() diff --git a/pyproject.toml b/pyproject.toml index a43ab63c8e..0b2f06049f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ authors = [ ] description = "Python toolkit for analysis, visualization, and comparison of spike sorting output" readme = "README.md" -requires-python = ">=3.9,<4.0" +requires-python = ">=3.9,<3.13" # Only numpy 2.1 supported on python 3.13 for windows. We need to wait for fix on neo classifiers = [ "Programming Language :: Python :: 3 :: Only", "License :: OSI Approved :: MIT License", @@ -73,7 +73,7 @@ extractors = [ ] streaming_extractors = [ - "ONE-api>=2.7.0", # alf sorter and streaming IBL + "ONE-api>=2.7.0,<2.10.0", # alf sorter and streaming IBL "ibllib>=2.36.0", # streaming IBL # Following dependencies are for streaming with nwb files "pynwb>=2.6.0", @@ -101,6 +101,8 @@ full = [ "matplotlib>=3.6", # matplotlib.colormaps "cuda-python; platform_system != 'Darwin'", "numba", + "skops", + "huggingface_hub" ] widgets = [ @@ -171,6 +173,10 @@ test = [ "torch", "pynndescent", + # curation + "skops", + "huggingface_hub", + # for github test : probeinterface and neo from master # for release we need pypi, so this need to be commented "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", @@ -192,6 +198,8 @@ docs = [ "hdbscan>=0.8.33", # For sorters spykingcircus2 + tridesclous "numba", # For many postprocessing functions "networkx", + "skops", # For auotmated curation + "scikit-learn", # For auotmated curation # Download data "pooch>=1.8.2", "datalad>=1.0.2", diff --git a/readthedocs.yml b/readthedocs.yml index 512fcbc709..c6c44d83a0 100644 --- a/readthedocs.yml +++ b/readthedocs.yml @@ -1,5 +1,9 @@ version: 2 +sphinx: + # Path to your Sphinx configuration file. + configuration: doc/conf.py + build: os: ubuntu-22.04 tools: diff --git a/src/spikeinterface/benchmark/benchmark_base.py b/src/spikeinterface/benchmark/benchmark_base.py index b9cbf269c8..fc1b136d2d 100644 --- a/src/spikeinterface/benchmark/benchmark_base.py +++ b/src/spikeinterface/benchmark/benchmark_base.py @@ -208,10 +208,11 @@ def run(self, case_keys=None, keep=True, verbose=False, **job_kwargs): for key in case_keys: result_folder = self.folder / "results" / self.key_to_str(key) + sorter_folder = self.folder / "sorters" / self.key_to_str(key) if keep and result_folder.exists(): continue - elif not keep and result_folder.exists(): + elif not keep and (result_folder.exists() or sorter_folder.exists()): self.remove_benchmark(key) job_keys.append(key) diff --git a/src/spikeinterface/benchmark/benchmark_motion_estimation.py b/src/spikeinterface/benchmark/benchmark_motion_estimation.py index abb2a51bae..5a3c490d38 100644 --- a/src/spikeinterface/benchmark/benchmark_motion_estimation.py +++ b/src/spikeinterface/benchmark/benchmark_motion_estimation.py @@ -109,6 +109,8 @@ def run(self, **job_kwargs): estimate_motion=t4 - t3, ) + self.result["peaks"] = peaks + self.result["peak_locations"] = peak_locations self.result["step_run_times"] = step_run_times self.result["raw_motion"] = motion @@ -131,6 +133,8 @@ def compute_result(self, **result_params): self.result["motion"] = motion _run_key_saved = [ + ("peaks", "npy"), + ("peak_locations", "npy"), ("raw_motion", "Motion"), ("step_run_times", "pickle"), ] @@ -161,7 +165,9 @@ def create_benchmark(self, key): def plot_true_drift(self, case_keys=None, scaling_probe=1.5, figsize=(8, 6)): self.plot_drift(case_keys=case_keys, tested_drift=False, scaling_probe=scaling_probe, figsize=figsize) - def plot_drift(self, case_keys=None, gt_drift=True, tested_drift=True, scaling_probe=1.0, figsize=(8, 6)): + def plot_drift( + self, case_keys=None, gt_drift=True, tested_drift=True, raster=False, scaling_probe=1.0, figsize=(8, 6) + ): import matplotlib.pyplot as plt if case_keys is None: @@ -195,6 +201,13 @@ def plot_drift(self, case_keys=None, gt_drift=True, tested_drift=True, scaling_p # for i in range(self.gt_unit_positions.shape[1]): # ax.plot(temporal_bins_s, self.gt_unit_positions[:, i], alpha=0.5, ls="--", c="0.5") + if raster: + peaks = bench.result["peaks"] + peak_locations = bench.result["peak_locations"] + rec = bench.recording + x = peaks["sample_index"] / rec.sampling_frequency + y = peak_locations[bench.direction] + ax.scatter(x, y, alpha=0.2, s=2, c=np.abs(peaks["amplitude"]), cmap="inferno") for i in range(gt_motion.displacement[0].shape[1]): depth = motion.spatial_bins_um[i] diff --git a/src/spikeinterface/benchmark/benchmark_sorter.py b/src/spikeinterface/benchmark/benchmark_sorter.py index f9267c785a..3cf6dca04f 100644 --- a/src/spikeinterface/benchmark/benchmark_sorter.py +++ b/src/spikeinterface/benchmark/benchmark_sorter.py @@ -56,6 +56,15 @@ def create_benchmark(self, key): benchmark = SorterBenchmark(recording, gt_sorting, params, sorter_folder) return benchmark + def remove_benchmark(self, key): + BenchmarkStudy.remove_benchmark(self, key) + + sorter_folder = self.folder / "sorters" / self.key_to_str(key) + import shutil + + if sorter_folder.exists(): + shutil.rmtree(sorter_folder) + def get_performance_by_unit(self, case_keys=None): import pandas as pd diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index bc5de63d07..447bbe562e 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -22,21 +22,23 @@ class ComputeRandomSpikes(AnalyzerExtension): """ - AnalyzerExtension that select some random spikes. + AnalyzerExtension that select somes random spikes. + This allows for a subsampling of spikes for further calculations and is important + for managing that amount of memory and speed of computation in the analyzer. This will be used by the `waveforms`/`templates` extensions. - This internally use `random_spikes_selection()` parameters are the same. + This internally uses `random_spikes_selection()` parameters. Parameters ---------- - method: "uniform" | "all", default: "uniform" + method : "uniform" | "all", default: "uniform" The method to select the spikes - max_spikes_per_unit: int, default: 500 + max_spikes_per_unit : int, default: 500 The maximum number of spikes per unit, ignored if method="all" - margin_size: int, default: None + margin_size : int, default: None A margin on each border of segments to avoid border spikes, ignored if method="all" - seed: int or None, default: None + seed : int or None, default: None A seed for the random generator, ignored if method="all" Returns @@ -104,7 +106,7 @@ def get_random_spikes(self): return self._some_spikes def get_selected_indices_in_spike_train(self, unit_id, segment_index): - # usefull for Waveforms extractor backwars compatibility + # useful for WaveformExtractor backwards compatibility # In Waveforms extractor "selected_spikes" was a dict (key: unit_id) of list (segment_index) of indices of spikes in spiketrain sorting = self.sorting_analyzer.sorting random_spikes_indices = self.data["random_spikes_indices"] @@ -133,16 +135,16 @@ class ComputeWaveforms(AnalyzerExtension): Parameters ---------- - ms_before: float, default: 1.0 + ms_before : float, default: 1.0 The number of ms to extract before the spike events - ms_after: float, default: 2.0 + ms_after : float, default: 2.0 The number of ms to extract after the spike events - dtype: None | dtype, default: None + dtype : None | dtype, default: None The dtype of the waveforms. If None, the dtype of the recording is used. Returns ------- - waveforms: np.ndarray + waveforms : np.ndarray Array with computed waveforms with shape (num_random_spikes, num_samples, num_channels) """ @@ -380,7 +382,12 @@ def _set_params(self, ms_before: float = 1.0, ms_after: float = 2.0, operators=N assert isinstance(operators, list) for operator in operators: if isinstance(operator, str): - assert operator in ("average", "std", "median", "mad") + if operator not in ("average", "std", "median", "mad"): + error_msg = ( + f"You have entered an operator {operator} in your `operators` argument which is " + f"not supported. Please use any of ['average', 'std', 'median', 'mad'] instead." + ) + raise ValueError(error_msg) else: assert isinstance(operator, (list, tuple)) assert len(operator) == 2 @@ -405,9 +412,13 @@ def _run(self, verbose=False, **job_kwargs): self._compute_and_append_from_waveforms(self.params["operators"]) else: - for operator in self.params["operators"]: - if operator not in ("average", "std"): - raise ValueError(f"Computing templates with operators {operator} needs the 'waveforms' extension") + bad_operator_list = [ + operator for operator in self.params["operators"] if operator not in ("average", "std") + ] + if len(bad_operator_list) > 0: + raise ValueError( + f"Computing templates with operators {bad_operator_list} requires the 'waveforms' extension" + ) recording = self.sorting_analyzer.recording sorting = self.sorting_analyzer.sorting @@ -441,7 +452,7 @@ def _run(self, verbose=False, **job_kwargs): def _compute_and_append_from_waveforms(self, operators): if not self.sorting_analyzer.has_extension("waveforms"): - raise ValueError(f"Computing templates with operators {operators} needs the 'waveforms' extension") + raise ValueError(f"Computing templates with operators {operators} requires the 'waveforms' extension") unit_ids = self.sorting_analyzer.unit_ids channel_ids = self.sorting_analyzer.channel_ids @@ -466,7 +477,7 @@ def _compute_and_append_from_waveforms(self, operators): assert self.sorting_analyzer.has_extension( "random_spikes" - ), "compute templates requires the random_spikes extension. You can run sorting_analyzer.get_random_spikes()" + ), "compute 'templates' requires the random_spikes extension. You can run sorting_analyzer.compute('random_spikes')" some_spikes = self.sorting_analyzer.get_extension("random_spikes").get_random_spikes() for unit_index, unit_id in enumerate(unit_ids): spike_mask = some_spikes["unit_index"] == unit_index @@ -549,9 +560,17 @@ def _get_data(self, operator="average", percentile=None, outputs="numpy"): if operator != "percentile": key = operator else: - assert percentile is not None, "You must provide percentile=..." + assert percentile is not None, "You must provide percentile=... if `operator=percentile`" key = f"percentile_{percentile}" + if key not in self.data.keys(): + error_msg = ( + f"You have entered `operator={key}`, but the only operators calculated are " + f"{list(self.data.keys())}. Please use one of these as your `operator` in the " + f"`get_data` function." + ) + raise ValueError(error_msg) + templates_array = self.data[key] if outputs == "numpy": @@ -566,7 +585,7 @@ def _get_data(self, operator="average", percentile=None, outputs="numpy"): probe=self.sorting_analyzer.get_probe(), ) else: - raise ValueError("outputs must be numpy or Templates") + raise ValueError("outputs must be `numpy` or `Templates`") def get_templates(self, unit_ids=None, operator="average", percentile=None, save=True, outputs="numpy"): """ @@ -576,26 +595,26 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save Parameters ---------- - unit_ids: list or None + unit_ids : list or None Unit ids to retrieve waveforms for - operator: "average" | "median" | "std" | "percentile", default: "average" + operator : "average" | "median" | "std" | "percentile", default: "average" The operator to compute the templates - percentile: float, default: None + percentile : float, default: None Percentile to use for operator="percentile" - save: bool, default True + save : bool, default: True In case, the operator is not computed yet it can be saved to folder or zarr - outputs: "numpy" | "Templates" + outputs : "numpy" | "Templates", default: "numpy" Whether to return a numpy array or a Templates object Returns ------- - templates: np.array + templates : np.array | Templates The returned templates (num_units, num_samples, num_channels) """ if operator != "percentile": key = operator else: - assert percentile is not None, "You must provide percentile=..." + assert percentile is not None, "You must provide percentile=... if `operator='percentile'`" key = f"pencentile_{percentile}" if key in self.data: @@ -632,7 +651,7 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save is_scaled=self.sorting_analyzer.return_scaled, ) else: - raise ValueError("outputs must be numpy or Templates") + raise ValueError("`outputs` must be 'numpy' or 'Templates'") def get_unit_template(self, unit_id, operator="average"): """ @@ -642,7 +661,7 @@ def get_unit_template(self, unit_id, operator="average"): ---------- unit_id: str | int Unit id to retrieve waveforms for - operator: str + operator: str, default: "average" The operator to compute the templates Returns @@ -691,22 +710,23 @@ class ComputeNoiseLevels(AnalyzerExtension): need_recording = True use_nodepipeline = False need_job_kwargs = False + need_backward_compatibility_on_load = True def __init__(self, sorting_analyzer): AnalyzerExtension.__init__(self, sorting_analyzer) - def _set_params(self, num_chunks_per_segment=20, chunk_size=10000, seed=None): - params = dict(num_chunks_per_segment=num_chunks_per_segment, chunk_size=chunk_size, seed=seed) + def _set_params(self, **noise_level_params): + params = noise_level_params.copy() return params def _select_extension_data(self, unit_ids): - # this do not depend on units + # this does not depend on units return self.data def _merge_extension_data( self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs ): - # this do not depend on units + # this does not depend on units return self.data.copy() def _run(self, verbose=False): @@ -717,6 +737,15 @@ def _run(self, verbose=False): def _get_data(self): return self.data["noise_levels"] + def _handle_backward_compatibility_on_load(self): + # The old parameters used to be params=dict(num_chunks_per_segment=20, chunk_size=10000, seed=None) + # now it is handle more explicitly using random_slices_kwargs=dict() + for key in ("num_chunks_per_segment", "chunk_size", "seed"): + if key in self.params: + if "random_slices_kwargs" not in self.params: + self.params["random_slices_kwargs"] = dict() + self.params["random_slices_kwargs"][key] = self.params.pop(key) + register_result_extension(ComputeNoiseLevels) compute_noise_levels = ComputeNoiseLevels.function_factory() diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 5e2e9e4014..7ca527e255 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -1,4 +1,5 @@ from __future__ import annotations + import warnings from pathlib import Path @@ -7,14 +8,9 @@ from .base import BaseSegment from .baserecordingsnippets import BaseRecordingSnippets -from .core_tools import ( - convert_bytes_to_str, - convert_seconds_to_str, -) -from .recording_tools import write_binary_recording - - +from .core_tools import convert_bytes_to_str, convert_seconds_to_str from .job_tools import split_job_kwargs +from .recording_tools import write_binary_recording class BaseRecording(BaseRecordingSnippets): @@ -509,6 +505,35 @@ def reset_times(self): rs.t_start = None rs.sampling_frequency = self.sampling_frequency + def shift_times(self, shift: int | float, segment_index: int | None = None) -> None: + """ + Shift all times by a scalar value. + + Parameters + ---------- + shift : int | float + The shift to apply. If positive, times will be increased by `shift`. + e.g. shifting by 1 will be like the recording started 1 second later. + If negative, the start time will be decreased i.e. as if the recording + started earlier. + + segment_index : int | None + The segment on which to shift the times. + If `None`, all segments will be shifted. + """ + if segment_index is None: + segments_to_shift = range(self.get_num_segments()) + else: + segments_to_shift = (segment_index,) + + for idx in segments_to_shift: + rs = self._recording_segments[idx] + + if self.has_time_vector(segment_index=idx): + rs.time_vector += shift + else: + rs.t_start += shift + def sample_index_to_time(self, sample_ind, segment_index=None): """ Transform sample index into time in seconds @@ -921,11 +946,11 @@ def time_to_sample_index(self, time_s): sample_index = time_s * self.sampling_frequency else: sample_index = (time_s - self.t_start) * self.sampling_frequency - sample_index = round(sample_index) + sample_index = np.round(sample_index).astype(int) else: sample_index = np.searchsorted(self.time_vector, time_s, side="right") - 1 - return int(sample_index) + return sample_index def get_num_samples(self) -> int: """Returns the number of samples in this signal segment diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 310533c96b..2ec3664a45 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -172,8 +172,10 @@ def _set_probes(self, probe_or_probegroup, group_mode="by_probe", in_place=False number_of_device_channel_indices = np.max(list(device_channel_indices) + [0]) if number_of_device_channel_indices >= self.get_num_channels(): error_msg = ( - f"The given Probe have 'device_channel_indices' that do not match channel count \n" - f"{number_of_device_channel_indices} vs {self.get_num_channels()} \n" + f"The given Probe either has 'device_channel_indices' that does not match channel count \n" + f"{len(device_channel_indices)} vs {self.get_num_channels()} \n" + f"or it's max index {number_of_device_channel_indices} is the same as the number of channels {self.get_num_channels()} \n" + f"If using all channels remember that python is 0-indexed so max device_channel_index should be {self.get_num_channels() - 1} \n" f"device_channel_indices are the following: {device_channel_indices} \n" f"recording channels are the following: {self.get_channel_ids()} \n" ) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 2af48407a3..9a0e242d62 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -135,7 +135,7 @@ def get_total_duration(self) -> float: def get_unit_spike_train( self, - unit_id, + unit_id: str | int, segment_index: Union[int, None] = None, start_frame: Union[int, None] = None, end_frame: Union[int, None] = None, diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 0316b3bab1..aa69fe585b 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -2,7 +2,7 @@ import math import warnings import numpy as np -from typing import Literal +from typing import Literal, Optional from math import ceil from .basesorting import SpikeVectorSortingSegment @@ -134,7 +134,7 @@ def generate_sorting( seed = _ensure_seed(seed) rng = np.random.default_rng(seed) num_segments = len(durations) - unit_ids = np.arange(num_units) + unit_ids = [str(idx) for idx in np.arange(num_units)] spikes = [] for segment_index in range(num_segments): @@ -1111,7 +1111,7 @@ def __init__( """ - unit_ids = np.arange(num_units) + unit_ids = [str(idx) for idx in np.arange(num_units)] super().__init__(sampling_frequency, unit_ids) self.num_units = num_units @@ -1138,6 +1138,7 @@ def __init__( firing_rates=firing_rates, refractory_period_seconds=self.refractory_period_seconds, seed=segment_seed, + unit_ids=unit_ids, t_start=None, ) self.add_sorting_segment(segment) @@ -1161,6 +1162,7 @@ def __init__( firing_rates: float | np.ndarray, refractory_period_seconds: float | np.ndarray, seed: int, + unit_ids: list[str], t_start: Optional[float] = None, ): self.num_units = num_units @@ -1177,7 +1179,8 @@ def __init__( self.refractory_period_seconds = np.full(num_units, self.refractory_period_seconds, dtype="float64") self.segment_seed = seed - self.units_seed = {unit_id: self.segment_seed + hash(unit_id) for unit_id in range(num_units)} + self.units_seed = {unit_id: abs(self.segment_seed + hash(unit_id)) for unit_id in unit_ids} + self.num_samples = math.ceil(sampling_frequency * duration) super().__init__(t_start) @@ -1280,7 +1283,7 @@ def __init__( noise_block_size: int = 30000, ): - channel_ids = np.arange(num_channels) + channel_ids = [str(idx) for idx in np.arange(num_channels)] dtype = np.dtype(dtype).name # Cast to string for serialization if dtype not in ("float32", "float64"): raise ValueError(f"'dtype' must be 'float32' or 'float64' but is {dtype}") diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index 5240edcee7..7a6172369b 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -149,12 +149,12 @@ def divide_segment_into_chunks(num_frames, chunk_size): def divide_recording_into_chunks(recording, chunk_size): - all_chunks = [] + recording_slices = [] for segment_index in range(recording.get_num_segments()): num_frames = recording.get_num_samples(segment_index) chunks = divide_segment_into_chunks(num_frames, chunk_size) - all_chunks.extend([(segment_index, frame_start, frame_stop) for frame_start, frame_stop in chunks]) - return all_chunks + recording_slices.extend([(segment_index, frame_start, frame_stop) for frame_start, frame_stop in chunks]) + return recording_slices def ensure_n_jobs(recording, n_jobs=1): @@ -185,6 +185,22 @@ def ensure_n_jobs(recording, n_jobs=1): return n_jobs +def chunk_duration_to_chunk_size(chunk_duration, recording): + if isinstance(chunk_duration, float): + chunk_size = int(chunk_duration * recording.get_sampling_frequency()) + elif isinstance(chunk_duration, str): + if chunk_duration.endswith("ms"): + chunk_duration = float(chunk_duration.replace("ms", "")) / 1000.0 + elif chunk_duration.endswith("s"): + chunk_duration = float(chunk_duration.replace("s", "")) + else: + raise ValueError("chunk_duration must ends with s or ms") + chunk_size = int(chunk_duration * recording.get_sampling_frequency()) + else: + raise ValueError("chunk_duration must be str or float") + return chunk_size + + def ensure_chunk_size( recording, total_memory=None, chunk_size=None, chunk_memory=None, chunk_duration=None, n_jobs=1, **other_kwargs ): @@ -231,18 +247,7 @@ def ensure_chunk_size( num_channels = recording.get_num_channels() chunk_size = int(total_memory / (num_channels * n_bytes * n_jobs)) elif chunk_duration is not None: - if isinstance(chunk_duration, float): - chunk_size = int(chunk_duration * recording.get_sampling_frequency()) - elif isinstance(chunk_duration, str): - if chunk_duration.endswith("ms"): - chunk_duration = float(chunk_duration.replace("ms", "")) / 1000.0 - elif chunk_duration.endswith("s"): - chunk_duration = float(chunk_duration.replace("s", "")) - else: - raise ValueError("chunk_duration must ends with s or ms") - chunk_size = int(chunk_duration * recording.get_sampling_frequency()) - else: - raise ValueError("chunk_duration must be str or float") + chunk_size = chunk_duration_to_chunk_size(chunk_duration, recording) else: # Edge case to define single chunk per segment for n_jobs=1. # All chunking parameters equal None mean single chunk per segment @@ -382,11 +387,13 @@ def __init__( f"chunk_duration={chunk_duration_str}", ) - def run(self): + def run(self, recording_slices=None): """ Runs the defined jobs. """ - all_chunks = divide_recording_into_chunks(self.recording, self.chunk_size) + + if recording_slices is None: + recording_slices = divide_recording_into_chunks(self.recording, self.chunk_size) if self.handle_returns: returns = [] @@ -395,17 +402,17 @@ def run(self): if self.n_jobs == 1: if self.progress_bar: - all_chunks = tqdm(all_chunks, ascii=True, desc=self.job_name) + recording_slices = tqdm(recording_slices, ascii=True, desc=self.job_name) worker_ctx = self.init_func(*self.init_args) - for segment_index, frame_start, frame_stop in all_chunks: + for segment_index, frame_start, frame_stop in recording_slices: res = self.func(segment_index, frame_start, frame_stop, worker_ctx) if self.handle_returns: returns.append(res) if self.gather_func is not None: self.gather_func(res) else: - n_jobs = min(self.n_jobs, len(all_chunks)) + n_jobs = min(self.n_jobs, len(recording_slices)) # parallel with ProcessPoolExecutor( @@ -414,10 +421,10 @@ def run(self): mp_context=mp.get_context(self.mp_context), initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_process), ) as executor: - results = executor.map(function_wrapper, all_chunks) + results = executor.map(function_wrapper, recording_slices) if self.progress_bar: - results = tqdm(results, desc=self.job_name, total=len(all_chunks)) + results = tqdm(results, desc=self.job_name, total=len(recording_slices)) for res in results: if self.handle_returns: diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index d90a20902d..53c2445c77 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -489,6 +489,7 @@ def run_node_pipeline( names=None, verbose=False, skip_after_n_peaks=None, + recording_slices=None, ): """ Machinery to compute in parallel operations on peaks and traces. @@ -540,6 +541,10 @@ def run_node_pipeline( skip_after_n_peaks : None | int Skip the computation after n_peaks. This is not an exact because internally this skip is done per worker in average. + recording_slices : None | list[tuple] + Optionaly give a list of slices to run the pipeline only on some chunks of the recording. + It must be a list of (segment_index, frame_start, frame_stop). + If None (default), the function iterates over the entire duration of the recording. Returns ------- @@ -578,7 +583,7 @@ def run_node_pipeline( **job_kwargs, ) - processor.run() + processor.run(recording_slices=recording_slices) outs = gather_func.finalize_buffers(squeeze_output=squeeze_output) return outs diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 77d427bc88..4aabbfd587 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -18,6 +18,8 @@ fix_job_kwargs, ChunkRecordingExecutor, _shared_job_kwargs_doc, + chunk_duration_to_chunk_size, + split_job_kwargs, ) @@ -512,33 +514,38 @@ def determine_cast_unsigned(recording, dtype): return cast_unsigned -def get_random_data_chunks( +def get_random_recording_slices( recording, - return_scaled=False, + method="full_random", num_chunks_per_segment=20, - chunk_size=10000, - concatenated=True, - seed=0, + chunk_duration="500ms", + chunk_size=None, margin_frames=0, + seed=None, ): """ - Extract random chunks across segments + Get random slice of a recording across segments. - This is used for instance in get_noise_levels() to estimate noise on traces. + This is used for instance in get_noise_levels() and get_random_data_chunks() to estimate noise on traces. Parameters ---------- recording : BaseRecording The recording to get random chunks from - return_scaled : bool, default: False - If True, returned chunks are scaled to uV + method : "full_random" + The method used to get random slices. + * "full_random" : legacy method, used until version 0.101.0, there is no constrain on slices + and they can overlap. num_chunks_per_segment : int, default: 20 Number of chunks per segment - chunk_size : int, default: 10000 - Size of a chunk in number of frames + chunk_duration : str | float | None, default "500ms" + The duration of each chunk in 's' or 'ms' + chunk_size : int | None + Size of a chunk in number of frames. This is ued only if chunk_duration is None. + This is kept for backward compatibility, you should prefer 'chunk_duration=500ms' instead. concatenated : bool, default: True If True chunk are concatenated along time axis - seed : int, default: 0 + seed : int, default: None Random seed margin_frames : int, default: 0 Margin in number of frames to avoid edge effects @@ -547,42 +554,89 @@ def get_random_data_chunks( ------- chunk_list : np.array Array of concatenate chunks per segment + + """ # TODO: if segment have differents length make another sampling that dependant on the length of the segment # Should be done by changing kwargs with total_num_chunks=XXX and total_duration=YYYY # And randomize the number of chunk per segment weighted by segment duration - # check chunk size - num_segments = recording.get_num_segments() - for segment_index in range(num_segments): - chunk_size_limit = recording.get_num_frames(segment_index) - 2 * margin_frames - if chunk_size > chunk_size_limit: - chunk_size = chunk_size_limit - 1 - warnings.warn( - f"chunk_size is greater than the number " - f"of samples for segment index {segment_index}. " - f"Using {chunk_size}." - ) + if method == "full_random": + if chunk_size is None: + if chunk_duration is not None: + chunk_size = chunk_duration_to_chunk_size(chunk_duration, recording) + else: + raise ValueError("get_random_recording_slices need chunk_size or chunk_duration") + + # check chunk size + num_segments = recording.get_num_segments() + for segment_index in range(num_segments): + chunk_size_limit = recording.get_num_frames(segment_index) - 2 * margin_frames + if chunk_size > chunk_size_limit: + chunk_size = chunk_size_limit - 1 + warnings.warn( + f"chunk_size is greater than the number " + f"of samples for segment index {segment_index}. " + f"Using {chunk_size}." + ) + rng = np.random.default_rng(seed) + recording_slices = [] + low = margin_frames + size = num_chunks_per_segment + for segment_index in range(num_segments): + num_frames = recording.get_num_frames(segment_index) + high = num_frames - chunk_size - margin_frames + random_starts = rng.integers(low=low, high=high, size=size) + random_starts = np.sort(random_starts) + recording_slices += [ + (segment_index, start_frame, (start_frame + chunk_size)) for start_frame in random_starts + ] + else: + raise ValueError(f"get_random_recording_slices : wrong method {method}") - rng = np.random.default_rng(seed) - chunk_list = [] - low = margin_frames - size = num_chunks_per_segment - for segment_index in range(num_segments): - num_frames = recording.get_num_frames(segment_index) - high = num_frames - chunk_size - margin_frames - random_starts = rng.integers(low=low, high=high, size=size) - segment_trace_chunk = [ - recording.get_traces( - start_frame=start_frame, - end_frame=(start_frame + chunk_size), - segment_index=segment_index, - return_scaled=return_scaled, - ) - for start_frame in random_starts - ] + return recording_slices - chunk_list.extend(segment_trace_chunk) + +def get_random_data_chunks(recording, return_scaled=False, concatenated=True, **random_slices_kwargs): + """ + Extract random chunks across segments. + + Internally, it uses `get_random_recording_slices()` and retrieves the traces chunk as a list + or a concatenated unique array. + + Please read `get_random_recording_slices()` for more details on parameters. + + + Parameters + ---------- + recording : BaseRecording + The recording to get random chunks from + return_scaled : bool, default: False + If True, returned chunks are scaled to uV + num_chunks_per_segment : int, default: 20 + Number of chunks per segment + concatenated : bool, default: True + If True chunk are concatenated along time axis + **random_slices_kwargs : dict + Options transmited to get_random_recording_slices(), please read documentation from this + function for more details. + + Returns + ------- + chunk_list : np.array | list of np.array + Array of concatenate chunks per segment + """ + recording_slices = get_random_recording_slices(recording, **random_slices_kwargs) + + chunk_list = [] + for segment_index, start_frame, end_frame in recording_slices: + traces_chunk = recording.get_traces( + start_frame=start_frame, + end_frame=end_frame, + segment_index=segment_index, + return_scaled=return_scaled, + ) + chunk_list.append(traces_chunk) if concatenated: return np.concatenate(chunk_list, axis=0) @@ -637,19 +691,52 @@ def get_closest_channels(recording, channel_ids=None, num_channels=None): return np.array(closest_channels_inds), np.array(dists) +def _noise_level_chunk(segment_index, start_frame, end_frame, worker_ctx): + recording = worker_ctx["recording"] + + one_chunk = recording.get_traces( + start_frame=start_frame, + end_frame=end_frame, + segment_index=segment_index, + return_scaled=worker_ctx["return_scaled"], + ) + + if worker_ctx["method"] == "mad": + med = np.median(one_chunk, axis=0, keepdims=True) + # hard-coded so that core doesn't depend on scipy + noise_levels = np.median(np.abs(one_chunk - med), axis=0) / 0.6744897501960817 + elif worker_ctx["method"] == "std": + noise_levels = np.std(one_chunk, axis=0) + + return noise_levels + + +def _noise_level_chunk_init(recording, return_scaled, method): + worker_ctx = {} + worker_ctx["recording"] = recording + worker_ctx["return_scaled"] = return_scaled + worker_ctx["method"] = method + return worker_ctx + + def get_noise_levels( recording: "BaseRecording", return_scaled: bool = True, method: Literal["mad", "std"] = "mad", force_recompute: bool = False, - **random_chunk_kwargs, + random_slices_kwargs: dict = {}, + **kwargs, ) -> np.ndarray: """ Estimate noise for each channel using MAD methods. You can use standard deviation with `method="std"` Internally it samples some chunk across segment. - And then, it use MAD estimator (more robust than STD) + And then, it uses the MAD estimator (more robust than STD) or the STD on each chunk. + Finally the average of all MAD/STD values is performed. + + The result is cached in a property of the recording, so that the next call on the same + recording will use the cached result unless `force_recompute=True`. Parameters ---------- @@ -662,8 +749,11 @@ def get_noise_levels( The method to use to estimate noise levels force_recompute : bool If True, noise levels are recomputed even if they are already stored in the recording extractor - random_chunk_kwargs : dict - Kwargs for get_random_data_chunks + random_slices_kwargs : dict + Options transmited to get_random_recording_slices(), please read documentation from this + function for more details. + + {} Returns ------- @@ -679,19 +769,56 @@ def get_noise_levels( if key in recording.get_property_keys() and not force_recompute: noise_levels = recording.get_property(key=key) else: - random_chunks = get_random_data_chunks(recording, return_scaled=return_scaled, **random_chunk_kwargs) - - if method == "mad": - med = np.median(random_chunks, axis=0, keepdims=True) - # hard-coded so that core doesn't depend on scipy - noise_levels = np.median(np.abs(random_chunks - med), axis=0) / 0.6744897501960817 - elif method == "std": - noise_levels = np.std(random_chunks, axis=0) + # This is to keep backward compatibility + # lets keep for a while and remove this maybe in 0.103.0 + # chunk_size used to be in the signature and now is ambiguous + random_slices_kwargs_, job_kwargs = split_job_kwargs(kwargs) + if len(random_slices_kwargs_) > 0 or "chunk_size" in job_kwargs: + msg = ( + "get_noise_levels(recording, num_chunks_per_segment=20) is deprecated\n" + "Now, you need to use get_noise_levels(recording, random_slices_kwargs=dict(num_chunks_per_segment=20, chunk_size=1000))\n" + "Please read get_random_recording_slices() documentation for more options." + ) + # if the user use both the old and the new behavior then an error is raised + assert len(random_slices_kwargs) == 0, msg + warnings.warn(msg) + random_slices_kwargs = random_slices_kwargs_ + if "chunk_size" in job_kwargs: + random_slices_kwargs["chunk_size"] = job_kwargs["chunk_size"] + + recording_slices = get_random_recording_slices(recording, **random_slices_kwargs) + + noise_levels_chunks = [] + + def append_noise_chunk(res): + noise_levels_chunks.append(res) + + func = _noise_level_chunk + init_func = _noise_level_chunk_init + init_args = (recording, return_scaled, method) + executor = ChunkRecordingExecutor( + recording, + func, + init_func, + init_args, + job_name="noise_level", + verbose=False, + gather_func=append_noise_chunk, + **job_kwargs, + ) + executor.run(recording_slices=recording_slices) + noise_levels_chunks = np.stack(noise_levels_chunks) + noise_levels = np.mean(noise_levels_chunks, axis=0) + + # set property recording.set_property(key, noise_levels) return noise_levels +get_noise_levels.__doc__ = get_noise_levels.__doc__.format(_shared_job_kwargs_doc) + + def get_chunk_with_margin( rec_segment, start_frame, diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 55cbe6070a..fdad87287e 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -2092,6 +2092,13 @@ def load_data(self): import pandas as pd ext_data = pd.read_csv(ext_data_file, index_col=0) + # we need to cast the index to the unit id dtype (int or str) + unit_ids = self.sorting_analyzer.unit_ids + if ext_data.shape[0] == unit_ids.size: + # we force dtype to be the same as unit_ids + if ext_data.index.dtype != unit_ids.dtype: + ext_data.index = ext_data.index.astype(unit_ids.dtype) + elif ext_data_file.suffix == ".pkl": with ext_data_file.open("rb") as f: ext_data = pickle.load(f) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index b64f0610ea..3e3fcc7384 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -205,6 +205,7 @@ def to_sparse(self, sparsity): unit_ids=self.unit_ids, probe=self.probe, check_for_consistent_sparsity=self.check_for_consistent_sparsity, + is_scaled=self.is_scaled, ) def get_one_template_dense(self, unit_index): diff --git a/src/spikeinterface/core/template_tools.py b/src/spikeinterface/core/template_tools.py index 934b18ed49..3c8663df70 100644 --- a/src/spikeinterface/core/template_tools.py +++ b/src/spikeinterface/core/template_tools.py @@ -31,7 +31,12 @@ def get_dense_templates_array(one_object: Templates | SortingAnalyzer, return_sc ) ext = one_object.get_extension("templates") if ext is not None: - templates_array = ext.data["average"] + if "average" in ext.data: + templates_array = ext.data.get("average") + elif "median" in ext.data: + templates_array = ext.data.get("median") + else: + raise ValueError("Average or median templates have not been computed.") else: raise ValueError("SortingAnalyzer need extension 'templates' to be computed to retrieve templates") else: diff --git a/src/spikeinterface/core/tests/test_analyzer_extension_core.py b/src/spikeinterface/core/tests/test_analyzer_extension_core.py index 626899ab6e..6f5bef3c6c 100644 --- a/src/spikeinterface/core/tests/test_analyzer_extension_core.py +++ b/src/spikeinterface/core/tests/test_analyzer_extension_core.py @@ -2,6 +2,8 @@ import shutil +from pathlib import Path + from spikeinterface.core import generate_ground_truth_recording from spikeinterface.core import create_sorting_analyzer from spikeinterface.core import Templates @@ -250,16 +252,17 @@ def test_compute_several(create_cache_folder): if __name__ == "__main__": - - test_ComputeWaveforms(format="memory", sparse=True) - test_ComputeWaveforms(format="memory", sparse=False) - test_ComputeWaveforms(format="binary_folder", sparse=True) - test_ComputeWaveforms(format="binary_folder", sparse=False) - test_ComputeWaveforms(format="zarr", sparse=True) - test_ComputeWaveforms(format="zarr", sparse=False) - test_ComputeRandomSpikes(format="memory", sparse=True) - test_ComputeTemplates(format="memory", sparse=True) - test_ComputeNoiseLevels(format="memory", sparse=False) - - test_get_children_dependencies() - test_delete_on_recompute() + cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "core" + # test_ComputeWaveforms(format="memory", sparse=True, create_cache_folder=cache_folder) + # test_ComputeWaveforms(format="memory", sparse=False, create_cache_folder=cache_folder) + # test_ComputeWaveforms(format="binary_folder", sparse=True, create_cache_folder=cache_folder) + # test_ComputeWaveforms(format="binary_folder", sparse=False, create_cache_folder=cache_folder) + # test_ComputeWaveforms(format="zarr", sparse=True, create_cache_folder=cache_folder) + # test_ComputeWaveforms(format="zarr", sparse=False, create_cache_folder=cache_folder) + # test_ComputeRandomSpikes(format="memory", sparse=True, create_cache_folder=cache_folder) + test_ComputeRandomSpikes(format="binary_folder", sparse=False, create_cache_folder=cache_folder) + test_ComputeTemplates(format="memory", sparse=True, create_cache_folder=cache_folder) + test_ComputeNoiseLevels(format="memory", sparse=False, create_cache_folder=cache_folder) + + # test_get_children_dependencies() + # test_delete_on_recompute(cache_folder) diff --git a/src/spikeinterface/core/tests/test_basesnippets.py b/src/spikeinterface/core/tests/test_basesnippets.py index 64f7f76819..f243dd9d9f 100644 --- a/src/spikeinterface/core/tests/test_basesnippets.py +++ b/src/spikeinterface/core/tests/test_basesnippets.py @@ -41,8 +41,8 @@ def test_BaseSnippets(create_cache_folder): assert snippets.get_num_segments() == len(duration) assert snippets.get_num_channels() == num_channels - assert np.all(snippets.ids_to_indices([0, 1, 2]) == [0, 1, 2]) - assert np.all(snippets.ids_to_indices([0, 1, 2], prefer_slice=True) == slice(0, 3, None)) + assert np.all(snippets.ids_to_indices(["0", "1", "2"]) == [0, 1, 2]) + assert np.all(snippets.ids_to_indices(["0", "1", "2"], prefer_slice=True) == slice(0, 3, None)) # annotations / properties snippets.annotate(gre="ta") @@ -60,7 +60,7 @@ def test_BaseSnippets(create_cache_folder): ) # missing property - snippets.set_property("string_property", ["ciao", "bello"], ids=[0, 1]) + snippets.set_property("string_property", ["ciao", "bello"], ids=["0", "1"]) values = snippets.get_property("string_property") assert values[2] == "" @@ -70,14 +70,14 @@ def test_BaseSnippets(create_cache_folder): snippets.set_property, key="string_property_nan", values=["hola", "chabon"], - ids=[0, 1], + ids=["0", "1"], missing_value=np.nan, ) # int properties without missing values raise an error assert_raises(Exception, snippets.set_property, key="int_property", values=[5, 6], ids=[1, 2]) - snippets.set_property("int_property", [5, 6], ids=[1, 2], missing_value=200) + snippets.set_property("int_property", [5, 6], ids=["1", "2"], missing_value=200) values = snippets.get_property("int_property") assert values.dtype.kind == "i" diff --git a/src/spikeinterface/core/tests/test_channelsaggregationrecording.py b/src/spikeinterface/core/tests/test_channelsaggregationrecording.py index 118b6092a9..99d6890dfd 100644 --- a/src/spikeinterface/core/tests/test_channelsaggregationrecording.py +++ b/src/spikeinterface/core/tests/test_channelsaggregationrecording.py @@ -38,10 +38,12 @@ def test_channelsaggregationrecording(): assert np.allclose(traces1_1, recording_agg.get_traces(channel_ids=[str(channel_ids[1])], segment_index=seg)) assert np.allclose( - traces2_0, recording_agg.get_traces(channel_ids=[str(num_channels + channel_ids[0])], segment_index=seg) + traces2_0, + recording_agg.get_traces(channel_ids=[str(num_channels + int(channel_ids[0]))], segment_index=seg), ) assert np.allclose( - traces3_2, recording_agg.get_traces(channel_ids=[str(2 * num_channels + channel_ids[2])], segment_index=seg) + traces3_2, + recording_agg.get_traces(channel_ids=[str(2 * num_channels + int(channel_ids[2]))], segment_index=seg), ) # all traces traces1 = recording1.get_traces(segment_index=seg) diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index deef2291c6..028eaecf12 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -4,7 +4,7 @@ import shutil from spikeinterface import create_sorting_analyzer, get_template_extremum_channel, generate_ground_truth_recording - +from spikeinterface.core.job_tools import divide_recording_into_chunks # from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.core.node_pipeline import ( @@ -191,8 +191,8 @@ def test_run_node_pipeline(cache_folder_creation): unpickled_node = pickle.loads(pickled_node) -def test_skip_after_n_peaks(): - recording, sorting = generate_ground_truth_recording(num_channels=10, num_units=10, durations=[10.0]) +def test_skip_after_n_peaks_and_recording_slices(): + recording, sorting = generate_ground_truth_recording(num_channels=10, num_units=10, durations=[10.0], seed=2205) # job_kwargs = dict(chunk_duration="0.5s", n_jobs=2, progress_bar=False) job_kwargs = dict(chunk_duration="0.5s", n_jobs=1, progress_bar=False) @@ -211,18 +211,27 @@ def test_skip_after_n_peaks(): node1 = AmplitudeExtractionNode(recording, parents=[node0], param0=6.6, return_output=True) nodes = [node0, node1] + # skip skip_after_n_peaks = 30 some_amplitudes = run_node_pipeline( recording, nodes, job_kwargs, gather_mode="memory", skip_after_n_peaks=skip_after_n_peaks ) - assert some_amplitudes.size >= skip_after_n_peaks assert some_amplitudes.size < spikes.size + # slices : 1 every 4 + recording_slices = divide_recording_into_chunks(recording, 10_000) + recording_slices = recording_slices[::4] + some_amplitudes = run_node_pipeline( + recording, nodes, job_kwargs, gather_mode="memory", recording_slices=recording_slices + ) + tolerance = 1.2 + assert some_amplitudes.size < (spikes.size // 4) * tolerance + # the following is for testing locally with python or ipython. It is not used in ci or with pytest. if __name__ == "__main__": # folder = Path("./cache_folder/core") # test_run_node_pipeline(folder) - test_skip_after_n_peaks() + test_skip_after_n_peaks_and_recording_slices() diff --git a/src/spikeinterface/core/tests/test_recording_tools.py b/src/spikeinterface/core/tests/test_recording_tools.py index 23a1574f2a..dad5273f12 100644 --- a/src/spikeinterface/core/tests/test_recording_tools.py +++ b/src/spikeinterface/core/tests/test_recording_tools.py @@ -11,6 +11,7 @@ from spikeinterface.core.recording_tools import ( write_binary_recording, write_memory_recording, + get_random_recording_slices, get_random_data_chunks, get_chunk_with_margin, get_closest_channels, @@ -167,6 +168,17 @@ def test_write_memory_recording(): shm.unlink() +def test_get_random_recording_slices(): + rec = generate_recording(num_channels=1, sampling_frequency=1000.0, durations=[10.0, 20.0]) + rec_slices = get_random_recording_slices( + rec, method="full_random", num_chunks_per_segment=20, chunk_duration="500ms", margin_frames=0, seed=0 + ) + assert len(rec_slices) == 40 + for seg_ind, start, stop in rec_slices: + assert stop - start == 500 + assert seg_ind in (0, 1) + + def test_get_random_data_chunks(): rec = generate_recording(num_channels=1, sampling_frequency=1000.0, durations=[10.0, 20.0]) chunks = get_random_data_chunks(rec, num_chunks_per_segment=50, chunk_size=500, seed=0) @@ -182,16 +194,17 @@ def test_get_closest_channels(): def test_get_noise_levels(): + job_kwargs = dict(n_jobs=1, progress_bar=True) rec = generate_recording(num_channels=2, sampling_frequency=1000.0, durations=[60.0]) - noise_levels_1 = get_noise_levels(rec, return_scaled=False) - noise_levels_2 = get_noise_levels(rec, return_scaled=False) + noise_levels_1 = get_noise_levels(rec, return_scaled=False, **job_kwargs) + noise_levels_2 = get_noise_levels(rec, return_scaled=False, **job_kwargs) rec.set_channel_gains(0.1) rec.set_channel_offsets(0) - noise_levels = get_noise_levels(rec, return_scaled=True, force_recompute=True) + noise_levels = get_noise_levels(rec, return_scaled=True, force_recompute=True, **job_kwargs) - noise_levels = get_noise_levels(rec, return_scaled=True, method="std") + noise_levels = get_noise_levels(rec, return_scaled=True, method="std", **job_kwargs) # Generate a recording following a gaussian distribution to check the result of get_noise. std = 6.0 @@ -201,8 +214,10 @@ def test_get_noise_levels(): recording = NumpyRecording(traces, 30000) assert np.all(noise_levels_1 == noise_levels_2) - assert np.allclose(get_noise_levels(recording, return_scaled=False), [std, std], rtol=1e-2, atol=1e-3) - assert np.allclose(get_noise_levels(recording, method="std", return_scaled=False), [std, std], rtol=1e-2, atol=1e-3) + assert np.allclose(get_noise_levels(recording, return_scaled=False, **job_kwargs), [std, std], rtol=1e-2, atol=1e-3) + assert np.allclose( + get_noise_levels(recording, method="std", return_scaled=False, **job_kwargs), [std, std], rtol=1e-2, atol=1e-3 + ) def test_get_noise_levels_output(): @@ -216,10 +231,21 @@ def test_get_noise_levels_output(): traces = rng.normal(loc=10.0, scale=std, size=(num_samples, num_channels)) recording = NumpyRecording(traces_list=traces, sampling_frequency=sampling_frequency) - std_estimated_with_mad = get_noise_levels(recording, method="mad", return_scaled=False, chunk_size=1_000) + std_estimated_with_mad = get_noise_levels( + recording, + method="mad", + return_scaled=False, + random_slices_kwargs=dict(num_chunks_per_segment=40, chunk_size=1_000, seed=seed), + ) + print(std_estimated_with_mad) assert np.allclose(std_estimated_with_mad, [std, std], rtol=1e-2, atol=1e-3) - std_estimated_with_std = get_noise_levels(recording, method="std", return_scaled=False, chunk_size=1_000) + std_estimated_with_std = get_noise_levels( + recording, + method="std", + return_scaled=False, + random_slices_kwargs=dict(num_chunks_per_segment=40, chunk_size=1_000, seed=seed), + ) assert np.allclose(std_estimated_with_std, [std, std], rtol=1e-2, atol=1e-3) @@ -333,14 +359,16 @@ def test_do_recording_attributes_match(): if __name__ == "__main__": # Create a temporary folder using the standard library - import tempfile - - with tempfile.TemporaryDirectory() as tmpdirname: - tmp_path = Path(tmpdirname) - test_write_binary_recording(tmp_path) - test_write_memory_recording() - - test_get_random_data_chunks() - test_get_closest_channels() - test_get_noise_levels() - test_order_channels_by_depth() + # import tempfile + + # with tempfile.TemporaryDirectory() as tmpdirname: + # tmp_path = Path(tmpdirname) + # test_write_binary_recording(tmp_path) + # test_write_memory_recording() + + test_get_random_recording_slices() + # test_get_random_data_chunks() + # test_get_closest_channels() + # test_get_noise_levels() + # test_get_noise_levels_output() + # test_order_channels_by_depth() diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index 35ab18b5f2..15f089f784 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -31,6 +31,14 @@ def get_dataset(): noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"), seed=2205, ) + + # TODO: the tests or the sorting analyzer make assumptions about the ids being integers + # So keeping this the way it was + integer_channel_ids = [int(id) for id in recording.get_channel_ids()] + integer_unit_ids = [int(id) for id in sorting.get_unit_ids()] + + recording = recording.rename_channels(new_channel_ids=integer_channel_ids) + sorting = sorting.rename_units(new_unit_ids=integer_unit_ids) return recording, sorting diff --git a/src/spikeinterface/core/tests/test_time_handling.py b/src/spikeinterface/core/tests/test_time_handling.py index a129316ee7..9b7ed11bbb 100644 --- a/src/spikeinterface/core/tests/test_time_handling.py +++ b/src/spikeinterface/core/tests/test_time_handling.py @@ -15,7 +15,10 @@ class TestTimeHandling: is generated on the fly. Both time representations are tested here. """ - # Fixtures ##### + # ######################################################################### + # Fixtures + # ######################################################################### + @pytest.fixture(scope="session") def time_vector_recording(self): """ @@ -95,7 +98,10 @@ def _get_fixture_data(self, request, fixture_name): raw_recording, times_recording, all_times = time_recording_fixture return (raw_recording, times_recording, all_times) - # Tests ##### + # ######################################################################### + # Tests + # ######################################################################### + def test_has_time_vector(self, time_vector_recording): """ Test the `has_time_vector` function returns `False` before @@ -305,7 +311,87 @@ def test_sorting_analyzer_get_durations_no_recording(self, time_vector_recording assert np.array_equal(sorting_analyzer.get_total_duration(), raw_recording.get_total_duration()) - # Helpers #### + @pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) + @pytest.mark.parametrize("shift", [-123.456, 123.456]) + def test_shift_time_all_segments(self, request, fixture_name, shift): + """ + Shift the times in every segment using the `None` default, then + check that every segment of the recording is shifted as expected. + """ + _, times_recording, all_times = self._get_fixture_data(request, fixture_name) + + num_segments, orig_seg_data = self._store_all_times(times_recording) + + times_recording.shift_times(shift) # use default `segment_index=None` + + for idx in range(num_segments): + assert np.allclose( + orig_seg_data[idx], times_recording.get_times(segment_index=idx) - shift, rtol=0, atol=1e-8 + ) + + @pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) + @pytest.mark.parametrize("shift", [-123.456, 123.456]) + def test_shift_times_different_segments(self, request, fixture_name, shift): + """ + Shift each segment separately, and check the shifted segment only + is shifted as expected. + """ + _, times_recording, all_times = self._get_fixture_data(request, fixture_name) + + num_segments, orig_seg_data = self._store_all_times(times_recording) + + # For each segment, shift the segment only and check the + # times are updated as expected. + for idx in range(num_segments): + + scaler = idx + 2 + times_recording.shift_times(shift * scaler, segment_index=idx) + + assert np.allclose( + orig_seg_data[idx], times_recording.get_times(segment_index=idx) - shift * scaler, rtol=0, atol=1e-8 + ) + + # Just do a little check that we are not + # accidentally changing some other segments, + # which should remain unchanged at this point in the loop. + if idx != num_segments - 1: + assert np.array_equal(orig_seg_data[idx + 1], times_recording.get_times(segment_index=idx + 1)) + + @pytest.mark.parametrize("fixture_name", ["time_vector_recording", "t_start_recording"]) + def test_save_and_load_time_shift(self, request, fixture_name, tmp_path): + """ + Save the shifted data and check the shift is propagated correctly. + """ + _, times_recording, all_times = self._get_fixture_data(request, fixture_name) + + shift = 100 + times_recording.shift_times(shift=shift) + + times_recording.save(folder=tmp_path / "my_file") + + loaded_recording = si.load_extractor(tmp_path / "my_file") + + for idx in range(times_recording.get_num_segments()): + assert np.array_equal( + times_recording.get_times(segment_index=idx), loaded_recording.get_times(segment_index=idx) + ) + + def _store_all_times(self, recording): + """ + Convenience function to store original times of all segments to a dict. + """ + num_segments = recording.get_num_segments() + seg_data = {} + + for idx in range(num_segments): + seg_data[idx] = copy.deepcopy(recording.get_times(segment_index=idx)) + + return num_segments, seg_data + + # ######################################################################### + # Helpers + # ######################################################################### + def _check_times_match(self, recording, all_times): """ For every segment in a recording, check the `get_times()` diff --git a/src/spikeinterface/core/tests/test_unitsselectionsorting.py b/src/spikeinterface/core/tests/test_unitsselectionsorting.py index 1e72b0ab28..3ecb702aa2 100644 --- a/src/spikeinterface/core/tests/test_unitsselectionsorting.py +++ b/src/spikeinterface/core/tests/test_unitsselectionsorting.py @@ -10,25 +10,29 @@ def test_basic_functions(): sorting = generate_sorting(num_units=3, durations=[0.100, 0.100], sampling_frequency=30000.0) - sorting2 = UnitsSelectionSorting(sorting, unit_ids=[0, 2]) - assert np.array_equal(sorting2.unit_ids, [0, 2]) + sorting2 = UnitsSelectionSorting(sorting, unit_ids=["0", "2"]) + assert np.array_equal(sorting2.unit_ids, ["0", "2"]) assert sorting2.get_parent() == sorting - sorting3 = UnitsSelectionSorting(sorting, unit_ids=[0, 2], renamed_unit_ids=["a", "b"]) + sorting3 = UnitsSelectionSorting(sorting, unit_ids=["0", "2"], renamed_unit_ids=["a", "b"]) assert np.array_equal(sorting3.unit_ids, ["a", "b"]) assert np.array_equal( - sorting.get_unit_spike_train(0, segment_index=0), sorting2.get_unit_spike_train(0, segment_index=0) + sorting.get_unit_spike_train(unit_id="0", segment_index=0), + sorting2.get_unit_spike_train(unit_id="0", segment_index=0), ) assert np.array_equal( - sorting.get_unit_spike_train(0, segment_index=0), sorting3.get_unit_spike_train("a", segment_index=0) + sorting.get_unit_spike_train(unit_id="0", segment_index=0), + sorting3.get_unit_spike_train(unit_id="a", segment_index=0), ) assert np.array_equal( - sorting.get_unit_spike_train(2, segment_index=0), sorting2.get_unit_spike_train(2, segment_index=0) + sorting.get_unit_spike_train(unit_id="2", segment_index=0), + sorting2.get_unit_spike_train(unit_id="2", segment_index=0), ) assert np.array_equal( - sorting.get_unit_spike_train(2, segment_index=0), sorting3.get_unit_spike_train("b", segment_index=0) + sorting.get_unit_spike_train(unit_id="2", segment_index=0), + sorting3.get_unit_spike_train(unit_id="b", segment_index=0), ) @@ -36,13 +40,13 @@ def test_failure_with_non_unique_unit_ids(): seed = 10 sorting = generate_sorting(num_units=3, durations=[0.100], sampling_frequency=30000.0, seed=seed) with pytest.raises(AssertionError): - sorting2 = UnitsSelectionSorting(sorting, unit_ids=[0, 2], renamed_unit_ids=["a", "a"]) + sorting2 = UnitsSelectionSorting(sorting, unit_ids=["0", "2"], renamed_unit_ids=["a", "a"]) def test_custom_cache_spike_vector(): sorting = generate_sorting(num_units=3, durations=[0.100, 0.100], sampling_frequency=30000.0) - sub_sorting = UnitsSelectionSorting(sorting, unit_ids=[2, 0], renamed_unit_ids=["b", "a"]) + sub_sorting = UnitsSelectionSorting(sorting, unit_ids=["2", "0"], renamed_unit_ids=["b", "a"]) cached_spike_vector = sub_sorting.to_spike_vector(use_cache=True) computed_spike_vector = sub_sorting.to_spike_vector(use_cache=False) assert np.all(cached_spike_vector == computed_spike_vector) diff --git a/src/spikeinterface/curation/__init__.py b/src/spikeinterface/curation/__init__.py index 657b936fb9..975f2fe22f 100644 --- a/src/spikeinterface/curation/__init__.py +++ b/src/spikeinterface/curation/__init__.py @@ -3,7 +3,7 @@ from .remove_redundant import remove_redundant_units, find_redundant_units from .remove_duplicated_spikes import remove_duplicated_spikes from .remove_excess_spikes import remove_excess_spikes -from .auto_merge import get_potential_auto_merge +from .auto_merge import compute_merge_unit_groups, auto_merge_units, get_potential_auto_merge # manual sorting, @@ -15,3 +15,7 @@ from .curation_format import validate_curation_dict, curation_label_to_dataframe, apply_curation from .sortingview_curation import apply_sortingview_curation + +# automated curation +from .model_based_curation import auto_label_units, load_model +from .train_manual_curation import train_model, get_default_classifier_search_spaces diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 19336e5943..4f4cff144e 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -1,5 +1,7 @@ from __future__ import annotations +import warnings + from typing import Tuple import numpy as np import math @@ -12,49 +14,90 @@ HAVE_NUMBA = False from ..core import SortingAnalyzer, Templates -from ..core.template_tools import get_template_extremum_channel -from ..postprocessing import compute_correlograms from ..qualitymetrics import compute_refrac_period_violations, compute_firing_rates from .mergeunitssorting import MergeUnitsSorting from .curation_tools import resolve_merging_graph - -_possible_presets = ["similarity_correlograms", "x_contaminations", "temporal_splits", "feature_neighbors"] +_compute_merge_presets = { + "similarity_correlograms": [ + "num_spikes", + "remove_contaminated", + "unit_locations", + "template_similarity", + "correlogram", + "quality_score", + ], + "temporal_splits": [ + "num_spikes", + "remove_contaminated", + "unit_locations", + "template_similarity", + "presence_distance", + "quality_score", + ], + "x_contaminations": [ + "num_spikes", + "remove_contaminated", + "unit_locations", + "template_similarity", + "cross_contamination", + "quality_score", + ], + "feature_neighbors": [ + "num_spikes", + "snr", + "remove_contaminated", + "unit_locations", + "knn", + "quality_score", + ], +} _required_extensions = { - "unit_locations": ["unit_locations"], + "unit_locations": ["templates", "unit_locations"], "correlogram": ["correlograms"], - "template_similarity": ["template_similarity"], - "knn": ["spike_locations", "spike_amplitudes"], + "snr": ["templates", "noise_levels"], + "template_similarity": ["templates", "template_similarity"], + "knn": ["templates", "spike_locations", "spike_amplitudes"], } -def get_potential_auto_merge( +_default_step_params = { + "num_spikes": {"min_spikes": 100}, + "snr": {"min_snr": 2}, + "remove_contaminated": {"contamination_thresh": 0.2, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, + "unit_locations": {"max_distance_um": 150}, + "correlogram": { + "corr_diff_thresh": 0.16, + "censor_correlograms_ms": 0.15, + "sigma_smooth_ms": 0.6, + "adaptative_window_thresh": 0.5, + }, + "template_similarity": {"template_diff_thresh": 0.25}, + "presence_distance": {"presence_distance_thresh": 100}, + "knn": {"k_nn": 10}, + "cross_contamination": { + "cc_thresh": 0.1, + "p_value": 0.2, + "refractory_period_ms": 1.0, + "censored_period_ms": 0.3, + }, + "quality_score": {"firing_contamination_balance": 1.5, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, +} + + +def compute_merge_unit_groups( sorting_analyzer: SortingAnalyzer, preset: str | None = "similarity_correlograms", - resolve_graph: bool = False, - min_spikes: int = 100, - min_snr: float = 2, - max_distance_um: float = 150.0, - corr_diff_thresh: float = 0.16, - template_diff_thresh: float = 0.25, - contamination_thresh: float = 0.2, - presence_distance_thresh: float = 100, - p_value: float = 0.2, - cc_thresh: float = 0.1, - censored_period_ms: float = 0.3, - refractory_period_ms: float = 1.0, - sigma_smooth_ms: float = 0.6, - adaptative_window_thresh: float = 0.5, - censor_correlograms_ms: float = 0.15, - firing_contamination_balance: float = 2.5, - k_nn: int = 10, - knn_kwargs: dict | None = None, - presence_distance_kwargs: dict | None = None, + resolve_graph: bool = True, + steps_params: dict = None, + compute_needed_extensions: bool = True, extra_outputs: bool = False, steps: list[str] | None = None, -) -> list[tuple[int | str, int | str]] | Tuple[tuple[int | str, int | str], dict]: + force_copy: bool = True, + **job_kwargs, +) -> list[tuple[int | str, int | str]] | Tuple[list[tuple[int | str, int | str]], dict]: """ Algorithm to find and check potential merges between units. @@ -78,6 +121,9 @@ def get_potential_auto_merge( Q = f(1 - (k + 1)C) + IMPORTANT: internally, all computations are relying on extensions of the analyzer, that are computed + with default parameters if not present (i.e. correlograms, template_similarity, ...) If you want to + have a finer control on these values, please precompute the extensions before applying the auto_merge Parameters ---------- @@ -98,47 +144,11 @@ def get_potential_auto_merge( * | "feature_neighbors": focused on finding unit pairs whose spikes are close in the feature space using kNN. | It uses the following steps: "num_spikes", "snr", "remove_contaminated", "unit_locations", | "knn", "quality_score" - If `preset` is None, you can specify the steps manually with the `steps` parameter. - resolve_graph : bool, default: False + resolve_graph : bool, default: True If True, the function resolves the potential unit pairs to be merged into multiple-unit merges. - min_spikes : int, default: 100 - Minimum number of spikes for each unit to consider a potential merge. - Enough spikes are needed to estimate the correlogram - min_snr : float, default 2 - Minimum Signal to Noise ratio for templates to be considered while merging - max_distance_um : float, default: 150 - Maximum distance between units for considering a merge - corr_diff_thresh : float, default: 0.16 - The threshold on the "correlogram distance metric" for considering a merge. - It needs to be between 0 and 1 - template_diff_thresh : float, default: 0.25 - The threshold on the "template distance metric" for considering a merge. - It needs to be between 0 and 1 - contamination_thresh : float, default: 0.2 - Threshold for not taking in account a unit when it is too contaminated. - presence_distance_thresh : float, default: 100 - Parameter to control how present two units should be simultaneously. - p_value : float, default: 0.2 - The p-value threshold for the cross-contamination test. - cc_thresh : float, default: 0.1 - The threshold on the cross-contamination for considering a merge. - censored_period_ms : float, default: 0.3 - Used to compute the refractory period violations aka "contamination". - refractory_period_ms : float, default: 1 - Used to compute the refractory period violations aka "contamination". - sigma_smooth_ms : float, default: 0.6 - Parameters to smooth the correlogram estimation. - adaptative_window_thresh : float, default: 0.5 - Parameter to detect the window size in correlogram estimation. - censor_correlograms_ms : float, default: 0.15 - The period to censor on the auto and cross-correlograms. - firing_contamination_balance : float, default: 2.5 - Parameter to control the balance between firing rate and contamination in computing unit "quality score". - k_nn : int, default 5 - The number of neighbors to consider for every spike in the recording. - knn_kwargs : dict, default None - The dict of extra params to be passed to knn. + compute_needed_extensions : bool, default : True + Should we force the computation of needed extensions, if not already computed? extra_outputs : bool, default: False If True, an additional dictionary (`outs`) with processed data is returned. steps : None or list of str, default: None @@ -146,157 +156,141 @@ def get_potential_auto_merge( Pontential steps : "num_spikes", "snr", "remove_contaminated", "unit_locations", "correlogram", "template_similarity", "presence_distance", "cross_contamination", "knn", "quality_score" Please check steps explanations above! - presence_distance_kwargs : None|dict, default: None - A dictionary of kwargs to be passed to compute_presence_distance(). + steps_params : dict + A dictionary whose keys are the steps, and keys are steps parameters. + force_copy : boolean, default: True + When new extensions are computed, the default is to make a copy of the analyzer, to avoid overwriting + already computed extensions. False if you want to overwrite Returns ------- - potential_merges: - A list of tuples of 2 elements (if `resolve_graph`if false) or 2+ elements (if `resolve_graph` is true). - List of pairs that could be merged. + merge_unit_groups: + List of groups that need to be merge. + When `resolve_graph` is true (default) a list of tuples of 2+ elements + If `resolve_graph` is false then a list of tuple of 2 elements is returned instead. outs: Returned only when extra_outputs=True A dictionary that contains data for debugging and plotting. References ---------- - This function is inspired and built upon similar functions from Lussac [Llobet]_, + This function used to be inspired and built upon similar functions from Lussac [Llobet]_, done by Aurelien Wyngaard and Victor Llobet. https://github.com/BarbourLab/lussac/blob/v1.0.0/postprocessing/merge_units.py + + However, it has been greatly consolidated and refined depending on the presets. """ import scipy sorting = sorting_analyzer.sorting unit_ids = sorting.unit_ids - # to get fast computation we will not analyse pairs when: - # * not enough spikes for one of theses - # * auto correlogram is contaminated - # * to far away one from each other - - all_steps = [ - "num_spikes", - "snr", - "remove_contaminated", - "unit_locations", - "correlogram", - "template_similarity", - "presence_distance", - "knn", - "cross_contamination", - "quality_score", - ] - - if preset is not None and preset not in _possible_presets: - raise ValueError(f"preset must be one of {_possible_presets}") - - if steps is None: - if preset is None: - if steps is None: - raise ValueError("You need to specify a preset or steps for the auto-merge function") - elif preset == "similarity_correlograms": - steps = [ - "num_spikes", - "remove_contaminated", - "unit_locations", - "template_similarity", - "correlogram", - "quality_score", - ] - elif preset == "temporal_splits": - steps = [ - "num_spikes", - "remove_contaminated", - "unit_locations", - "template_similarity", - "presence_distance", - "quality_score", - ] - elif preset == "x_contaminations": - steps = [ - "num_spikes", - "remove_contaminated", - "unit_locations", - "template_similarity", - "cross_contamination", - "quality_score", - ] - elif preset == "feature_neighbors": - steps = [ - "num_spikes", - "snr", - "remove_contaminated", - "unit_locations", - "knn", - "quality_score", - ] - + if preset is None and steps is None: + raise ValueError("You need to specify a preset or steps for the auto-merge function") + elif steps is not None: + # steps has precedence on presets + pass + elif preset is not None: + if preset not in _compute_merge_presets: + raise ValueError(f"preset must be one of {list(_compute_merge_presets.keys())}") + steps = _compute_merge_presets[preset] + + # check at least one extension is needed + at_least_one_extension_to_compute = False for step in steps: + assert step in _default_step_params, f"{step} is not a valid step" if step in _required_extensions: for ext in _required_extensions[step]: - if not sorting_analyzer.has_extension(ext): + if sorting_analyzer.has_extension(ext): + continue + if not compute_needed_extensions: raise ValueError(f"{step} requires {ext} extension") + at_least_one_extension_to_compute = True + + if force_copy and at_least_one_extension_to_compute: + # To avoid erasing the extensions of the user + sorting_analyzer = sorting_analyzer.copy() n = unit_ids.size - pair_mask = np.triu(np.arange(n)) > 0 + pair_mask = np.triu(np.arange(n), 1) > 0 outs = dict() for step in steps: - assert step in all_steps, f"{step} is not a valid step" + if step in _required_extensions: + for ext in _required_extensions[step]: + if sorting_analyzer.has_extension(ext): + continue + + # special case for templates + if ext == "templates" and not sorting_analyzer.has_extension("random_spikes"): + sorting_analyzer.compute(["random_spikes", "templates"], **job_kwargs) + else: + sorting_analyzer.compute(ext, **job_kwargs) + + params = _default_step_params.get(step).copy() + if steps_params is not None and step in steps_params: + params.update(steps_params[step]) # STEP : remove units with too few spikes if step == "num_spikes": + num_spikes = sorting.count_num_spikes_per_unit(outputs="array") - to_remove = num_spikes < min_spikes + to_remove = num_spikes < params["min_spikes"] pair_mask[to_remove, :] = False pair_mask[:, to_remove] = False + outs["num_spikes"] = to_remove # STEP : remove units with too small SNR elif step == "snr": qm_ext = sorting_analyzer.get_extension("quality_metrics") if qm_ext is None: - sorting_analyzer.compute("noise_levels") - sorting_analyzer.compute("quality_metrics", metric_names=["snr"]) + sorting_analyzer.compute("quality_metrics", metric_names=["snr"], **job_kwargs) qm_ext = sorting_analyzer.get_extension("quality_metrics") snrs = qm_ext.get_data()["snr"].values - to_remove = snrs < min_snr + to_remove = snrs < params["min_snr"] pair_mask[to_remove, :] = False pair_mask[:, to_remove] = False + outs["snr"] = to_remove # STEP : remove contaminated auto corr elif step == "remove_contaminated": contaminations, nb_violations = compute_refrac_period_violations( - sorting_analyzer, refractory_period_ms=refractory_period_ms, censored_period_ms=censored_period_ms + sorting_analyzer, + refractory_period_ms=params["refractory_period_ms"], + censored_period_ms=params["censored_period_ms"], ) nb_violations = np.array(list(nb_violations.values())) contaminations = np.array(list(contaminations.values())) - to_remove = contaminations > contamination_thresh + to_remove = contaminations > params["contamination_thresh"] pair_mask[to_remove, :] = False pair_mask[:, to_remove] = False + outs["remove_contaminated"] = to_remove # STEP : unit positions are estimated roughly with channel - elif step == "unit_locations" in steps: + elif step == "unit_locations": location_ext = sorting_analyzer.get_extension("unit_locations") unit_locations = location_ext.get_data()[:, :2] unit_distances = scipy.spatial.distance.cdist(unit_locations, unit_locations, metric="euclidean") - pair_mask = pair_mask & (unit_distances <= max_distance_um) + pair_mask = pair_mask & (unit_distances <= params["max_distance_um"]) outs["unit_distances"] = unit_distances # STEP : potential auto merge by correlogram - elif step == "correlogram" in steps: + elif step == "correlogram": correlograms_ext = sorting_analyzer.get_extension("correlograms") correlograms, bins = correlograms_ext.get_data() - mask = (bins[:-1] >= -censor_correlograms_ms) & (bins[:-1] < censor_correlograms_ms) + censor_ms = params["censor_correlograms_ms"] + sigma_smooth_ms = params["sigma_smooth_ms"] + mask = (bins[:-1] >= -censor_ms) & (bins[:-1] < censor_ms) correlograms[:, :, mask] = 0 correlograms_smoothed = smooth_correlogram(correlograms, bins, sigma_smooth_ms=sigma_smooth_ms) # find correlogram window for each units win_sizes = np.zeros(n, dtype=int) for unit_ind in range(n): auto_corr = correlograms_smoothed[unit_ind, unit_ind, :] - thresh = np.max(auto_corr) * adaptative_window_thresh + thresh = np.max(auto_corr) * params["adaptative_window_thresh"] win_size = get_unit_adaptive_window(auto_corr, thresh) win_sizes[unit_ind] = win_size correlogram_diff = compute_correlogram_diff( @@ -306,7 +300,7 @@ def get_potential_auto_merge( pair_mask=pair_mask, ) # print(correlogram_diff) - pair_mask = pair_mask & (correlogram_diff < corr_diff_thresh) + pair_mask = pair_mask & (correlogram_diff < params["corr_diff_thresh"]) outs["correlograms"] = correlograms outs["bins"] = bins outs["correlograms_smoothed"] = correlograms_smoothed @@ -314,22 +308,21 @@ def get_potential_auto_merge( outs["win_sizes"] = win_sizes # STEP : check if potential merge with CC also have template similarity - elif step == "template_similarity" in steps: + elif step == "template_similarity": template_similarity_ext = sorting_analyzer.get_extension("template_similarity") templates_similarity = template_similarity_ext.get_data() templates_diff = 1 - templates_similarity - pair_mask = pair_mask & (templates_diff < template_diff_thresh) + pair_mask = pair_mask & (templates_diff < params["template_diff_thresh"]) outs["templates_diff"] = templates_diff # STEP : check the vicinity of the spikes - elif step == "knn" in steps: - if knn_kwargs is None: - knn_kwargs = dict() - pair_mask = get_pairs_via_nntree(sorting_analyzer, k_nn, pair_mask, **knn_kwargs) + elif step == "knn": + pair_mask = get_pairs_via_nntree(sorting_analyzer, **params, pair_mask=pair_mask) # STEP : check how the rates overlap in times - elif step == "presence_distance" in steps: - presence_distance_kwargs = presence_distance_kwargs or dict() + elif step == "presence_distance": + presence_distance_kwargs = params.copy() + presence_distance_thresh = presence_distance_kwargs.pop("presence_distance_thresh") num_samples = [ sorting_analyzer.get_num_samples(segment_index) for segment_index in range(sorting.get_num_segments()) ] @@ -340,40 +333,243 @@ def get_potential_auto_merge( outs["presence_distances"] = presence_distances # STEP : check if the cross contamination is significant - elif step == "cross_contamination" in steps: - refractory = (censored_period_ms, refractory_period_ms) + elif step == "cross_contamination": + refractory = ( + params["censored_period_ms"], + params["refractory_period_ms"], + ) CC, p_values = compute_cross_contaminations( - sorting_analyzer, pair_mask, cc_thresh, refractory, contaminations + sorting_analyzer, pair_mask, params["cc_thresh"], refractory, contaminations ) - pair_mask = pair_mask & (p_values > p_value) + pair_mask = pair_mask & (p_values > params["p_value"]) outs["cross_contaminations"] = CC, p_values # STEP : validate the potential merges with CC increase the contamination quality metrics - elif step == "quality_score" in steps: + elif step == "quality_score": pair_mask, pairs_decreased_score = check_improve_contaminations_score( sorting_analyzer, pair_mask, contaminations, - firing_contamination_balance, - refractory_period_ms, - censored_period_ms, + params["firing_contamination_balance"], + params["refractory_period_ms"], + params["censored_period_ms"], ) outs["pairs_decreased_score"] = pairs_decreased_score # FINAL STEP : create the final list from pair_mask boolean matrix ind1, ind2 = np.nonzero(pair_mask) - potential_merges = list(zip(unit_ids[ind1], unit_ids[ind2])) - - # some methods return identities ie (1,1) which we can cleanup first. - potential_merges = [(ids[0], ids[1]) for ids in potential_merges if ids[0] != ids[1]] + merge_unit_groups = list(zip(unit_ids[ind1], unit_ids[ind2])) if resolve_graph: - potential_merges = resolve_merging_graph(sorting, potential_merges) + merge_unit_groups = resolve_merging_graph(sorting, merge_unit_groups) if extra_outputs: - return potential_merges, outs + return merge_unit_groups, outs else: - return potential_merges + return merge_unit_groups + + +def auto_merge_units( + sorting_analyzer: SortingAnalyzer, compute_merge_kwargs: dict = {}, apply_merge_kwargs: dict = {}, **job_kwargs +) -> SortingAnalyzer: + """ + Compute merge unit groups and apply it on a SortingAnalyzer. + Internally uses `compute_merge_unit_groups()` + """ + merge_unit_groups = compute_merge_unit_groups( + sorting_analyzer, extra_outputs=False, **compute_merge_kwargs, **job_kwargs + ) + + merged_analyzer = sorting_analyzer.merge_units(merge_unit_groups, **apply_merge_kwargs, **job_kwargs) + return merged_analyzer + + +def get_potential_auto_merge( + sorting_analyzer: SortingAnalyzer, + preset: str | None = "similarity_correlograms", + resolve_graph: bool = False, + min_spikes: int = 100, + min_snr: float = 2, + max_distance_um: float = 150.0, + corr_diff_thresh: float = 0.16, + template_diff_thresh: float = 0.25, + contamination_thresh: float = 0.2, + presence_distance_thresh: float = 100, + p_value: float = 0.2, + cc_thresh: float = 0.1, + censored_period_ms: float = 0.3, + refractory_period_ms: float = 1.0, + sigma_smooth_ms: float = 0.6, + adaptative_window_thresh: float = 0.5, + censor_correlograms_ms: float = 0.15, + firing_contamination_balance: float = 1.5, + k_nn: int = 10, + knn_kwargs: dict | None = None, + presence_distance_kwargs: dict | None = None, + extra_outputs: bool = False, + steps: list[str] | None = None, +) -> list[tuple[int | str, int | str]] | Tuple[tuple[int | str, int | str], dict]: + """ + This function is deprecated. Use compute_merge_unit_groups() instead. + This will be removed in 0.103.0 + + Algorithm to find and check potential merges between units. + + The merges are proposed based on a series of steps with different criteria: + + * "num_spikes": enough spikes are found in each unit for computing the correlogram (`min_spikes`) + * "snr": the SNR of the units is above a threshold (`min_snr`) + * "remove_contaminated": each unit is not contaminated (by checking auto-correlogram - `contamination_thresh`) + * "unit_locations": estimated unit locations are close enough (`max_distance_um`) + * "correlogram": the cross-correlograms of the two units are similar to each auto-corrleogram (`corr_diff_thresh`) + * "template_similarity": the templates of the two units are similar (`template_diff_thresh`) + * "presence_distance": the presence of the units is complementary in time (`presence_distance_thresh`) + * "cross_contamination": the cross-contamination is not significant (`cc_thresh` and `p_value`) + * "knn": the two units are close in the feature space + * "quality_score": the unit "quality score" is increased after the merge + + The "quality score" factors in the increase in firing rate (**f**) due to the merge and a possible increase in + contamination (**C**), wheighted by a factor **k** (`firing_contamination_balance`). + + .. math:: + + Q = f(1 - (k + 1)C) + + IMPORTANT: internally, all computations are relying on extensions of the analyzer, that are computed + with default parameters if not present (i.e. correlograms, template_similarity, ...) If you want to + have a finer control on these values, please precompute the extensions before applying the auto_merge + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The SortingAnalyzer + preset : "similarity_correlograms" | "x_contaminations" | "temporal_splits" | "feature_neighbors" | None, default: "similarity_correlograms" + The preset to use for the auto-merge. Presets combine different steps into a recipe and focus on: + + * | "similarity_correlograms": mainly focused on template similarity and correlograms. + | It uses the following steps: "num_spikes", "remove_contaminated", "unit_locations", + | "template_similarity", "correlogram", "quality_score" + * | "x_contaminations": similar to "similarity_correlograms", but checks for cross-contamination instead of correlograms. + | It uses the following steps: "num_spikes", "remove_contaminated", "unit_locations", + | "template_similarity", "cross_contamination", "quality_score" + * | "temporal_splits": focused on finding temporal splits using presence distance. + | It uses the following steps: "num_spikes", "remove_contaminated", "unit_locations", + | "template_similarity", "presence_distance", "quality_score" + * | "feature_neighbors": focused on finding unit pairs whose spikes are close in the feature space using kNN. + | It uses the following steps: "num_spikes", "snr", "remove_contaminated", "unit_locations", + | "knn", "quality_score" + + If `preset` is None, you can specify the steps manually with the `steps` parameter. + resolve_graph : bool, default: False + If True, the function resolves the potential unit pairs to be merged into multiple-unit merges. + min_spikes : int, default: 100 + Minimum number of spikes for each unit to consider a potential merge. + Enough spikes are needed to estimate the correlogram + min_snr : float, default 2 + Minimum Signal to Noise ratio for templates to be considered while merging + max_distance_um : float, default: 150 + Maximum distance between units for considering a merge + corr_diff_thresh : float, default: 0.16 + The threshold on the "correlogram distance metric" for considering a merge. + It needs to be between 0 and 1 + template_diff_thresh : float, default: 0.25 + The threshold on the "template distance metric" for considering a merge. + It needs to be between 0 and 1 + contamination_thresh : float, default: 0.2 + Threshold for not taking in account a unit when it is too contaminated. + presence_distance_thresh : float, default: 100 + Parameter to control how present two units should be simultaneously. + p_value : float, default: 0.2 + The p-value threshold for the cross-contamination test. + cc_thresh : float, default: 0.1 + The threshold on the cross-contamination for considering a merge. + censored_period_ms : float, default: 0.3 + Used to compute the refractory period violations aka "contamination". + refractory_period_ms : float, default: 1 + Used to compute the refractory period violations aka "contamination". + sigma_smooth_ms : float, default: 0.6 + Parameters to smooth the correlogram estimation. + adaptative_window_thresh : float, default: 0.5 + Parameter to detect the window size in correlogram estimation. + censor_correlograms_ms : float, default: 0.15 + The period to censor on the auto and cross-correlograms. + firing_contamination_balance : float, default: 1.5 + Parameter to control the balance between firing rate and contamination in computing unit "quality score". + k_nn : int, default 5 + The number of neighbors to consider for every spike in the recording. + knn_kwargs : dict, default None + The dict of extra params to be passed to knn. + extra_outputs : bool, default: False + If True, an additional dictionary (`outs`) with processed data is returned. + steps : None or list of str, default: None + Which steps to run, if no preset is used. + Pontential steps : "num_spikes", "snr", "remove_contaminated", "unit_locations", "correlogram", + "template_similarity", "presence_distance", "cross_contamination", "knn", "quality_score" + Please check steps explanations above! + presence_distance_kwargs : None|dict, default: None + A dictionary of kwargs to be passed to compute_presence_distance(). + + Returns + ------- + potential_merges: + A list of tuples of 2 elements (if `resolve_graph`if false) or 2+ elements (if `resolve_graph` is true). + List of pairs that could be merged. + outs: + Returned only when extra_outputs=True + A dictionary that contains data for debugging and plotting. + + References + ---------- + This function is inspired and built upon similar functions from Lussac [Llobet]_, + done by Aurelien Wyngaard and Victor Llobet. + https://github.com/BarbourLab/lussac/blob/v1.0.0/postprocessing/merge_units.py + """ + warnings.warn( + "get_potential_auto_merge() is deprecated. Use compute_merge_unit_groups() instead", + DeprecationWarning, + stacklevel=2, + ) + + presence_distance_kwargs = presence_distance_kwargs or dict() + knn_kwargs = knn_kwargs or dict() + return compute_merge_unit_groups( + sorting_analyzer, + preset, + resolve_graph, + steps_params={ + "num_spikes": {"min_spikes": min_spikes}, + "snr": {"min_snr": min_snr}, + "remove_contaminated": { + "contamination_thresh": contamination_thresh, + "refractory_period_ms": refractory_period_ms, + "censored_period_ms": censored_period_ms, + }, + "unit_locations": {"max_distance_um": max_distance_um}, + "correlogram": { + "corr_diff_thresh": corr_diff_thresh, + "censor_correlograms_ms": censor_correlograms_ms, + "sigma_smooth_ms": sigma_smooth_ms, + "adaptative_window_thresh": adaptative_window_thresh, + }, + "template_similarity": {"template_diff_thresh": template_diff_thresh}, + "presence_distance": {"presence_distance_thresh": presence_distance_thresh, **presence_distance_kwargs}, + "knn": {"k_nn": k_nn, **knn_kwargs}, + "cross_contamination": { + "cc_thresh": cc_thresh, + "p_value": p_value, + "refractory_period_ms": refractory_period_ms, + "censored_period_ms": censored_period_ms, + }, + "quality_score": { + "firing_contamination_balance": firing_contamination_balance, + "refractory_period_ms": refractory_period_ms, + "censored_period_ms": censored_period_ms, + }, + }, + compute_needed_extensions=True, + extra_outputs=extra_outputs, + steps=steps, + ) def get_pairs_via_nntree(sorting_analyzer, k_nn=5, pair_mask=None, **knn_kwargs): @@ -661,10 +857,10 @@ def check_improve_contaminations_score( f_new = compute_firing_rates(sorting_analyzer_new)[unit_id1] # old and new scores - k = firing_contamination_balance - score_1 = f_1 * (1 - (k + 1) * c_1) - score_2 = f_2 * (1 - (k + 1) * c_2) - score_new = f_new * (1 - (k + 1) * c_new) + k = 1 + firing_contamination_balance + score_1 = f_1 * (1 - k * c_1) + score_2 = f_2 * (1 - k * c_2) + score_new = f_new * (1 - k * c_new) if score_new < score_1 or score_new < score_2: # the score is not improved diff --git a/src/spikeinterface/curation/curation_format.py b/src/spikeinterface/curation/curation_format.py index 5f85538b08..80f251ca43 100644 --- a/src/spikeinterface/curation/curation_format.py +++ b/src/spikeinterface/curation/curation_format.py @@ -45,12 +45,16 @@ def validate_curation_dict(curation_dict): if not removed_units_set.issubset(unit_set): raise ValueError("Curation format: some removed units are not in the unit list") + for group in curation_dict["merge_unit_groups"]: + if len(group) < 2: + raise ValueError("Curation format: 'merge_unit_groups' must be list of list with at least 2 elements") + all_merging_groups = [set(group) for group in curation_dict["merge_unit_groups"]] for gp_1, gp_2 in combinations(all_merging_groups, 2): if len(gp_1.intersection(gp_2)) != 0: - raise ValueError("Some units belong to multiple merge groups") + raise ValueError("Curation format: some units belong to multiple merge groups") if len(removed_units_set.intersection(merged_units_set)) != 0: - raise ValueError("Some units were merged and deleted") + raise ValueError("Curation format: some units were merged and deleted") # Check the labels exclusivity for lbl in curation_dict["manual_labels"]: @@ -238,7 +242,7 @@ def apply_curation_labels(sorting, new_unit_ids, curation_dict): all_values = np.zeros(sorting.unit_ids.size, dtype=values.dtype) for unit_ind, unit_id in enumerate(sorting.unit_ids): if unit_id not in new_unit_ids: - ind = curation_dict["unit_ids"].index(unit_id) + ind = list(curation_dict["unit_ids"]).index(unit_id) all_values[unit_ind] = values[ind] sorting.set_property(key, all_values) @@ -253,7 +257,7 @@ def apply_curation_labels(sorting, new_unit_ids, curation_dict): group_values.append(value) if len(set(group_values)) == 1: # all group has the same label or empty - sorting.set_property(key, values=group_values, ids=[new_unit_id]) + sorting.set_property(key, values=group_values[:1], ids=[new_unit_id]) else: for key in label_def["label_options"]: @@ -339,18 +343,22 @@ def apply_curation( elif isinstance(sorting_or_analyzer, SortingAnalyzer): analyzer = sorting_or_analyzer - analyzer = analyzer.remove_units(curation_dict["removed_units"]) - analyzer, new_unit_ids = analyzer.merge_units( - curation_dict["merge_unit_groups"], - censor_ms=censor_ms, - merging_mode=merging_mode, - sparsity_overlap=sparsity_overlap, - new_id_strategy=new_id_strategy, - return_new_unit_ids=True, - format="memory", - verbose=verbose, - **job_kwargs, - ) + if len(curation_dict["removed_units"]) > 0: + analyzer = analyzer.remove_units(curation_dict["removed_units"]) + if len(curation_dict["merge_unit_groups"]) > 0: + analyzer, new_unit_ids = analyzer.merge_units( + curation_dict["merge_unit_groups"], + censor_ms=censor_ms, + merging_mode=merging_mode, + sparsity_overlap=sparsity_overlap, + new_id_strategy=new_id_strategy, + return_new_unit_ids=True, + format="memory", + verbose=verbose, + **job_kwargs, + ) + else: + new_unit_ids = [] apply_curation_labels(analyzer.sorting, new_unit_ids, curation_dict) return analyzer else: diff --git a/src/spikeinterface/curation/model_based_curation.py b/src/spikeinterface/curation/model_based_curation.py new file mode 100644 index 0000000000..93ad03734c --- /dev/null +++ b/src/spikeinterface/curation/model_based_curation.py @@ -0,0 +1,435 @@ +import numpy as np +from pathlib import Path +import json +import warnings +import re + +from spikeinterface.core import SortingAnalyzer +from spikeinterface.curation.train_manual_curation import ( + try_to_get_metrics_from_analyzer, + _get_computed_metrics, + _format_metric_dataframe, +) +from copy import deepcopy + + +class ModelBasedClassification: + """ + Class for performing model-based classification on spike sorting data. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The sorting analyzer object containing the spike sorting data. + pipeline : Pipeline + The pipeline object representing the trained classification model. + + Attributes + ---------- + sorting_analyzer : SortingAnalyzer + The sorting analyzer object containing the spike sorting data. + pipeline : Pipeline + The pipeline object representing the trained classification model. + required_metrics : Sequence[str] + The list of required metrics for classification, extracted from the pipeline. + + Methods + ------- + predict_labels() + Predicts the labels for the spike sorting data using the trained model. + """ + + def __init__(self, sorting_analyzer: SortingAnalyzer, pipeline): + from sklearn.pipeline import Pipeline + + if not isinstance(pipeline, Pipeline): + raise ValueError("The `pipeline` must be an instance of sklearn.pipeline.Pipeline") + + self.sorting_analyzer = sorting_analyzer + self.pipeline = pipeline + self.required_metrics = pipeline.feature_names_in_ + + def predict_labels( + self, label_conversion=None, input_data=None, export_to_phy=False, model_info=None, enforce_metric_params=False + ): + """ + Predicts the labels for the spike sorting data using the trained model. + Populates the sorting object with the predicted labels and probabilities as unit properties + + Parameters + ---------- + model_info : dict or None, default: None + Model info, generated with model, used to check metric parameters used to train it. + label_conversion : dict or None, default: None + A dictionary for converting the predicted labels (which are integers) to custom labels. If None, + tries to find in `model_info` file. The dictionary should have the format {old_label: new_label}. + input_data : pandas.DataFrame or None, default: None + The input data for classification. If not provided, the method will extract metrics stored in the sorting analyzer. + export_to_phy : bool, default: False. + Whether to export the classified units to Phy format. Default is False. + enforce_metric_params : bool, default: False + If True and the parameters used to compute the metrics in `sorting_analyzer` are different than the parmeters + used to compute the metrics used to train the model, this function will raise an error. Otherwise, a warning is raised. + + Returns + ------- + pd.DataFrame + A dataframe containing the classified units and their corresponding predictions and probabilities, + indexed by their `unit_ids`. + """ + import pandas as pd + + # Get metrics DataFrame for classification + if input_data is None: + input_data = _get_computed_metrics(self.sorting_analyzer) + else: + if not isinstance(input_data, pd.DataFrame): + raise ValueError("Input data must be a pandas DataFrame") + + input_data = self._check_required_metrics_are_present(input_data) + + if model_info is not None: + self._check_params_for_classification(enforce_metric_params, model_info=model_info) + + if model_info is not None and label_conversion is None: + try: + string_label_conversion = model_info["label_conversion"] + # json keys are strings; we convert these to ints + label_conversion = {} + for key, value in string_label_conversion.items(): + label_conversion[int(key)] = value + except: + warnings.warn("Could not find `label_conversion` key in `model_info.json` file") + + input_data = _format_metric_dataframe(input_data) + + # Apply classifier + predictions = self.pipeline.predict(input_data) + probabilities = self.pipeline.predict_proba(input_data) + probabilities = np.max(probabilities, axis=1) + + if isinstance(label_conversion, dict): + + if set(predictions).issubset(set(label_conversion.keys())) is False: + raise ValueError("Labels in predictions do not match those in label_conversion") + predictions = [label_conversion[label] for label in predictions] + + classified_units = pd.DataFrame( + zip(predictions, probabilities), columns=["prediction", "probability"], index=self.sorting_analyzer.unit_ids + ) + + # Set predictions and probability as sorting properties + self.sorting_analyzer.sorting.set_property("classifier_label", predictions) + self.sorting_analyzer.sorting.set_property("classifier_probability", probabilities) + + if export_to_phy: + self._export_to_phy(classified_units) + + return classified_units + + def _check_required_metrics_are_present(self, calculated_metrics): + + # Check all the required metrics have been calculated + required_metrics = set(self.required_metrics) + if required_metrics.issubset(set(calculated_metrics)): + input_data = calculated_metrics[self.required_metrics] + else: + raise ValueError( + "Input data does not contain all required metrics for classification", + f"Missing metrics: {required_metrics.difference(calculated_metrics)}", + ) + + return input_data + + def _check_params_for_classification(self, enforce_metric_params=False, model_info=None): + """ + Check that quality and template metrics parameters match those used to train the model + + Parameters + ---------- + enforce_metric_params : bool, default: False + If True and the parameters used to compute the metrics in `sorting_analyzer` are different than the parmeters + used to compute the metrics used to train the model, this function will raise an error. Otherwise, a warning is raised. + model_info : dict, default: None + Dictionary of model info containing provenance of the model. + """ + + extension_names = ["quality_metrics", "template_metrics"] + + metric_extensions = [self.sorting_analyzer.get_extension(extension_name) for extension_name in extension_names] + + for metric_extension, extension_name in zip(metric_extensions, extension_names): + + # remove the 's' at the end of the extension name + extension_name = extension_name[:-1] + model_extension_params = model_info["metric_params"].get(extension_name + "_params") + + if metric_extension is not None and model_extension_params is not None: + + metric_params = metric_extension.params["metric_params"] + + inconsistent_metrics = [] + for metric in model_extension_params["metric_names"]: + model_metric_params = model_extension_params.get("metric_params") + if model_metric_params is None or metric not in model_metric_params: + inconsistent_metrics.append(metric) + else: + if metric_params[metric] != model_metric_params[metric]: + warning_message = f"{extension_name} params for {metric} do not match those used to train the model. Parameters can be found in the 'model_info.json' file." + if enforce_metric_params is True: + raise Exception(warning_message) + else: + warnings.warn(warning_message) + + if len(inconsistent_metrics) > 0: + warning_message = f"Parameters used to compute metrics {inconsistent_metrics}, used to train this model, are unknown." + if enforce_metric_params is True: + raise Exception(warning_message) + else: + warnings.warn(warning_message) + + def _export_to_phy(self, classified_units): + """Export the classified units to Phy as cluster_prediction.tsv file""" + + import pandas as pd + + # Create a new DataFrame with unit_id, prediction, and probability columns from dict {unit_id: (prediction, probability)} + classified_df = pd.DataFrame.from_dict(classified_units, orient="index", columns=["prediction", "probability"]) + + # Export to Phy format + try: + sorting_path = self.sorting_analyzer.sorting.get_annotation("phy_folder") + assert sorting_path is not None + assert Path(sorting_path).is_dir() + except AssertionError: + raise ValueError("Phy folder not found in sorting annotations, or is not a directory") + + classified_df.to_csv(f"{sorting_path}/cluster_prediction.tsv", sep="\t", index_label="cluster_id") + + +def auto_label_units( + sorting_analyzer: SortingAnalyzer, + model_folder=None, + model_name=None, + repo_id=None, + label_conversion=None, + trust_model=False, + trusted=None, + export_to_phy=False, + enforce_metric_params=False, +): + """ + Automatically labels units based on a model-based classification, either from a model + hosted on HuggingFaceHub or one available in a local folder. + + This function returns the predicted labels and the prediction probabilities, and populates + the sorting object with the predicted labels and probabilities in the 'classifier_label' and + 'classifier_probability' properties. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The sorting analyzer object containing the spike sorting results. + model_folder : str or Path, defualt: None + The path to the folder containing the model + repo_id : str | Path, default: None + Hugging face repo id which contains the model e.g. 'username/model' + model_name: str | Path, default: None + Filename of model e.g. 'my_model.skops'. If None, uses first model found. + label_conversion : dic | None, default: None + A dictionary for converting the predicted labels (which are integers) to custom labels. If None, + tries to extract from `model_info.json` file. The dictionary should have the format {old_label: new_label}. + export_to_phy : bool, default: False + Whether to export the results to Phy format. Default is False. + trust_model : bool, default: False + Whether to trust the model. If True, the `trusted` parameter that is passed to `skops.load` to load the model will be + automatically inferred. If False, the `trusted` parameter must be provided to indicate the trusted objects. + trusted : list of str, default: None + Passed to skops.load. The object will be loaded only if there are only trusted objects and objects of types listed in trusted in the dumped file. + enforce_metric_params : bool, default: False + If True and the parameters used to compute the metrics in `sorting_analyzer` are different than the parmeters + used to compute the metrics used to train the model, this function will raise an error. Otherwise, a warning is raised. + + + Returns + ------- + classified_units : pd.DataFrame + A dataframe containing the classified units, indexed by the `unit_ids`, containing the predicted label + and confidence probability of each labelled unit. + + Raises + ------ + ValueError + If the pipeline is not an instance of sklearn.pipeline.Pipeline. + + """ + from sklearn.pipeline import Pipeline + + model, model_info = load_model( + model_folder=model_folder, repo_id=repo_id, model_name=model_name, trust_model=trust_model, trusted=trusted + ) + + if not isinstance(model, Pipeline): + raise ValueError("The model must be an instance of sklearn.pipeline.Pipeline") + + model_based_classification = ModelBasedClassification(sorting_analyzer, model) + + classified_units = model_based_classification.predict_labels( + label_conversion=label_conversion, + export_to_phy=export_to_phy, + model_info=model_info, + enforce_metric_params=enforce_metric_params, + ) + + return classified_units + + +def load_model(model_folder=None, repo_id=None, model_name=None, trust_model=False, trusted=None): + """ + Loads a model and model_info from a HuggingFaceHub repo or a local folder. + + Parameters + ---------- + model_folder : str or Path, defualt: None + The path to the folder containing the model + repo_id : str | Path, default: None + Hugging face repo id which contains the model e.g. 'username/model' + model_name: str | Path, default: None + Filename of model e.g. 'my_model.skops'. If None, uses first model found. + trust_model : bool, default: False + Whether to trust the model. If True, the `trusted` parameter that is passed to `skops.load` to load the model will be + automatically inferred. If False, the `trusted` parameter must be provided to indicate the trusted objects. + trusted : list of str, default: None + Passed to skops.load. The object will be loaded only if there are only trusted objects and objects of types listed in trusted in the dumped file. + + + Returns + ------- + model, model_info + A model and metadata about the model + """ + + if model_folder is None and repo_id is None: + raise ValueError("Please provide a 'model_folder' or a 'repo_id'.") + elif model_folder is not None and repo_id is not None: + raise ValueError("Please only provide one of 'model_folder' or 'repo_id'.") + elif model_folder is not None: + model, model_info = _load_model_from_folder( + model_folder=model_folder, model_name=model_name, trust_model=trust_model, trusted=trusted + ) + else: + model, model_info = _load_model_from_huggingface( + repo_id=repo_id, model_name=model_name, trust_model=trust_model, trusted=trusted + ) + + return model, model_info + + +def _load_model_from_huggingface(repo_id=None, model_name=None, trust_model=False, trusted=None): + """ + Loads a model from a huggingface repo + + Returns + ------- + model, model_info + A model and metadata about the model + """ + + from huggingface_hub import list_repo_files + from huggingface_hub import hf_hub_download + + # get repo filenames + repo_filenames = list_repo_files(repo_id=repo_id) + + # download all skops and json files to temp directory + for filename in repo_filenames: + if Path(filename).suffix in [".skops", ".json"]: + full_path = hf_hub_download(repo_id=repo_id, filename=filename) + model_folder = Path(full_path).parent + + model, model_info = _load_model_from_folder( + model_folder=model_folder, model_name=model_name, trust_model=trust_model, trusted=trusted + ) + + return model, model_info + + +def _load_model_from_folder(model_folder=None, model_name=None, trust_model=False, trusted=None): + """ + Loads a model and model_info from a folder + + Returns + ------- + model, model_info + A model and metadata about the model + """ + + import skops.io as skio + from skops.io.exceptions import UntrustedTypesFoundException + + folder = Path(model_folder) + assert folder.is_dir(), f"The folder {folder}, does not exist." + + # look for any .skops files + skops_files = list(folder.glob("*.skops")) + assert len(skops_files) > 0, f"There are no '.skops' files in the folder {folder}" + + if len(skops_files) > 1: + if model_name is None: + model_names = [f.name for f in skops_files] + raise ValueError( + f"There are more than 1 '.skops' file in folder {folder}. You have to specify " + f"the file using the 'model_name' argument. Available files:\n{model_names}" + ) + else: + skops_file = folder / Path(model_name) + assert skops_file.is_file(), f"Model file {skops_file} not found." + elif len(skops_files) == 1: + skops_file = skops_files[0] + + if trust_model and trusted is None: + try: + model = skio.load(skops_file) + except UntrustedTypesFoundException as e: + exception_msg = str(e) + # the exception message contains the list of untrusted objects. The following + # search assumes it is the only list in the message. + string_list = re.search(r"\[(.*?)\]", exception_msg).group() + trusted = [list_item for list_item in string_list.split("'") if len(list_item) > 2] + + model = skio.load(skops_file, trusted=trusted) + + model_info_path = folder / "model_info.json" + if not model_info_path.is_file(): + warnings.warn("No 'model_info.json' file found in folder. No metadata can be checked.") + model_info = None + else: + model_info = json.load(open(model_info_path)) + + model_info = handle_backwards_compatibility_metric_params(model_info) + + return model, model_info + + +def handle_backwards_compatibility_metric_params(model_info): + + if ( + model_info.get("metric_params") is not None + and model_info.get("metric_params").get("quality_metric_params") is not None + ): + if (qm_params := model_info["metric_params"]["quality_metric_params"].get("qm_params")) is not None: + model_info["metric_params"]["quality_metric_params"]["metric_params"] = qm_params + del model_info["metric_params"]["quality_metric_params"]["qm_params"] + + if ( + model_info.get("metric_params") is not None + and model_info.get("metric_params").get("template_metric_params") is not None + ): + if (tm_params := model_info["metric_params"]["template_metric_params"].get("metrics_kwargs")) is not None: + metric_params = {} + for metric_name in model_info["metric_params"]["template_metric_params"].get("metric_names"): + metric_params[metric_name] = deepcopy(tm_params) + model_info["metric_params"]["template_metric_params"]["metric_params"] = metric_params + del model_info["metric_params"]["template_metric_params"]["metrics_kwargs"] + + return model_info diff --git a/src/spikeinterface/curation/tests/common.py b/src/spikeinterface/curation/tests/common.py index 9cd20f4bfc..e9c4c4a463 100644 --- a/src/spikeinterface/curation/tests/common.py +++ b/src/spikeinterface/curation/tests/common.py @@ -19,6 +19,11 @@ def make_sorting_analyzer(sparse=True): seed=2205, ) + channel_ids_as_integers = [id for id in range(recording.get_num_channels())] + unit_ids_as_integers = [id for id in range(sorting.get_num_units())] + recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers) + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers) + sorting_analyzer = create_sorting_analyzer(sorting=sorting, recording=recording, format="memory", sparse=sparse) sorting_analyzer.compute("random_spikes") sorting_analyzer.compute("waveforms", **job_kwargs) diff --git a/src/spikeinterface/curation/tests/test_auto_merge.py b/src/spikeinterface/curation/tests/test_auto_merge.py index 33fd06d27a..4c05f41a4c 100644 --- a/src/spikeinterface/curation/tests/test_auto_merge.py +++ b/src/spikeinterface/curation/tests/test_auto_merge.py @@ -3,16 +3,16 @@ from spikeinterface.core import create_sorting_analyzer from spikeinterface.core.generate import inject_some_split_units -from spikeinterface.curation import get_potential_auto_merge +from spikeinterface.curation import compute_merge_unit_groups, auto_merge from spikeinterface.curation.tests.common import make_sorting_analyzer, sorting_analyzer_for_curation @pytest.mark.parametrize( - "preset", ["x_contaminations", "feature_neighbors", "temporal_splits", "similarity_correlograms"] + "preset", ["x_contaminations", "feature_neighbors", "temporal_splits", "similarity_correlograms", None] ) -def test_get_auto_merge_list(sorting_analyzer_for_curation, preset): +def test_compute_merge_unit_groups(sorting_analyzer_for_curation, preset): print(sorting_analyzer_for_curation) sorting = sorting_analyzer_for_curation.sorting @@ -47,32 +47,38 @@ def test_get_auto_merge_list(sorting_analyzer_for_curation, preset): ) if preset is not None: - potential_merges, outs = get_potential_auto_merge( + # do not resolve graph for checking true pairs + merge_unit_groups, outs = compute_merge_unit_groups( sorting_analyzer, preset=preset, - min_spikes=1000, - max_distance_um=150.0, - contamination_thresh=0.2, - corr_diff_thresh=0.16, - template_diff_thresh=0.25, - censored_period_ms=0.0, - refractory_period_ms=4.0, - sigma_smooth_ms=0.6, - adaptative_window_thresh=0.5, - firing_contamination_balance=1.5, + resolve_graph=False, + # min_spikes=1000, + # max_distance_um=150.0, + # contamination_thresh=0.2, + # corr_diff_thresh=0.16, + # template_diff_thresh=0.25, + # censored_period_ms=0.0, + # refractory_period_ms=4.0, + # sigma_smooth_ms=0.6, + # adaptative_window_thresh=0.5, + # firing_contamination_balance=1.5, extra_outputs=True, + **job_kwargs, ) if preset == "x_contaminations": - assert len(potential_merges) == num_unit_splited + assert len(merge_unit_groups) == num_unit_splited for true_pair in other_ids.values(): true_pair = tuple(true_pair) - assert true_pair in potential_merges + assert true_pair in merge_unit_groups else: # when preset is None you have to specify the steps with pytest.raises(ValueError): - potential_merges = get_potential_auto_merge(sorting_analyzer, preset=preset) - potential_merges = get_potential_auto_merge( - sorting_analyzer, preset=preset, steps=["min_spikes", "min_snr", "remove_contaminated", "unit_positions"] + merge_unit_groups = compute_merge_unit_groups(sorting_analyzer, preset=preset) + merge_unit_groups = compute_merge_unit_groups( + sorting_analyzer, + preset=preset, + steps=["num_spikes", "snr", "remove_contaminated", "unit_locations"], + **job_kwargs, ) # DEBUG @@ -93,7 +99,7 @@ def test_get_auto_merge_list(sorting_analyzer_for_curation, preset): # m = correlograms.shape[2] // 2 - # for unit_id1, unit_id2 in potential_merges[:5]: + # for unit_id1, unit_id2 in merge_unit_groups[:5]: # unit_ind1 = sorting_with_split.id_to_index(unit_id1) # unit_ind2 = sorting_with_split.id_to_index(unit_id2) @@ -129,4 +135,6 @@ def test_get_auto_merge_list(sorting_analyzer_for_curation, preset): if __name__ == "__main__": sorting_analyzer = make_sorting_analyzer(sparse=True) - test_get_auto_merge_list(sorting_analyzer) + # preset = "x_contaminations" + preset = None + test_compute_merge_unit_groups(sorting_analyzer, preset=preset) diff --git a/src/spikeinterface/curation/tests/test_model_based_curation.py b/src/spikeinterface/curation/tests/test_model_based_curation.py new file mode 100644 index 0000000000..3683b417df --- /dev/null +++ b/src/spikeinterface/curation/tests/test_model_based_curation.py @@ -0,0 +1,167 @@ +import pytest +from pathlib import Path +from spikeinterface.curation.tests.common import make_sorting_analyzer, sorting_analyzer_for_curation +from spikeinterface.curation.model_based_curation import ModelBasedClassification +from spikeinterface.curation import auto_label_units, load_model +from spikeinterface.curation.train_manual_curation import _get_computed_metrics + +import numpy as np + +if hasattr(pytest, "global_test_folder"): + cache_folder = pytest.global_test_folder / "curation" +else: + cache_folder = Path("cache_folder") / "curation" + + +@pytest.fixture +def model(): + """A toy model, created using the `sorting_analyzer_for_curation` from `spikeinterface.curation.tests.common`. + It has been trained locally and, when applied to `sorting_analyzer_for_curation` will label its 5 units with + the following labels: [1,0,1,0,1].""" + + model = load_model(Path(__file__).parent / "trained_pipeline/", trusted=["numpy.dtype"]) + return model + + +@pytest.fixture +def required_metrics(): + """These are the metrics which `model` are trained on.""" + return ["num_spikes", "snr", "half_width"] + + +def test_model_based_classification_init(sorting_analyzer_for_curation, model): + """Test that the ModelBasedClassification attributes are correctly initialised""" + + model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model[0]) + assert model_based_classification.sorting_analyzer == sorting_analyzer_for_curation + assert model_based_classification.pipeline == model[0] + assert np.all(model_based_classification.required_metrics == model_based_classification.pipeline.feature_names_in_) + + +def test_metric_ordering_independence(sorting_analyzer_for_curation, model): + """The function `auto_label_units` needs the correct metrics to have been computed. However, + it should be independent of the order of computation. We test this here.""" + + sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"]) + sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["num_spikes", "snr"]) + + model_folder = Path(__file__).parent / Path("trained_pipeline") + + prediction_prob_dataframe_1 = auto_label_units( + sorting_analyzer=sorting_analyzer_for_curation, + model_folder=model_folder, + trusted=["numpy.dtype"], + ) + + sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["snr", "num_spikes"]) + + prediction_prob_dataframe_2 = auto_label_units( + sorting_analyzer=sorting_analyzer_for_curation, + model_folder=model_folder, + trusted=["numpy.dtype"], + ) + + assert prediction_prob_dataframe_1.equals(prediction_prob_dataframe_2) + + +def test_model_based_classification_get_metrics_for_classification( + sorting_analyzer_for_curation, model, required_metrics +): + """If the user has not computed the required metrics, an error should be returned. + This test checks that an error occurs when the required metrics have not been computed, + and that no error is returned when the required metrics have been computed. + """ + + sorting_analyzer_for_curation.delete_extension("quality_metrics") + sorting_analyzer_for_curation.delete_extension("template_metrics") + + model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model[0]) + + # Check that ValueError is returned when no metrics are present in sorting_analyzer + with pytest.raises(ValueError): + computed_metrics = _get_computed_metrics(sorting_analyzer_for_curation) + + # Compute some (but not all) of the required metrics in sorting_analyzer, should still error + sorting_analyzer_for_curation.compute("quality_metrics", metric_names=[required_metrics[0]]) + computed_metrics = _get_computed_metrics(sorting_analyzer_for_curation) + with pytest.raises(ValueError): + model_based_classification._check_required_metrics_are_present(computed_metrics) + + # Compute all of the required metrics in sorting_analyzer, no more error + sorting_analyzer_for_curation.compute("quality_metrics", metric_names=required_metrics[0:2]) + sorting_analyzer_for_curation.compute("template_metrics", metric_names=[required_metrics[2]]) + + metrics_data = _get_computed_metrics(sorting_analyzer_for_curation) + assert metrics_data.shape[0] == len(sorting_analyzer_for_curation.sorting.get_unit_ids()) + assert set(metrics_data.columns.to_list()) == set(required_metrics) + + +def test_model_based_classification_export_to_phy(sorting_analyzer_for_curation, model): + # Test the _export_to_phy() method of ModelBasedClassification + model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model[0]) + classified_units = {0: (1, 0.5), 1: (0, 0.5), 2: (1, 0.5), 3: (0, 0.5), 4: (1, 0.5)} + # Function should fail here + with pytest.raises(ValueError): + model_based_classification._export_to_phy(classified_units) + # Make temp output folder and set as phy_folder + phy_folder = cache_folder / "phy_folder" + phy_folder.mkdir(parents=True, exist_ok=True) + + model_based_classification.sorting_analyzer.sorting.annotate(phy_folder=phy_folder) + model_based_classification._export_to_phy(classified_units) + assert (phy_folder / "cluster_prediction.tsv").exists() + + +def test_model_based_classification_predict_labels(sorting_analyzer_for_curation, model): + """The model `model` has been trained on the `sorting_analyzer` used in this test with + the labels `[1, 0, 1, 0, 1]`. Hence if we apply the model to this `sorting_analyzer` + we expect these labels to be outputted. The test checks this, and also checks + that label conversion works as expected.""" + + sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"]) + sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["num_spikes", "snr"]) + + # Test the predict_labels() method of ModelBasedClassification + model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model[0]) + classified_units = model_based_classification.predict_labels() + predictions = classified_units["prediction"].values + + assert np.all(predictions == np.array([1, 0, 1, 0, 1])) + + conversion = {0: "noise", 1: "good"} + classified_units_labelled = model_based_classification.predict_labels(label_conversion=conversion) + predictions_labelled = classified_units_labelled["prediction"] + assert np.all(predictions_labelled == ["good", "noise", "good", "noise", "good"]) + + +def test_exception_raised_when_metricparams_not_equal(sorting_analyzer_for_curation): + """We track whether the metric parameters used to compute the metrics used to train + a model are the same as the parameters used to compute the metrics in the sorting + analyzer which is being curated. If they are different, an error or warning will + be raised depending on the `enforce_metric_params` kwarg. This behaviour is tested here.""" + + sorting_analyzer_for_curation.compute( + "quality_metrics", metric_names=["num_spikes", "snr"], metric_params={"snr": {"peak_mode": "peak_to_peak"}} + ) + sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"]) + + model_folder = Path(__file__).parent / Path("trained_pipeline") + + model, model_info = load_model(model_folder=model_folder, trusted=["numpy.dtype"]) + model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model) + + # an error should be raised if `enforce_metric_params` is True + with pytest.raises(Exception): + model_based_classification._check_params_for_classification(enforce_metric_params=True, model_info=model_info) + + # but only a warning if `enforce_metric_params` is False + with pytest.warns(UserWarning): + model_based_classification._check_params_for_classification(enforce_metric_params=False, model_info=model_info) + + # Now test the positive case. Recompute using the default parameters + sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["num_spikes", "snr"], metric_params={}) + sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"]) + + model, model_info = load_model(model_folder=model_folder, trusted=["numpy.dtype"]) + model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model) + model_based_classification._check_params_for_classification(enforce_metric_params=True, model_info=model_info) diff --git a/src/spikeinterface/curation/tests/test_sortingview_curation.py b/src/spikeinterface/curation/tests/test_sortingview_curation.py index 945aca7937..ff80be365d 100644 --- a/src/spikeinterface/curation/tests/test_sortingview_curation.py +++ b/src/spikeinterface/curation/tests/test_sortingview_curation.py @@ -49,6 +49,9 @@ def test_gh_curation(): Test curation using GitHub URI. """ sorting = generate_sorting(num_units=10) + unit_ids_as_int = [id for id in range(sorting.get_num_units())] + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_int) + # curated link: # https://figurl.org/f?v=npm://@fi-sci/figurl-sortingview@12/dist&d=sha1://058ab901610aa9d29df565595a3cc2a81a1b08e5 gh_uri = "gh://SpikeInterface/spikeinterface/main/src/spikeinterface/curation/tests/sv-sorting-curation.json" @@ -76,6 +79,8 @@ def test_sha1_curation(): Test curation using SHA1 URI. """ sorting = generate_sorting(num_units=10) + unit_ids_as_int = [id for id in range(sorting.get_num_units())] + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_int) # from SHA1 # curated link: @@ -105,6 +110,8 @@ def test_json_curation(): Test curation using a JSON file. """ sorting = generate_sorting(num_units=10) + unit_ids_as_int = [id for id in range(sorting.get_num_units())] + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_int) # from curation.json json_file = parent_folder / "sv-sorting-curation.json" @@ -248,6 +255,8 @@ def test_json_no_merge_curation(): Test curation with no merges using a JSON file. """ sorting = generate_sorting(num_units=10) + unit_ids_as_int = [id for id in range(sorting.get_num_units())] + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_int) json_file = parent_folder / "sv-sorting-curation-no-merge.json" sorting_curated = apply_sortingview_curation(sorting, uri_or_json=json_file) diff --git a/src/spikeinterface/curation/tests/test_train_manual_curation.py b/src/spikeinterface/curation/tests/test_train_manual_curation.py new file mode 100644 index 0000000000..f455fbdb9c --- /dev/null +++ b/src/spikeinterface/curation/tests/test_train_manual_curation.py @@ -0,0 +1,285 @@ +import pytest +import numpy as np +import tempfile, csv +from pathlib import Path + +from spikeinterface.curation.tests.common import make_sorting_analyzer +from spikeinterface.curation.train_manual_curation import CurationModelTrainer, train_model + + +@pytest.fixture +def trainer(): + """A simple CurationModelTrainer object is created, which can later by used to + train models using data from `sorting_analyzer`s.""" + + folder = tempfile.mkdtemp() # Create a temporary output folder + imputation_strategies = ["median"] + scaling_techniques = ["standard_scaler"] + classifiers = ["LogisticRegression"] + metric_names = ["metric1", "metric2", "metric3"] + search_kwargs = {"cv": 3} + return CurationModelTrainer( + labels=[[0, 1, 0, 1, 0, 1, 0, 1, 0, 1]], + folder=folder, + metric_names=metric_names, + imputation_strategies=imputation_strategies, + scaling_techniques=scaling_techniques, + classifiers=classifiers, + search_kwargs=search_kwargs, + ) + + +def make_temp_training_csv(): + """Create a temporary CSV file with artificially generated quality metrics. + The data is designed to be easy to dicern between units. Even units metric + values are all `0`, while odd units metric values are all `1`. + """ + with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_file: + writer = csv.writer(temp_file) + writer.writerow(["unit_id", "metric1", "metric2", "metric3"]) + for i in range(5): + writer.writerow([i * 2, 0, 0, 0]) + writer.writerow([i * 2 + 1, 1, 1, 1]) + return temp_file.name + + +def test_load_and_preprocess_full(trainer): + """Check that we load and preprocess the csv file from `make_temp_training_csv` + correctly.""" + temp_file_path = make_temp_training_csv() + + # Load and preprocess the data from the temporary CSV file + trainer.load_and_preprocess_csv([temp_file_path]) + + # Assert that the data is loaded and preprocessed correctly + for a, row in trainer.X.iterrows(): + assert np.all(row.values == [float(a % 2)] * 3) + for a, label in enumerate(trainer.y.values): + assert label == a % 2 + for a, row in trainer.testing_metrics.iterrows(): + assert np.all(row.values == [a % 2] * 3) + assert row.name == a + + +def test_apply_scaling_imputation(trainer): + """Take a simple training and test set and check that they are corrected scaled, + using a standard scaler which rescales the training distribution to have mean 0 + and variance 1. Length between each row is 3, so if x0 is the first value in the + column, all other values are scaled as x -> 2/3(x - x0) - 1. The y (labled) values + do not get scaled.""" + + from sklearn.impute._knn import KNNImputer + from sklearn.preprocessing._data import StandardScaler + + imputation_strategy = "knn" + scaling_technique = "standard_scaler" + X_train = np.array([[1, 2, 3], [4, 5, 6]]) + X_test = np.array([[7, 8, 9], [10, 11, 12]]) + y_train = np.array([0, 1]) + y_test = np.array([2, 3]) + + X_train_scaled, X_test_scaled, y_train_scaled, y_test_scaled, imputer, scaler = trainer.apply_scaling_imputation( + imputation_strategy, scaling_technique, X_train, X_test, y_train, y_test + ) + + first_row_elements = X_train[0] + for a, row in enumerate(X_train): + assert np.all(2 / 3 * (row - first_row_elements) - 1.0 == X_train_scaled[a]) + for a, row in enumerate(X_test): + assert np.all(2 / 3 * (row - first_row_elements) - 1.0 == X_test_scaled[a]) + + assert np.all(y_train == y_train_scaled) + assert np.all(y_test == y_test_scaled) + + assert isinstance(imputer, KNNImputer) + assert isinstance(scaler, StandardScaler) + + +def test_get_classifier_search_space(trainer): + """For each classifier, there is a hyperparameter space we search over to find its + most accurate incarnation. Here, we check that we do indeed load the approprirate + dict of hyperparameter possibilities""" + + from sklearn.linear_model._logistic import LogisticRegression + + classifier = "LogisticRegression" + model, param_space = trainer.get_classifier_search_space(classifier) + + assert isinstance(model, LogisticRegression) + assert len(param_space) > 0 + assert isinstance(param_space, dict) + + +def test_get_custom_classifier_search_space(): + """Check that if a user passes a custom hyperparameter search space, that this is + passed correctly to the trainer.""" + + classifier = { + "LogisticRegression": { + "C": [0.1, 8.0], + "solver": ["lbfgs"], + "max_iter": [100, 400], + } + } + trainer = CurationModelTrainer(classifiers=classifier, labels=[[0, 1, 0, 1, 0, 1, 0, 1, 0, 1]]) + + model, param_space = trainer.get_classifier_search_space(list(classifier.keys())[0]) + assert param_space == classifier["LogisticRegression"] + + +def test_saved_files(trainer): + """During the trainer's creation, the following files should be created: + - best_model.skops + - labels.csv + - model_accuracies.csv + - model_info.json + - training_data.csv + This test checks that these exist, and checks some properties of the files.""" + + import pandas as pd + import json + + trainer.X = np.random.rand(10, 3) + trainer.y = np.append(np.ones(5), np.zeros(5)) + + trainer.evaluate_model_config() + trainer_folder = Path(trainer.folder) + + assert trainer_folder.is_dir() + + best_model_path = trainer_folder / "best_model.skops" + model_accuracies_path = trainer_folder / "model_accuracies.csv" + training_data_path = trainer_folder / "training_data.csv" + labels_path = trainer_folder / "labels.csv" + model_info_path = trainer_folder / "model_info.json" + + assert (best_model_path).is_file() + + model_accuracies = pd.read_csv(model_accuracies_path) + model_accuracies["classifier name"].values[0] == "LogisticRegression" + assert len(model_accuracies) == 1 + + training_data = pd.read_csv(training_data_path) + assert np.all(np.isclose(training_data.values[:, 1:4], trainer.X, rtol=1e-10)) + + labels = pd.read_csv(labels_path) + assert np.all(labels.values[:, 1] == trainer.y.astype("float")) + + model_info = pd.read_json(model_info_path) + + with open(model_info_path) as f: + model_info = json.load(f) + + assert set(model_info.keys()) == set(["metric_params", "requirements", "label_conversion"]) + + +def test_train_model(): + """A simple function test to check that `train_model` doesn't fail with one csv inputs""" + + metrics_path = make_temp_training_csv() + folder = tempfile.mkdtemp() + metric_names = ["metric1", "metric2", "metric3"] + trainer = train_model( + mode="csv", + metrics_paths=[metrics_path], + folder=folder, + labels=[[0, 1, 0, 1, 0, 1, 0, 1, 0, 1]], + metric_names=metric_names, + imputation_strategies=["median"], + scaling_techniques=["standard_scaler"], + classifiers=["LogisticRegression"], + overwrite=True, + search_kwargs={"cv": 3, "scoring": "balanced_accuracy", "n_iter": 1}, + ) + assert isinstance(trainer, CurationModelTrainer) + + +def test_train_model_using_two_csvs(): + """Models can be trained using more than one set of training data. This test checks + that `train_model` works with two inputs, from csv files.""" + + metrics_path_1 = make_temp_training_csv() + metrics_path_2 = make_temp_training_csv() + + folder = tempfile.mkdtemp() + metric_names = ["metric1", "metric2", "metric3"] + + trainer = train_model( + mode="csv", + metrics_paths=[metrics_path_1, metrics_path_2], + folder=folder, + labels=[[0, 1, 0, 1, 0, 1, 0, 1, 0, 1], [0, 1, 0, 1, 0, 1, 0, 1, 0, 1]], + metric_names=metric_names, + imputation_strategies=["median"], + scaling_techniques=["standard_scaler"], + classifiers=["LogisticRegression"], + overwrite=True, + ) + assert isinstance(trainer, CurationModelTrainer) + + +def test_train_using_two_sorting_analyzers(): + """Models can be trained using more than one set of training data. This test checks + that `train_model` works with two inputs, from sorting analzyers. It also checks that + an error is raised if the sorting_analyzers have different sets of metrics computed.""" + + sorting_analyzer_1 = make_sorting_analyzer() + sorting_analyzer_1.compute({"quality_metrics": {"metric_names": ["num_spikes", "snr"]}}) + + sorting_analyzer_2 = make_sorting_analyzer() + sorting_analyzer_2.compute({"quality_metrics": {"metric_names": ["num_spikes", "snr"]}}) + + labels_1 = [0, 1, 1, 1, 1] + labels_2 = [1, 1, 0, 1, 1] + + folder = tempfile.mkdtemp() + trainer = train_model( + analyzers=[sorting_analyzer_1, sorting_analyzer_2], + folder=folder, + labels=[labels_1, labels_2], + imputation_strategies=["median"], + scaling_techniques=["standard_scaler"], + classifiers=["LogisticRegression"], + overwrite=True, + ) + + assert isinstance(trainer, CurationModelTrainer) + + # Check that there is an error raised if the metric names are different + sorting_analyzer_2 = make_sorting_analyzer() + sorting_analyzer_2.compute({"quality_metrics": {"metric_names": ["num_spikes"], "delete_existing_metrics": True}}) + + with pytest.raises(Exception): + trainer = train_model( + analyzers=[sorting_analyzer_1, sorting_analyzer_2], + folder=folder, + labels=[labels_1, labels_2], + imputation_strategies=["median"], + scaling_techniques=["standard_scaler"], + classifiers=["LogisticRegression"], + overwrite=True, + ) + + # Now check that there is an error raised if we demand the same metric params, but don't have them + + sorting_analyzer_2.compute( + { + "quality_metrics": { + "metric_names": ["num_spikes", "snr"], + "metric_params": {"snr": {"peak_mode": "at_index"}}, + } + } + ) + + with pytest.raises(Exception): + train_model( + analyzers=[sorting_analyzer_1, sorting_analyzer_2], + folder=folder, + labels=[labels_1, labels_2], + imputation_strategies=["median"], + scaling_techniques=["standard_scaler"], + classifiers=["LogisticRegression"], + search_kwargs={"cv": 3, "scoring": "balanced_accuracy", "n_iter": 1}, + overwrite=True, + enforce_metric_params=True, + ) diff --git a/src/spikeinterface/curation/tests/trained_pipeline/best_model.skops b/src/spikeinterface/curation/tests/trained_pipeline/best_model.skops new file mode 100644 index 0000000000..362405f917 Binary files /dev/null and b/src/spikeinterface/curation/tests/trained_pipeline/best_model.skops differ diff --git a/src/spikeinterface/curation/tests/trained_pipeline/labels.csv b/src/spikeinterface/curation/tests/trained_pipeline/labels.csv new file mode 100644 index 0000000000..46680a9e89 --- /dev/null +++ b/src/spikeinterface/curation/tests/trained_pipeline/labels.csv @@ -0,0 +1,21 @@ +unit_index,0 +0,1 +1,0 +2,1 +3,0 +4,1 +0,1 +1,0 +2,1 +3,0 +4,1 +0,1 +1,0 +2,1 +3,0 +4,1 +0,1 +1,0 +2,1 +3,0 +4,1 diff --git a/src/spikeinterface/curation/tests/trained_pipeline/model_accuracies.csv b/src/spikeinterface/curation/tests/trained_pipeline/model_accuracies.csv new file mode 100644 index 0000000000..7f015c380b --- /dev/null +++ b/src/spikeinterface/curation/tests/trained_pipeline/model_accuracies.csv @@ -0,0 +1,2 @@ +,classifier name,imputation_strategy,scaling_strategy,accuracy,precision,recall,model_id,best_params +0,LogisticRegression,median,StandardScaler(),1.0000,1.0000,1.0000,0,"OrderedDict([('C', 4.811707275233983), ('max_iter', 384), ('solver', 'saga')])" diff --git a/src/spikeinterface/curation/tests/trained_pipeline/model_info.json b/src/spikeinterface/curation/tests/trained_pipeline/model_info.json new file mode 100644 index 0000000000..75ced28486 --- /dev/null +++ b/src/spikeinterface/curation/tests/trained_pipeline/model_info.json @@ -0,0 +1,60 @@ +{ + "metric_params": { + "quality_metric_params": { + "metric_names": [ + "snr", + "num_spikes" + ], + "peak_sign": null, + "seed": null, + "metric_params": { + "num_spikes": {}, + "snr": { + "peak_sign": "neg", + "peak_mode": "extremum" + } + }, + "skip_pc_metrics": false, + "delete_existing_metrics": false, + "metrics_to_compute": [ + "snr", + "num_spikes" + ] + }, + "template_metric_params": { + "metric_names": [ + "half_width" + ], + "sparsity": null, + "peak_sign": "neg", + "upsampling_factor": 10, + "metric_params": { + "half_width": { + "recovery_window_ms": 0.7, + "peak_relative_threshold": 0.2, + "peak_width_ms": 0.1, + "depth_direction": "y", + "min_channels_for_velocity": 5, + "min_r2_velocity": 0.5, + "exp_peak_function": "ptp", + "min_r2_exp_decay": 0.5, + "spread_threshold": 0.2, + "spread_smooth_um": 20, + "column_range": null + } + }, + "delete_existing_metrics": false, + "metrics_to_compute": [ + "half_width" + ] + } + }, + "requirements": { + "spikeinterface": "0.101.1", + "scikit-learn": "1.3.2" + }, + "label_conversion": { + "1": 1, + "0": 0 + } +} diff --git a/src/spikeinterface/curation/tests/trained_pipeline/training_data.csv b/src/spikeinterface/curation/tests/trained_pipeline/training_data.csv new file mode 100644 index 0000000000..c9efca17ad --- /dev/null +++ b/src/spikeinterface/curation/tests/trained_pipeline/training_data.csv @@ -0,0 +1,21 @@ +unit_id,snr,num_spikes,half_width +0,21.026926,5968.0,0.00027333334 +1,34.64474,5928.0,0.00023666666 +2,6.986315,5954.0,0.00026666667 +3,8.223127,6032.0,0.00020333333 +4,2.7464194,6002.0,0.00026666667 +0,21.026926,5968.0,0.00027333334 +1,34.64474,5928.0,0.00023666666 +2,6.986315,5954.0,0.00026666667 +3,8.223127,6032.0,0.00020333333 +4,2.7464194,6002.0,0.00026666667 +0,21.026926,5968.0,0.00027333334 +1,34.64474,5928.0,0.00023666666 +2,6.986315,5954.0,0.00026666667 +3,8.223127,6032.0,0.00020333333 +4,2.7464194,6002.0,0.00026666667 +0,21.026926,5968.0,0.00027333334 +1,34.64474,5928.0,0.00023666666 +2,6.986315,5954.0,0.00026666667 +3,8.223127,6032.0,0.00020333333 +4,2.7464194,6002.0,0.00026666667 diff --git a/src/spikeinterface/curation/train_manual_curation.py b/src/spikeinterface/curation/train_manual_curation.py new file mode 100644 index 0000000000..7b315b0fba --- /dev/null +++ b/src/spikeinterface/curation/train_manual_curation.py @@ -0,0 +1,843 @@ +import os +import warnings +import numpy as np +import json +import spikeinterface +from spikeinterface.core.job_tools import fix_job_kwargs +from spikeinterface.qualitymetrics import ( + get_quality_metric_list, + get_quality_pca_metric_list, + qm_compute_name_to_column_names, +) +from spikeinterface.postprocessing import get_template_metric_names +from spikeinterface.postprocessing.template_metrics import tm_compute_name_to_column_names +from pathlib import Path +from copy import deepcopy + + +def get_default_classifier_search_spaces(): + + from scipy.stats import uniform, randint + + default_classifier_search_spaces = { + "RandomForestClassifier": { + "n_estimators": [100, 150], + "criterion": ["gini", "entropy"], + "min_samples_split": [2, 4], + "min_samples_leaf": [2, 4], + "class_weight": ["balanced", "balanced_subsample"], + }, + "AdaBoostClassifier": { + "learning_rate": [1, 2], + "n_estimators": [50, 100], + "algorithm": ["SAMME", "SAMME.R"], + }, + "GradientBoostingClassifier": { + "learning_rate": uniform(0.05, 0.1), + "n_estimators": randint(100, 150), + "max_depth": [2, 4], + "min_samples_split": [2, 4], + "min_samples_leaf": [2, 4], + }, + "SVC": { + "C": uniform(0.001, 10.0), + "kernel": ["sigmoid", "rbf"], + "gamma": uniform(0.001, 10.0), + "probability": [True], + }, + "LogisticRegression": { + "C": uniform(0.001, 10.0), + "solver": ["newton-cg", "lbfgs", "liblinear", "sag", "saga"], + "max_iter": [100], + }, + "XGBClassifier": { + "max_depth": [2, 4], + "eta": uniform(0.2, 0.5), + "sampling_method": ["uniform"], + "grow_policy": ["depthwise", "lossguide"], + }, + "CatBoostClassifier": {"depth": [2, 4], "learning_rate": uniform(0.05, 0.15), "n_estimators": [100, 150]}, + "LGBMClassifier": {"learning_rate": uniform(0.05, 0.15), "n_estimators": randint(100, 150)}, + "MLPClassifier": { + "activation": ["tanh", "relu"], + "solver": ["adam"], + "alpha": uniform(1e-7, 1e-1), + "learning_rate": ["constant", "adaptive"], + "n_iter_no_change": [32], + }, + } + + return default_classifier_search_spaces + + +class CurationModelTrainer: + """ + Used to train and evaluate machine learning models for spike sorting curation. + + Parameters + ---------- + labels : list of lists, default: None + List of curated labels for each unit; must be in the same order as the metrics data. + folder : str, default: None + The folder where outputs such as models and evaluation metrics will be saved, if specified. Requires the skops library. If None, output will not be saved on file system. + metric_names : list of str, default: None + A list of metrics to use for training. If None, default metrics will be used. + imputation_strategies : list of str | None, default: None + A list of imputation strategies to try. Can be "knn”, "iterative" or any allowed + strategy passable to the sklearn `SimpleImputer`. If None, the default strategies + `["median", "most_frequent", "knn", "iterative"]` will be used. + scaling_techniques : list of str | None, default: None + A list of scaling techniques to try. Can be "standard_scaler", "min_max_scaler", + or "robust_scaler", If None, all techniques will be used. + classifiers : list of str or dict, default: None + A list of classifiers to evaluate. Optionally, a dictionary of classifiers and their hyperparameter search spaces can be provided. If None, default classifiers will be used. Check the `get_classifier_search_space` method for the default search spaces & format for custom spaces. + test_size : float, default: 0.2 + Proportion of the dataset to include in the test split, passed to `train_test_split` from `sklear`. + seed : int, default: None + Random seed for reproducibility. If None, a random seed will be generated. + smote : bool, default: False + Whether to apply SMOTE for class imbalance. Default is False. Requires imbalanced-learn package. + verbose : bool, default: True + If True, useful information is printed during training. + search_kwargs : dict or None, default: None + Keyword arguments passed to `BayesSearchCV` or `RandomizedSearchCV` from `sklearn`. If None, use + `search_kwargs = {'cv': 3, 'scoring': 'balanced_accuracy', 'n_iter': 25}`. + + Attributes + ---------- + folder : str + The folder where outputs such as models and evaluation metrics will be saved. Requires the skops library. + labels : list of lists, default: None + List of curated labels for each `sorting_analyzer` and each unit; must be in the same order as the metrics data. + imputation_strategies : list of str | None, default: None + A list of imputation strategies to try. Can be "knn”, "iterative" or any allowed + strategy passable to the sklearn `SimpleImputer`. If None, the default strategies + `["median", "most_frequent", "knn", "iterative"]` will be used. + scaling_techniques : list of str | None, default: None + A list of scaling techniques to try. Can be "standard_scaler", "min_max_scaler", + or "robust_scaler", If None, all techniques will be used. + classifiers : list of str + The list of classifiers to evaluate. + classifier_search_space : dict or None + Dictionary of classifiers and their hyperparameter search spaces, if provided. If None, default search spaces are used. + seed : int + Random seed for reproducibility. + metrics_list : list of str + The list of metrics to use for training. + X : pandas.DataFrame or None + The feature matrix after preprocessing. + y : pandas.Series or None + The target vector after preprocessing. + testing_metrics : dict or None + Dictionary to hold testing metrics data. + label_conversion : dict or None + Dictionary to map string labels to integer codes if target column contains string labels. + + Methods + ------- + get_default_metrics_list() + Returns the default list of metrics. + load_and_preprocess_full(path) + Loads and preprocesses the data from the given path. + load_data_file(path) + Loads the data file from the given path. + process_test_data_for_classification() + Processes the test data for classification. + apply_scaling_imputation(imputation_strategy, scaling_technique, X_train, X_val, y_train, y_val) + Applies the specified imputation and scaling techniques to the data. + get_classifier_instance(classifier_name) + Returns an instance of the specified classifier. + get_classifier_search_space(classifier_name) + Returns the search space for hyperparameter tuning for the specified classifier. + get_classifier_search_space() + Returns the default search spaces for hyperparameter tuning for the classifiers. + evaluate_model_config(imputation_strategies, scaling_techniques, classifiers) + Evaluates the model configurations with the given imputation strategies, scaling techniques, and classifiers. + """ + + def __init__( + self, + labels=None, + folder=None, + metric_names=None, + imputation_strategies=None, + scaling_techniques=None, + classifiers=None, + test_size=0.2, + seed=None, + smote=False, + verbose=True, + search_kwargs=None, + **job_kwargs, + ): + + import pandas as pd + + if imputation_strategies is None: + imputation_strategies = ["median", "most_frequent", "knn", "iterative"] + + if scaling_techniques is None: + scaling_techniques = [ + "standard_scaler", + "min_max_scaler", + "robust_scaler", + ] + + if classifiers is None: + self.classifiers = ["RandomForestClassifier"] + self.classifier_search_space = None + elif isinstance(classifiers, dict): + self.classifiers = list(classifiers.keys()) + self.classifier_search_space = classifiers + elif isinstance(classifiers, list): + self.classifiers = classifiers + self.classifier_search_space = None + else: + raise ValueError("classifiers must be a list or dictionary") + + # check if labels is a list of lists + if not all(isinstance(label, list) or isinstance(label, np.ndarray) for label in labels): + raise ValueError("labels must be a list of lists") + + self.folder = Path(folder) if folder is not None else None + self.imputation_strategies = imputation_strategies + self.scaling_techniques = scaling_techniques + self.test_size = test_size + self.seed = seed if seed is not None else np.random.default_rng(seed=None).integers(0, 2**31) + self.metrics_params = {} + self.smote = smote + self.label_conversion = None + self.verbose = verbose + self.search_kwargs = search_kwargs + + self.X = None + self.testing_metrics = None + + self.requirements = {"spikeinterface": spikeinterface.__version__} + + self.y = pd.concat([pd.DataFrame(one_labels)[0] for one_labels in labels]) + + self.metric_names = metric_names + + if self.folder is not None and not self.folder.is_dir(): + self.folder.mkdir(parents=True, exist_ok=True) + + # update job_kwargs with global ones + job_kwargs = fix_job_kwargs(job_kwargs) + self.n_jobs = job_kwargs["n_jobs"] + + def get_default_metrics_list(self): + """Returns the default list of metrics.""" + return get_quality_metric_list() + get_quality_pca_metric_list() + get_template_metric_names() + + def load_and_preprocess_analyzers(self, analyzers, enforce_metric_params): + """ + Loads and preprocesses the quality metrics and labels from the given list of SortingAnalyzer objects. + """ + import pandas as pd + + metrics_for_each_analyzer = [_get_computed_metrics(an) for an in analyzers] + check_metric_names_are_the_same(metrics_for_each_analyzer) + + self.testing_metrics = pd.concat(metrics_for_each_analyzer, axis=0) + + # Set metric names to those calculated if not provided + if self.metric_names is None: + warnings.warn("No metric_names provided, using all metrics calculated by the analyzers") + self.metric_names = self.testing_metrics.columns.tolist() + + conflicting_metrics = self._check_metrics_parameters(analyzers, enforce_metric_params) + + self.metrics_params = {} + + extension_names = ["quality_metrics", "template_metrics"] + metric_extensions = [analyzers[0].get_extension(extension_name) for extension_name in extension_names] + + for metric_extension, extension_name in zip(metric_extensions, extension_names): + + # remove the 's' at the end of the extension name + extension_name = extension_name[:-1] + if metric_extension is not None: + self.metrics_params[extension_name + "_params"] = metric_extension.params + + # Only save metric params which are 1) consistent and 2) exist in metric_names + metric_names = metric_extension.params["metric_names"] + consistent_metrics = list(set(metric_names).difference(set(conflicting_metrics))) + consistent_metric_params = { + metric: metric_extension.params["metric_params"][metric] for metric in consistent_metrics + } + self.metrics_params[extension_name + "_params"]["metric_params"] = consistent_metric_params + + self.process_test_data_for_classification() + + def _check_metrics_parameters(self, analyzers, enforce_metric_params): + """Checks that the metrics of each analyzer have been calcualted using the same parameters""" + + extension_names = ["quality_metrics", "template_metrics"] + + conflicting_metrics = [] + for analyzer_index_1, analyzer_1 in enumerate(analyzers): + for analyzer_index_2, analyzer_2 in enumerate(analyzers): + + if analyzer_index_1 <= analyzer_index_2: + continue + else: + + metric_params_1 = {} + metric_params_2 = {} + + for extension_name in extension_names: + if (extension_1 := analyzer_1.get_extension(extension_name)) is not None: + metric_params_1.update(extension_1.params["metric_params"]) + if (extension_2 := analyzer_2.get_extension(extension_name)) is not None: + metric_params_2.update(extension_2.params["metric_params"]) + + conflicting_metrics_between_1_2 = [] + # check quality metrics params + for metric, params_1 in metric_params_1.items(): + if params_1 != metric_params_2.get(metric): + conflicting_metrics_between_1_2.append(metric) + + conflicting_metrics += conflicting_metrics_between_1_2 + + if len(conflicting_metrics_between_1_2) > 0: + warning_message = f"Parameters used to calculate {conflicting_metrics_between_1_2} are different for sorting_analyzers #{analyzer_index_1} and #{analyzer_index_2}" + if enforce_metric_params is True: + raise Exception(warning_message) + else: + warnings.warn(warning_message) + + unique_conflicting_metrics = set(conflicting_metrics) + return unique_conflicting_metrics + + def load_and_preprocess_csv(self, paths): + self._load_data_files(paths) + self.process_test_data_for_classification() + self.get_metric_params_csv() + + def get_metric_params_csv(self): + + from itertools import chain + + qm_metric_names = list(chain.from_iterable(qm_compute_name_to_column_names.values())) + tm_metric_names = list(chain.from_iterable(tm_compute_name_to_column_names.values())) + + quality_metric_names = [] + template_metric_names = [] + + for metric_name in self.metric_names: + if metric_name in qm_metric_names: + quality_metric_names.append(metric_name) + if metric_name in tm_metric_names: + template_metric_names.append(metric_name) + + self.metrics_params = {} + if quality_metric_names != {}: + self.metrics_params["quality_metric_params"] = {"metric_names": quality_metric_names} + if template_metric_names != {}: + self.metrics_params["template_metric_params"] = {"metric_names": template_metric_names} + + return + + def process_test_data_for_classification(self): + """ + Cleans the input data so that it can be used by sklearn. + + Extracts the target variable and features from the loaded dataset. + It handles string labels by converting them to integer codes and reindexes the + feature matrix to match the specified metrics list. Infinite values in the features + are replaced with NaN, and any remaining NaN values are filled with zeros. + + Raises + ------ + ValueError + If the target column specified is not found in the loaded dataset. + + Notes + ----- + If the target column contains string labels, a warning is issued and the labels + are converted to integer codes. The mapping from string labels to integer codes + is stored in the `label_conversion` attribute. + """ + + # Convert string labels to integer codes to allow classification + new_y = self.y.astype("category").cat.codes + self.label_conversion = dict(zip(new_y, self.y)) + self.y = new_y + + # Extract features + try: + if (set(self.metric_names) - set(self.testing_metrics.columns) != set()) and self.verbose is True: + print( + f"Dropped metrics (calculated but not included in metric_names): {set(self.testing_metrics.columns) - set(self.metric_names)}" + ) + self.X = self.testing_metrics[self.metric_names] + except KeyError as e: + raise KeyError(f"{str(e)}, metrics_list contains invalid metric names") + + self.X = self.testing_metrics.reindex(columns=self.metric_names) + self.X = _format_metric_dataframe(self.X) + + def apply_scaling_imputation(self, imputation_strategy, scaling_technique, X_train, X_test, y_train, y_test): + """Impute and scale the data using the specified techniques.""" + from sklearn.experimental import enable_iterative_imputer + from sklearn.impute import SimpleImputer, KNNImputer, IterativeImputer + from sklearn.ensemble import HistGradientBoostingRegressor + from sklearn.preprocessing import StandardScaler, MinMaxScaler, RobustScaler + + if imputation_strategy == "knn": + imputer = KNNImputer(n_neighbors=5) + elif imputation_strategy == "iterative": + imputer = IterativeImputer( + estimator=HistGradientBoostingRegressor(random_state=self.seed), random_state=self.seed + ) + else: + imputer = SimpleImputer(strategy=imputation_strategy) + + if scaling_technique == "standard_scaler": + scaler = StandardScaler() + elif scaling_technique == "min_max_scaler": + scaler = MinMaxScaler() + elif scaling_technique == "robust_scaler": + scaler = RobustScaler() + else: + raise ValueError( + f"Unknown scaling technique: {scaling_technique}. Supported scaling techniques are 'standard_scaler', 'min_max_scaler' and 'robust_scaler." + ) + + y_train_processed = y_train.astype(int) + y_test = y_test.astype(int) + + X_train_imputed = imputer.fit_transform(X_train) + X_test_imputed = imputer.transform(X_test) + X_train_processed = scaler.fit_transform(X_train_imputed) + X_test_processed = scaler.transform(X_test_imputed) + + # Apply SMOTE for class imbalance + if self.smote: + try: + from imblearn.over_sampling import SMOTE + except ModuleNotFoundError: + raise ModuleNotFoundError("Please install imbalanced-learn package to use SMOTE") + smote = SMOTE(random_state=self.seed) + X_train_processed, y_train_processed = smote.fit_resample(X_train_processed, y_train_processed) + + return X_train_processed, X_test_processed, y_train_processed, y_test, imputer, scaler + + def get_classifier_instance(self, classifier_name): + from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier, GradientBoostingClassifier + from sklearn.svm import SVC + from sklearn.linear_model import LogisticRegression + from sklearn.neural_network import MLPClassifier + + classifier_mapping = { + "RandomForestClassifier": RandomForestClassifier(random_state=self.seed), + "AdaBoostClassifier": AdaBoostClassifier(random_state=self.seed), + "GradientBoostingClassifier": GradientBoostingClassifier(random_state=self.seed), + "SVC": SVC(random_state=self.seed), + "LogisticRegression": LogisticRegression(random_state=self.seed), + "MLPClassifier": MLPClassifier(random_state=self.seed), + } + + # Check lightgbm package install + if classifier_name == "LGBMClassifier": + try: + import lightgbm + + self.requirements["lightgbm"] = lightgbm.__version__ + classifier_mapping["LGBMClassifier"] = lightgbm.LGBMClassifier(random_state=self.seed, verbose=-1) + except ImportError: + raise ImportError("Please install lightgbm package to use LGBMClassifier") + elif classifier_name == "CatBoostClassifier": + try: + import catboost + + self.requirements["catboost"] = catboost.__version__ + classifier_mapping["CatBoostClassifier"] = catboost.CatBoostClassifier( + silent=True, random_state=self.seed + ) + except ImportError: + raise ImportError("Please install catboost package to use CatBoostClassifier") + elif classifier_name == "XGBClassifier": + try: + import xgboost + + self.requirements["xgboost"] = xgboost.__version__ + classifier_mapping["XGBClassifier"] = xgboost.XGBClassifier( + use_label_encoder=False, random_state=self.seed + ) + except ImportError: + raise ImportError("Please install xgboost package to use XGBClassifier") + + if classifier_name not in classifier_mapping: + raise ValueError( + f"Unknown classifier: {classifier_name}. To see list of supported classifiers run\n\t>>> from spikeinterface.curation import get_default_classifier_search_spaces\n\t>>> print(get_default_classifier_search_spaces().keys())" + ) + + return classifier_mapping[classifier_name] + + def get_classifier_search_space(self, classifier_name): + + default_classifier_search_spaces = get_default_classifier_search_spaces() + + if classifier_name not in default_classifier_search_spaces: + raise ValueError( + f"Unknown classifier: {classifier_name}. To see list of supported classifiers run\n\t>>> from spikeinterface.curation import get_default_classifier_search_spaces\n\t>>> print(get_default_classifier_search_spaces().keys())" + ) + + model = self.get_classifier_instance(classifier_name) + if self.classifier_search_space is not None: + param_space = self.classifier_search_space[classifier_name] + else: + param_space = default_classifier_search_spaces[classifier_name] + return model, param_space + + def evaluate_model_config(self): + """ + Evaluates the model configurations with the given imputation strategies, scaling techniques, and classifiers. + + This method splits the preprocessed data into training and testing sets, then evaluates the specified + combinations of imputation strategies, scaling techniques, and classifiers. The evaluation results are + saved to the output folder. + + Raises + ------ + ValueError + If any of the specified classifier names are not recognized. + + Notes + ----- + The method converts the classifier names to actual classifier instances before evaluating them. + The evaluation results, including the best model and its parameters, are saved to the output folder. + """ + from sklearn.model_selection import train_test_split + + X_train, X_test, y_train, y_test = train_test_split( + self.X, self.y, test_size=self.test_size, random_state=self.seed, stratify=self.y + ) + classifier_instances = [self.get_classifier_instance(clf) for clf in self.classifiers] + self._evaluate( + self.imputation_strategies, + self.scaling_techniques, + classifier_instances, + X_train, + X_test, + y_train, + y_test, + self.search_kwargs, + ) + + def _load_data_files(self, paths): + import pandas as pd + + self.testing_metrics = pd.concat([pd.read_csv(path, index_col=0) for path in paths], axis=0) + + def _evaluate( + self, imputation_strategies, scaling_techniques, classifiers, X_train, X_test, y_train, y_test, search_kwargs + ): + from joblib import Parallel, delayed + from sklearn.pipeline import Pipeline + import pandas as pd + + results = Parallel(n_jobs=self.n_jobs)( + delayed(self._train_and_evaluate)( + imputation_strategy, scaler, classifier, X_train, X_test, y_train, y_test, idx, search_kwargs + ) + for idx, (imputation_strategy, scaler, classifier) in enumerate( + (imputation_strategy, scaler, classifier) + for imputation_strategy in imputation_strategies + for scaler in scaling_techniques + for classifier in classifiers + ) + ) + + test_accuracies, models = zip(*results) + + if self.search_kwargs is None or self.search_kwargs.get("scoring"): + scoring_method = "balanced_accuracy" + else: + scoring_method = self.search_kwargs.get("scoring") + + self.test_accuracies_df = pd.DataFrame(test_accuracies).sort_values(scoring_method, ascending=False) + + best_model_id = int(self.test_accuracies_df.iloc[0]["model_id"]) + best_model, best_imputer, best_scaler = models[best_model_id] + + best_pipeline = Pipeline( + [("imputer", best_imputer), ("scaler", best_scaler), ("classifier", best_model.best_estimator_)] + ) + + self.best_pipeline = best_pipeline + + if self.folder is not None: + self._save() + + def _save(self): + from skops.io import dump + import sklearn + import pandas as pd + + # export training data and labels + pd.DataFrame(self.X).to_csv(self.folder / f"training_data.csv", index_label="unit_id") + pd.DataFrame(self.y).to_csv(self.folder / f"labels.csv", index_label="unit_index") + + self.requirements["scikit-learn"] = sklearn.__version__ + + # Dump to skops if folder is provided + dump(self.best_pipeline, self.folder / f"best_model.skops") + self.test_accuracies_df.to_csv(self.folder / f"model_accuracies.csv", float_format="%.4f") + + model_info = {} + model_info["metric_params"] = self.metrics_params + + model_info["requirements"] = self.requirements + + model_info["label_conversion"] = self.label_conversion + + param_file = self.folder / "model_info.json" + Path(param_file).write_text(json.dumps(model_info, indent=4), encoding="utf8") + + def _train_and_evaluate( + self, imputation_strategy, scaler, classifier, X_train, X_test, y_train, y_test, model_id, search_kwargs + ): + from sklearn.metrics import balanced_accuracy_score, precision_score, recall_score + + search_kwargs = set_default_search_kwargs(search_kwargs) + + X_train_scaled, X_test_scaled, y_train, y_test, imputer, scaler = self.apply_scaling_imputation( + imputation_strategy, scaler, X_train, X_test, y_train, y_test + ) + if self.verbose is True: + print(f"Running {classifier.__class__.__name__} with imputation {imputation_strategy} and scaling {scaler}") + model, param_space = self.get_classifier_search_space(classifier.__class__.__name__) + + try: + from skopt import BayesSearchCV + + model = BayesSearchCV( + model, + param_space, + random_state=self.seed, + **search_kwargs, + ) + except: + if self.verbose is True: + print("BayesSearchCV from scikit-optimize not available, using RandomizedSearchCV") + from sklearn.model_selection import RandomizedSearchCV + + model = RandomizedSearchCV(model, param_space, **search_kwargs) + + model.fit(X_train_scaled, y_train) + y_pred = model.predict(X_test_scaled) + balanced_acc = balanced_accuracy_score(y_test, y_pred) + precision = precision_score(y_test, y_pred, average="macro") + recall = recall_score(y_test, y_pred, average="macro") + return { + "classifier name": classifier.__class__.__name__, + "imputation_strategy": imputation_strategy, + "scaling_strategy": scaler, + "balanced_accuracy": balanced_acc, + "precision": precision, + "recall": recall, + "model_id": model_id, + "best_params": model.best_params_, + }, (model, imputer, scaler) + + +def train_model( + mode="analyzers", + labels=None, + analyzers=None, + metrics_paths=None, + folder=None, + metric_names=None, + imputation_strategies=None, + scaling_techniques=None, + classifiers=None, + test_size=0.2, + overwrite=False, + seed=None, + search_kwargs=None, + verbose=True, + enforce_metric_params=False, + **job_kwargs, +): + """ + Trains and evaluates machine learning models for spike sorting curation. + + This function initializes a `CurationModelTrainer` object, loads and preprocesses the data, + and evaluates the specified combinations of imputation strategies, scaling techniques, and classifiers. + The evaluation results, including the best model and its parameters, are saved to the output folder. + + Parameters + ---------- + mode : "analyzers" | "csv", default: "analyzers" + Mode to use for training. + analyzers : list of SortingAnalyzer | None, default: None + List of SortingAnalyzer objects containing the quality metrics and labels to use for training, if using 'analyzers' mode. + labels : list of list | None, default: None + List of curated labels for each unit; must be in the same order as the metrics data. + metrics_paths : list of str or None, default: None + List of paths to the CSV files containing the metrics data if using 'csv' mode. + folder : str | None, default: None + The folder where outputs such as models and evaluation metrics will be saved. + metric_names : list of str | None, default: None + A list of metrics to use for training. If None, default metrics will be used. + imputation_strategies : list of str | None, default: None + A list of imputation strategies to try. Can be "knn”, "iterative" or any allowed + strategy passable to the sklearn `SimpleImputer`. If None, the default strategies + `["median", "most_frequent", "knn", "iterative"]` will be used. + scaling_techniques : list of str | None, default: None + A list of scaling techniques to try. Can be "standard_scaler", "min_max_scaler", + or "robust_scaler", If None, all techniques will be used. + classifiers : list of str | dict | None, default: None + A list of classifiers to evaluate. Optionally, a dictionary of classifiers and their hyperparameter search spaces can be provided. If None, default classifiers will be used. Check the `get_classifier_search_space` method for the default search spaces & format for custom spaces. + test_size : float, default: 0.2 + Proportion of the dataset to include in the test split, passed to `train_test_split` from `sklear`. + overwrite : bool, default: False + Overwrites the `folder` if it already exists + seed : int | None, default: None + Random seed for reproducibility. If None, a random seed will be generated. + search_kwargs : dict or None, default: None + Keyword arguments passed to `BayesSearchCV` or `RandomizedSearchCV` from `sklearn`. If None, use + `search_kwargs = {'cv': 3, 'scoring': 'balanced_accuracy', 'n_iter': 25}`. + verbose : bool, default: True + If True, useful information is printed during training. + enforce_metric_params : bool, default: False + If True and metric parameters used to calculate metrics for different `sorting_analyzer`s are + different, an error will be raised. + + + Returns + ------- + CurationModelTrainer + The `CurationModelTrainer` object used for training and evaluation. + + Notes + ----- + This function handles the entire workflow of initializing the trainer, loading and preprocessing the data, + and evaluating the models. The evaluation results are saved to the specified output folder. + """ + + if folder is None: + raise Exception("You must supply a folder for the model to be saved in using `folder='path/to/folder/'`") + + if overwrite is False: + assert not Path(folder).is_dir(), f"folder {folder} already exists, choose another name or use overwrite=True" + + if labels is None: + raise Exception("You must supply a list of lists of curated labels using `labels = [[...],[...],...]`") + + if mode not in ["analyzers", "csv"]: + raise Exception("`mode` must be equal to 'analyzers' or 'csv'.") + + if (test_size > 1.0) or (0.0 > test_size): + raise Exception("`test_size` must be between 0.0 and 1.0") + + trainer = CurationModelTrainer( + labels=labels, + folder=folder, + metric_names=metric_names, + imputation_strategies=imputation_strategies, + scaling_techniques=scaling_techniques, + classifiers=classifiers, + test_size=test_size, + seed=seed, + verbose=verbose, + search_kwargs=search_kwargs, + **job_kwargs, + ) + + if mode == "analyzers": + assert analyzers is not None, "Analyzers must be provided as a list for mode 'analyzers'" + trainer.load_and_preprocess_analyzers(analyzers, enforce_metric_params) + + elif mode == "csv": + for metrics_path in metrics_paths: + assert Path(metrics_path).is_file(), f"{metrics_path} is not a file." + trainer.load_and_preprocess_csv(metrics_paths) + + trainer.evaluate_model_config() + return trainer + + +def _get_computed_metrics(sorting_analyzer): + """Loads and organises the computed metrics from a sorting_analyzer into a single dataframe""" + + import pandas as pd + + quality_metrics, template_metrics = try_to_get_metrics_from_analyzer(sorting_analyzer) + calculated_metrics = pd.concat([quality_metrics, template_metrics], axis=1) + + # Remove any metrics for non-existent units, raise error if no units are present + calculated_metrics = calculated_metrics.loc[calculated_metrics.index.isin(sorting_analyzer.sorting.get_unit_ids())] + if calculated_metrics.shape[0] == 0: + raise ValueError("No units present in sorting data") + + return calculated_metrics + + +def try_to_get_metrics_from_analyzer(sorting_analyzer): + + extension_names = ["quality_metrics", "template_metrics"] + metric_extensions = [sorting_analyzer.get_extension(extension_name) for extension_name in extension_names] + + if any(metric_extensions) is False: + raise ValueError( + "At least one of quality metrics or template metrics must be computed before classification.", + "Compute both using `sorting_analyzer.compute('quality_metrics', 'template_metrics')", + ) + + metric_extensions_data = [] + for metric_extension in metric_extensions: + try: + metric_extensions_data.append(metric_extension.get_data()) + except: + metric_extensions_data.append(None) + + return metric_extensions_data + + +def set_default_search_kwargs(search_kwargs): + + if search_kwargs is None: + search_kwargs = {} + + if search_kwargs.get("cv") is None: + search_kwargs["cv"] = 5 + if search_kwargs.get("scoring") is None: + search_kwargs["scoring"] = "balanced_accuracy" + if search_kwargs.get("n_iter") is None: + search_kwargs["n_iter"] = 25 + + return search_kwargs + + +def check_metric_names_are_the_same(metrics_for_each_analyzer): + """ + Given a list of dataframes, checks that the keys are all equal. + """ + + for i, metrics_for_analyzer_1 in enumerate(metrics_for_each_analyzer): + for j, metrics_for_analyzer_2 in enumerate(metrics_for_each_analyzer): + if i > j: + metric_names_1 = set(metrics_for_analyzer_1.keys()) + metric_names_2 = set(metrics_for_analyzer_2.keys()) + if metric_names_1 != metric_names_2: + metrics_in_1_but_not_2 = metric_names_1.difference(metric_names_2) + metrics_in_2_but_not_1 = metric_names_2.difference(metric_names_1) + + error_message = f"Computed metrics are not equal for sorting_analyzers #{j} and #{i}\n" + if metrics_in_1_but_not_2: + error_message += f"#{j} does not contain {metrics_in_1_but_not_2}, which #{i} does." + if metrics_in_2_but_not_1: + error_message += f"#{i} does not contain {metrics_in_2_but_not_1}, which #{j} does." + raise Exception(error_message) + + +def _format_metric_dataframe(input_data): + + input_data = input_data.map(lambda x: np.nan if np.isinf(x) else x) + input_data = input_data.astype("float32") + + return input_data diff --git a/src/spikeinterface/exporters/report.py b/src/spikeinterface/exporters/report.py index 3a4be9213a..ab08401382 100644 --- a/src/spikeinterface/exporters/report.py +++ b/src/spikeinterface/exporters/report.py @@ -20,10 +20,10 @@ def export_report( **job_kwargs, ): """ - Exports a SI spike sorting report. The report includes summary figures of the spike sorting output - (e.g. amplitude distributions, unit localization and depth VS amplitude) as well as unit-specific reports, - that include waveforms, templates, template maps, ISI distributions, and more. - + Exports a SI spike sorting report. The report includes summary figures of the spike sorting output. + What is plotted depends on what has been calculated. Unit locations and unit waveforms are always included. + Unit waveform densities, correlograms and spike amplitudes are plotted if `waveforms`, `correlograms`, + and `spike_amplitudes` have been computed for the given `sorting_analyzer`. Parameters ---------- diff --git a/src/spikeinterface/extractors/cbin_ibl.py b/src/spikeinterface/extractors/cbin_ibl.py index d7e5b58e11..728d352973 100644 --- a/src/spikeinterface/extractors/cbin_ibl.py +++ b/src/spikeinterface/extractors/cbin_ibl.py @@ -1,6 +1,7 @@ from __future__ import annotations from pathlib import Path +import warnings import numpy as np import probeinterface @@ -30,8 +31,10 @@ class CompressedBinaryIblExtractor(BaseRecording): stream_name : {"ap", "lp"}, default: "ap". Whether to load AP or LFP band, one of "ap" or "lp". - cbin_file : str or None, default None + cbin_file_path : str, Path or None, default None The cbin file of the recording. If None, searches in `folder_path` for file. + cbin_file : str or None, default None + (deprecated) The cbin file of the recording. If None, searches in `folder_path` for file. Returns ------- @@ -41,14 +44,23 @@ class CompressedBinaryIblExtractor(BaseRecording): installation_mesg = "To use the CompressedBinaryIblExtractor, install mtscomp: \n\n pip install mtscomp\n\n" - def __init__(self, folder_path=None, load_sync_channel=False, stream_name="ap", cbin_file=None): + def __init__( + self, folder_path=None, load_sync_channel=False, stream_name="ap", cbin_file_path=None, cbin_file=None + ): from neo.rawio.spikeglxrawio import read_meta_file try: import mtscomp except ImportError: raise ImportError(self.installation_mesg) - if cbin_file is None: + if cbin_file is not None: + warnings.warn( + "The `cbin_file` argument is deprecated, please use `cbin_file_path` instead", + DeprecationWarning, + stacklevel=2, + ) + cbin_file_path = cbin_file + if cbin_file_path is None: folder_path = Path(folder_path) # check bands assert stream_name in ["ap", "lp"], "stream_name must be one of: 'ap', 'lp'" @@ -60,17 +72,17 @@ def __init__(self, folder_path=None, load_sync_channel=False, stream_name="ap", assert ( len(curr_cbin_files) == 1 ), f"There should only be one `*.cbin` file in the folder, but {print(curr_cbin_files)} have been found" - cbin_file = curr_cbin_files[0] + cbin_file_path = curr_cbin_files[0] else: - cbin_file = Path(cbin_file) - folder_path = cbin_file.parent + cbin_file_path = Path(cbin_file_path) + folder_path = cbin_file_path.parent - ch_file = cbin_file.with_suffix(".ch") - meta_file = cbin_file.with_suffix(".meta") + ch_file = cbin_file_path.with_suffix(".ch") + meta_file = cbin_file_path.with_suffix(".meta") # reader cbuffer = mtscomp.Reader() - cbuffer.open(cbin_file, ch_file) + cbuffer.open(cbin_file_path, ch_file) # meta data meta = read_meta_file(meta_file) @@ -119,7 +131,7 @@ def __init__(self, folder_path=None, load_sync_channel=False, stream_name="ap", self._kwargs = { "folder_path": str(Path(folder_path).resolve()), "load_sync_channel": load_sync_channel, - "cbin_file": str(Path(cbin_file).resolve()), + "cbin_file_path": str(Path(cbin_file_path).resolve()), } diff --git a/src/spikeinterface/extractors/mdaextractors.py b/src/spikeinterface/extractors/mdaextractors.py index f055e1d7c9..d2886d9e79 100644 --- a/src/spikeinterface/extractors/mdaextractors.py +++ b/src/spikeinterface/extractors/mdaextractors.py @@ -72,6 +72,7 @@ def write_recording( params_fname="params.json", geom_fname="geom.csv", dtype=None, + verbose=False, **job_kwargs, ): """Write a recording to file in MDA format. @@ -93,6 +94,8 @@ def write_recording( File name of geom file dtype : dtype or None, default: None Data type to be used. If None dtype is same as recording traces. + verbose : bool + If True, shows progress bar when saving recording. **job_kwargs: Use by job_tools modules to set: @@ -130,6 +133,7 @@ def write_recording( dtype=dtype, byte_offset=header_size, add_file_extension=False, + verbose=verbose, **job_kwargs, ) diff --git a/src/spikeinterface/extractors/neoextractors/openephys.py b/src/spikeinterface/extractors/neoextractors/openephys.py index 24bc7591e4..dd24e6cae7 100644 --- a/src/spikeinterface/extractors/neoextractors/openephys.py +++ b/src/spikeinterface/extractors/neoextractors/openephys.py @@ -250,6 +250,7 @@ def __init__( except: warnings.warn(f"Could not load synchronized timestamps for {stream_name}") + self.annotate(experiment_name=f"experiment{exp_id}") self._stream_folders = stream_folders self._kwargs.update( diff --git a/src/spikeinterface/extractors/neoextractors/plexon2.py b/src/spikeinterface/extractors/neoextractors/plexon2.py index 2f360ed864..e0604f7496 100644 --- a/src/spikeinterface/extractors/neoextractors/plexon2.py +++ b/src/spikeinterface/extractors/neoextractors/plexon2.py @@ -28,6 +28,10 @@ class Plexon2RecordingExtractor(NeoBaseRecordingExtractor): ids: ["source3.1" , "source3.2", "source3.3", "source3.4"] all_annotations : bool, default: False Load exhaustively all annotations from neo. + reading_attempts : int, default: 25 + Number of attempts to read the file before raising an error + This opening process is somewhat unreliable and might fail occasionally. Adjust this higher + if you encounter problems in opening the file. Examples -------- @@ -37,8 +41,16 @@ class Plexon2RecordingExtractor(NeoBaseRecordingExtractor): NeoRawIOClass = "Plexon2RawIO" - def __init__(self, file_path, stream_id=None, stream_name=None, use_names_as_ids=True, all_annotations=False): - neo_kwargs = self.map_to_neo_kwargs(file_path) + def __init__( + self, + file_path, + stream_id=None, + stream_name=None, + use_names_as_ids=True, + all_annotations=False, + reading_attempts: int = 25, + ): + neo_kwargs = self.map_to_neo_kwargs(file_path, reading_attempts=reading_attempts) NeoBaseRecordingExtractor.__init__( self, stream_id=stream_id, @@ -50,8 +62,18 @@ def __init__(self, file_path, stream_id=None, stream_name=None, use_names_as_ids self._kwargs.update({"file_path": str(file_path)}) @classmethod - def map_to_neo_kwargs(cls, file_path): + def map_to_neo_kwargs(cls, file_path, reading_attempts: int = 25): + neo_kwargs = {"filename": str(file_path)} + + from packaging.version import Version + import neo + + neo_version = Version(neo.__version__) + + if neo_version > Version("0.13.3"): + neo_kwargs["reading_attempts"] = reading_attempts + return neo_kwargs diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index d797e64910..171992f6b1 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -599,6 +599,8 @@ def __init__( else: gains, offsets, locations, groups = self._fetch_main_properties_backend() self.extra_requirements.append("h5py") + if stream_mode is not None: + self.extra_requirements.append(stream_mode) self.set_channel_gains(gains) self.set_channel_offsets(offsets) if locations is not None: @@ -1100,6 +1102,8 @@ def __init__( for property_name, property_values in properties.items(): values = [x.decode("utf-8") if isinstance(x, bytes) else x for x in property_values] self.set_property(property_name, values) + if stream_mode is not None: + self.extra_requirements.append(stream_mode) if stream_mode is None and file_path is not None: file_path = str(Path(file_path).resolve()) diff --git a/src/spikeinterface/extractors/tests/test_mdaextractors.py b/src/spikeinterface/extractors/tests/test_mdaextractors.py index 0ef6697c6c..78e6afb65e 100644 --- a/src/spikeinterface/extractors/tests/test_mdaextractors.py +++ b/src/spikeinterface/extractors/tests/test_mdaextractors.py @@ -9,6 +9,12 @@ def test_mda_extractors(create_cache_folder): cache_folder = create_cache_folder rec, sort = generate_ground_truth_recording(durations=[10.0], num_units=10) + ids_as_integers = [id for id in range(rec.get_num_channels())] + rec = rec.rename_channels(new_channel_ids=ids_as_integers) + + ids_as_integers = [id for id in range(sort.get_num_units())] + sort = sort.rename_units(new_unit_ids=ids_as_integers) + MdaRecordingExtractor.write_recording(rec, cache_folder / "mdatest") rec_mda = MdaRecordingExtractor(cache_folder / "mdatest") probe = rec_mda.get_probe() diff --git a/src/spikeinterface/extractors/tests/test_neoextractors.py b/src/spikeinterface/extractors/tests/test_neoextractors.py index fcdd766f4f..3da92331a6 100644 --- a/src/spikeinterface/extractors/tests/test_neoextractors.py +++ b/src/spikeinterface/extractors/tests/test_neoextractors.py @@ -368,7 +368,7 @@ class Plexon2RecordingTest(RecordingCommonTestSuite, unittest.TestCase): ExtractorClass = Plexon2RecordingExtractor downloads = ["plexon"] entities = [ - ("plexon/4chDemoPL2.pl2", {"stream_id": "3"}), + ("plexon/4chDemoPL2.pl2", {"stream_name": "WB-Wideband"}), ] diff --git a/src/spikeinterface/postprocessing/localization_tools.py b/src/spikeinterface/postprocessing/localization_tools.py index e6278fc59f..59c12a9923 100644 --- a/src/spikeinterface/postprocessing/localization_tools.py +++ b/src/spikeinterface/postprocessing/localization_tools.py @@ -3,6 +3,8 @@ import warnings import numpy as np +from spikeinterface.core import SortingAnalyzer, Templates, compute_sparsity +from spikeinterface.core.template_tools import _get_nbefore, get_dense_templates_array, get_template_extremum_channel try: import numba @@ -12,10 +14,6 @@ HAVE_NUMBA = False -from spikeinterface.core import compute_sparsity, SortingAnalyzer, Templates -from spikeinterface.core.template_tools import get_template_extremum_channel, _get_nbefore, get_dense_templates_array - - def compute_monopolar_triangulation( sorting_analyzer_or_templates: SortingAnalyzer | Templates, unit_ids=None, @@ -77,7 +75,11 @@ def compute_monopolar_triangulation( contact_locations = sorting_analyzer_or_templates.get_channel_locations() - sparsity = compute_sparsity(sorting_analyzer_or_templates, method="radius", radius_um=radius_um) + if sorting_analyzer_or_templates.sparsity is None: + sparsity = compute_sparsity(sorting_analyzer_or_templates, method="radius", radius_um=radius_um) + else: + sparsity = sorting_analyzer_or_templates.sparsity + templates = get_dense_templates_array( sorting_analyzer_or_templates, return_scaled=get_return_scaled(sorting_analyzer_or_templates) ) @@ -106,7 +108,7 @@ def compute_monopolar_triangulation( # wf is (nsample, nchan) - chann is only nieghboor wf = templates[i, :, :][:, chan_inds] if feature == "ptp": - wf_data = wf.ptp(axis=0) + wf_data = np.ptp(wf, axis=0) elif feature == "energy": wf_data = np.linalg.norm(wf, axis=0) elif feature == "peak_voltage": @@ -157,9 +159,13 @@ def compute_center_of_mass( assert feature in ["ptp", "mean", "energy", "peak_voltage"], f"{feature} is not a valid feature" - sparsity = compute_sparsity( - sorting_analyzer_or_templates, peak_sign=peak_sign, method="radius", radius_um=radius_um - ) + if sorting_analyzer_or_templates.sparsity is None: + sparsity = compute_sparsity( + sorting_analyzer_or_templates, peak_sign=peak_sign, method="radius", radius_um=radius_um + ) + else: + sparsity = sorting_analyzer_or_templates.sparsity + templates = get_dense_templates_array( sorting_analyzer_or_templates, return_scaled=get_return_scaled(sorting_analyzer_or_templates) ) @@ -180,7 +186,7 @@ def compute_center_of_mass( wf = templates[i, :, :] if feature == "ptp": - wf_data = (wf[:, chan_inds]).ptp(axis=0) + wf_data = np.ptp(wf[:, chan_inds], axis=0) elif feature == "mean": wf_data = (wf[:, chan_inds]).mean(axis=0) elif feature == "energy": @@ -650,8 +656,55 @@ def get_convolution_weights( enforce_decrease_shells = numba.jit(enforce_decrease_shells_data, nopython=True) +def compute_location_max_channel( + templates_or_sorting_analyzer: SortingAnalyzer | Templates, + unit_ids=None, + peak_sign: "neg" | "pos" | "both" = "neg", + mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", +) -> np.ndarray: + """ + Localize a unit using max channel. + + This uses internally `get_template_extremum_channel()` + + + Parameters + ---------- + templates_or_sorting_analyzer : SortingAnalyzer | Templates + A SortingAnalyzer or Templates object + unit_ids: list[str] | list[int] | None + A list of unit_id to restrict the computation + peak_sign : "neg" | "pos" | "both" + Sign of the template to find extremum channels + mode : "extremum" | "at_index" | "peak_to_peak", default: "at_index" + Where the amplitude is computed + * "extremum" : take the peak value (max or min depending on `peak_sign`) + * "at_index" : take value at `nbefore` index + * "peak_to_peak" : take the peak-to-peak amplitude + + Returns + ------- + unit_locations: np.ndarray + 2d + """ + extremum_channels_index = get_template_extremum_channel( + templates_or_sorting_analyzer, peak_sign=peak_sign, mode=mode, outputs="index" + ) + contact_locations = templates_or_sorting_analyzer.get_channel_locations() + if unit_ids is None: + unit_ids = templates_or_sorting_analyzer.unit_ids + else: + unit_ids = np.asarray(unit_ids) + unit_locations = np.zeros((unit_ids.size, 2), dtype="float32") + for i, unit_id in enumerate(unit_ids): + unit_locations[i, :] = contact_locations[extremum_channels_index[unit_id]] + + return unit_locations + + _unit_location_methods = { "center_of_mass": compute_center_of_mass, "grid_convolution": compute_grid_convolution, "monopolar_triangulation": compute_monopolar_triangulation, + "max_channel": compute_location_max_channel, } diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 306e9594b8..1969480503 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -63,23 +63,10 @@ class ComputeTemplateMetrics(AnalyzerExtension): include_multi_channel_metrics : bool, default: False Whether to compute multi-channel metrics delete_existing_metrics : bool, default: False - If True, any template metrics attached to the `sorting_analyzer` are deleted. If False, any metrics which were previously calculated but are not included in `metric_names` are kept, provided the `metrics_kwargs` are unchanged. - metrics_kwargs : dict - Additional arguments to pass to the metric functions. Including: - * recovery_window_ms: the window in ms after the peak to compute the recovery_slope, default: 0.7 - * peak_relative_threshold: the relative threshold to detect positive and negative peaks, default: 0.2 - * peak_width_ms: the width in samples to detect peaks, default: 0.2 - * depth_direction: the direction to compute velocity above and below, default: "y" (see notes) - * min_channels_for_velocity: the minimum number of channels above or below to compute velocity, default: 5 - * min_r2_velocity: the minimum r2 to accept the velocity fit, default: 0.7 - * exp_peak_function: the function to use to compute the peak amplitude for the exp decay, default: "ptp" - * min_r2_exp_decay: the minimum r2 to accept the exp decay fit, default: 0.5 - * spread_threshold: the threshold to compute the spread, default: 0.2 - * spread_smooth_um: the smoothing in um to compute the spread, default: 20 - * column_range: the range in um in the horizontal direction to consider channels for velocity, default: None - - If None, all channels all channels are considered - - If 0 or 1, only the "column" that includes the max channel is considered - - If > 1, only channels within range (+/-) um from the max channel horizontal position are used + If True, any template metrics attached to the `sorting_analyzer` are deleted. If False, any metrics which were previously calculated but are not included in `metric_names` are kept, provided the `metric_params` are unchanged. + metric_params : dict of dicts or None, default: None + Dictionary with parameters for template metrics calculation. + Default parameters can be obtained with: `si.postprocessing.template_metrics.get_default_tm_params()` Returns ------- @@ -97,18 +84,32 @@ class ComputeTemplateMetrics(AnalyzerExtension): extension_name = "template_metrics" depend_on = ["templates"] - need_recording = True + need_recording = False use_nodepipeline = False need_job_kwargs = False + need_backward_compatibility_on_load = True min_channels_for_multi_channel_warning = 10 + def _handle_backward_compatibility_on_load(self): + + # For backwards compatibility - this reformats metrics_kwargs as metric_params + if (metrics_kwargs := self.params.get("metrics_kwargs")) is not None: + + metric_params = {} + for metric_name in self.params["metric_names"]: + metric_params[metric_name] = deepcopy(metrics_kwargs) + self.params["metric_params"] = metric_params + + del self.params["metrics_kwargs"] + def _set_params( self, metric_names=None, peak_sign="neg", upsampling_factor=10, sparsity=None, + metric_params=None, metrics_kwargs=None, include_multi_channel_metrics=False, delete_existing_metrics=False, @@ -134,33 +135,24 @@ def _set_params( if include_multi_channel_metrics: metric_names += get_multi_channel_template_metric_names() - if metrics_kwargs is None: - metrics_kwargs_ = _default_function_kwargs.copy() - if len(other_kwargs) > 0: - for m in other_kwargs: - if m in metrics_kwargs_: - metrics_kwargs_[m] = other_kwargs[m] - else: - metrics_kwargs_ = _default_function_kwargs.copy() - metrics_kwargs_.update(metrics_kwargs) + if metrics_kwargs is not None and metric_params is None: + deprecation_msg = "`metrics_kwargs` is deprecated and will be removed in version 0.104.0. Please use `metric_params` instead" + deprecation_msg = "`metrics_kwargs` is deprecated and will be removed in version 0.104.0. Please use `metric_params` instead" + + metric_params = {} + for metric_name in metric_names: + metric_params[metric_name] = deepcopy(metrics_kwargs) + + metric_params_ = get_default_tm_params(metric_names) + for k in metric_params_: + if metric_params is not None and k in metric_params: + metric_params_[k].update(metric_params[k]) metrics_to_compute = metric_names tm_extension = self.sorting_analyzer.get_extension("template_metrics") if delete_existing_metrics is False and tm_extension is not None: - existing_params = tm_extension.params["metrics_kwargs"] - # checks that existing metrics were calculated using the same params - if existing_params != metrics_kwargs_: - warnings.warn( - f"The parameters used to calculate the previous template metrics are different" - f"than those used now.\nPrevious parameters: {existing_params}\nCurrent " - f"parameters: {metrics_kwargs_}\nDeleting previous template metrics..." - ) - tm_extension.params["metric_names"] = [] - existing_metric_names = [] - else: - existing_metric_names = tm_extension.params["metric_names"] - + existing_metric_names = tm_extension.params["metric_names"] existing_metric_names_propogated = [ metric_name for metric_name in existing_metric_names if metric_name not in metrics_to_compute ] @@ -171,7 +163,7 @@ def _set_params( sparsity=sparsity, peak_sign=peak_sign, upsampling_factor=int(upsampling_factor), - metrics_kwargs=metrics_kwargs_, + metric_params=metric_params_, delete_existing_metrics=delete_existing_metrics, metrics_to_compute=metrics_to_compute, ) @@ -273,7 +265,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri sampling_frequency=sampling_frequency_up, trough_idx=trough_idx, peak_idx=peak_idx, - **self.params["metrics_kwargs"], + **self.params["metric_params"][metric_name], ) except Exception as e: warnings.warn(f"Error computing metric {metric_name} for unit {unit_id}: {e}") @@ -312,7 +304,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri template_upsampled, channel_locations=channel_locations_sparse, sampling_frequency=sampling_frequency_up, - **self.params["metrics_kwargs"], + **self.params["metric_params"][metric_name], ) except Exception as e: warnings.warn(f"Error computing metric {metric_name} for unit {unit_id}: {e}") @@ -326,8 +318,8 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri def _run(self, verbose=False): - delete_existing_metrics = self.params["delete_existing_metrics"] metrics_to_compute = self.params["metrics_to_compute"] + delete_existing_metrics = self.params["delete_existing_metrics"] # compute the metrics which have been specified by the user computed_metrics = self._compute_metrics( @@ -343,9 +335,21 @@ def _run(self, verbose=False): ): existing_metrics = tm_extension.params["metric_names"] + existing_metrics = [] + # here we get in the loaded via the dict only (to avoid full loading from disk after params reset) + tm_extension = self.sorting_analyzer.extensions.get("template_metrics", None) + if ( + delete_existing_metrics is False + and tm_extension is not None + and tm_extension.data.get("metrics") is not None + ): + existing_metrics = tm_extension.params["metric_names"] + # append the metrics which were previously computed for metric_name in set(existing_metrics).difference(metrics_to_compute): - computed_metrics[metric_name] = tm_extension.data["metrics"][metric_name] + # some metrics names produce data columns with other names. This deals with that. + for column_name in tm_compute_name_to_column_names[metric_name]: + computed_metrics[column_name] = tm_extension.data["metrics"][column_name] self.data["metrics"] = computed_metrics @@ -372,6 +376,35 @@ def _get_data(self): ) +def get_default_tm_params(metric_names): + if metric_names is None: + metric_names = get_template_metric_names() + + base_tm_params = _default_function_kwargs + + metric_params = {} + for metric_name in metric_names: + metric_params[metric_name] = deepcopy(base_tm_params) + + return metric_params + + +# a dict converting the name of the metric for computation to the output of that computation +tm_compute_name_to_column_names = { + "peak_to_valley": ["peak_to_valley"], + "peak_trough_ratio": ["peak_trough_ratio"], + "half_width": ["half_width"], + "repolarization_slope": ["repolarization_slope"], + "recovery_slope": ["recovery_slope"], + "num_positive_peaks": ["num_positive_peaks"], + "num_negative_peaks": ["num_negative_peaks"], + "velocity_above": ["velocity_above"], + "velocity_below": ["velocity_below"], + "exp_decay": ["exp_decay"], + "spread": ["spread"], +} + + def get_trough_and_peak_idx(template): """ Return the indices into the input template of the detected trough diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index cfa9d89fea..6c30e2730b 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -44,7 +44,7 @@ class ComputeTemplateSimilarity(AnalyzerExtension): extension_name = "template_similarity" depend_on = ["templates"] - need_recording = True + need_recording = False use_nodepipeline = False need_job_kwargs = False need_backward_compatibility_on_load = True diff --git a/src/spikeinterface/postprocessing/tests/test_multi_extensions.py b/src/spikeinterface/postprocessing/tests/test_multi_extensions.py index bf0000135c..be0070d94a 100644 --- a/src/spikeinterface/postprocessing/tests/test_multi_extensions.py +++ b/src/spikeinterface/postprocessing/tests/test_multi_extensions.py @@ -23,6 +23,11 @@ def get_dataset(): seed=2205, ) + channel_ids_as_integers = [id for id in range(recording.get_num_channels())] + unit_ids_as_integers = [id for id in range(sorting.get_num_units())] + recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers) + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers) + # since templates are going to be averaged and this might be a problem for amplitude scaling # we select the 3 units with the largest templates to split analyzer_raw = create_sorting_analyzer(sorting, recording, format="memory", sparse=False) diff --git a/src/spikeinterface/postprocessing/tests/test_template_metrics.py b/src/spikeinterface/postprocessing/tests/test_template_metrics.py index 5056d4ff2a..1bf49f64c1 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_metrics.py +++ b/src/spikeinterface/postprocessing/tests/test_template_metrics.py @@ -1,5 +1,5 @@ from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite -from spikeinterface.postprocessing import ComputeTemplateMetrics +from spikeinterface.postprocessing import ComputeTemplateMetrics, compute_template_metrics import pytest import csv @@ -8,6 +8,49 @@ template_metrics = list(_single_channel_metric_name_to_func.keys()) +def test_different_params_template_metrics(small_sorting_analyzer): + """ + Computes template metrics using different params, and check that they are + actually calculated using the different params. + """ + compute_template_metrics( + sorting_analyzer=small_sorting_analyzer, + metric_names=["exp_decay", "spread", "half_width"], + metric_params={"exp_decay": {"recovery_window_ms": 0.8}, "spread": {"spread_smooth_um": 15}}, + ) + + tm_extension = small_sorting_analyzer.get_extension("template_metrics") + tm_params = tm_extension.params["metric_params"] + + assert tm_params["exp_decay"]["recovery_window_ms"] == 0.8 + assert tm_params["spread"]["recovery_window_ms"] == 0.7 + assert tm_params["half_width"]["recovery_window_ms"] == 0.7 + + assert tm_params["spread"]["spread_smooth_um"] == 15 + assert tm_params["exp_decay"]["spread_smooth_um"] == 20 + assert tm_params["half_width"]["spread_smooth_um"] == 20 + + +def test_backwards_compat_params_template_metrics(small_sorting_analyzer): + """ + Computes template metrics using the metrics_kwargs keyword + """ + compute_template_metrics( + sorting_analyzer=small_sorting_analyzer, + metric_names=["exp_decay", "spread"], + metrics_kwargs={"recovery_window_ms": 0.8}, + ) + + tm_extension = small_sorting_analyzer.get_extension("template_metrics") + tm_params = tm_extension.params["metric_params"] + + assert tm_params["exp_decay"]["recovery_window_ms"] == 0.8 + assert tm_params["spread"]["recovery_window_ms"] == 0.8 + + assert tm_params["spread"]["spread_smooth_um"] == 20 + assert tm_params["exp_decay"]["spread_smooth_um"] == 20 + + def test_compute_new_template_metrics(small_sorting_analyzer): """ Computes template metrics then computes a subset of template metrics, and checks @@ -17,6 +60,8 @@ def test_compute_new_template_metrics(small_sorting_analyzer): are deleted. """ + small_sorting_analyzer.delete_extension("template_metrics") + # calculate just exp_decay small_sorting_analyzer.compute({"template_metrics": {"metric_names": ["exp_decay"]}}) template_metric_extension = small_sorting_analyzer.get_extension("template_metrics") @@ -47,7 +92,7 @@ def test_compute_new_template_metrics(small_sorting_analyzer): # check that, when parameters are changed, the old metrics are deleted small_sorting_analyzer.compute( - {"template_metrics": {"metric_names": ["exp_decay"], "metrics_kwargs": {"recovery_window_ms": 0.6}}} + {"template_metrics": {"metric_names": ["exp_decay"], "metric_params": {"recovery_window_ms": 0.6}}} ) diff --git a/src/spikeinterface/postprocessing/tests/test_unit_locations.py b/src/spikeinterface/postprocessing/tests/test_unit_locations.py index c40a917a2b..545edb3497 100644 --- a/src/spikeinterface/postprocessing/tests/test_unit_locations.py +++ b/src/spikeinterface/postprocessing/tests/test_unit_locations.py @@ -13,6 +13,7 @@ class TestUnitLocationsExtension(AnalyzerExtensionCommonTestSuite): dict(method="grid_convolution", radius_um=150, weight_method={"mode": "gaussian_2d"}), dict(method="monopolar_triangulation", radius_um=150), dict(method="monopolar_triangulation", radius_um=150, optimizer="minimize_with_log_penality"), + dict(method="max_channel"), ], ) def test_extension(self, params): diff --git a/src/spikeinterface/postprocessing/unit_locations.py b/src/spikeinterface/postprocessing/unit_locations.py index 4029fc88c7..df19458316 100644 --- a/src/spikeinterface/postprocessing/unit_locations.py +++ b/src/spikeinterface/postprocessing/unit_locations.py @@ -26,7 +26,7 @@ class ComputeUnitLocations(AnalyzerExtension): ---------- sorting_analyzer : SortingAnalyzer A SortingAnalyzer object - method : "center_of_mass" | "monopolar_triangulation" | "grid_convolution", default: "center_of_mass" + method : "monopolar_triangulation" | "center_of_mass" | "grid_convolution", default: "monopolar_triangulation" The method to use for localization **method_kwargs : dict, default: {} Kwargs which are passed to the method function. These can be found in the docstrings of `compute_center_of_mass`, `compute_grid_convolution` and `compute_monopolar_triangulation`. @@ -39,7 +39,7 @@ class ComputeUnitLocations(AnalyzerExtension): extension_name = "unit_locations" depend_on = ["templates"] - need_recording = True + need_recording = False use_nodepipeline = False need_job_kwargs = False need_backward_compatibility_on_load = True diff --git a/src/spikeinterface/preprocessing/decimate.py b/src/spikeinterface/preprocessing/decimate.py index 334ebb02d2..d5fc9d2025 100644 --- a/src/spikeinterface/preprocessing/decimate.py +++ b/src/spikeinterface/preprocessing/decimate.py @@ -63,18 +63,15 @@ def __init__( f"Consider combining DecimateRecording with FrameSliceRecording for fine control on the recording start/end frames." ) self._decimation_offset = decimation_offset - resample_rate = self._orig_samp_freq / self._decimation_factor + decimated_sampling_frequency = self._orig_samp_freq / self._decimation_factor - BasePreprocessor.__init__(self, recording, sampling_frequency=resample_rate) + BasePreprocessor.__init__(self, recording, sampling_frequency=decimated_sampling_frequency) - # in case there was a time_vector, it will be dropped for sanity. - # This is not necessary but consistent with ResampleRecording for parent_segment in recording._recording_segments: - parent_segment.time_vector = None self.add_recording_segment( DecimateRecordingSegment( parent_segment, - resample_rate, + decimated_sampling_frequency, self._orig_samp_freq, decimation_factor, decimation_offset, @@ -93,22 +90,26 @@ class DecimateRecordingSegment(BaseRecordingSegment): def __init__( self, parent_recording_segment, - resample_rate, + decimated_sampling_frequency, parent_rate, decimation_factor, decimation_offset, dtype, ): - if parent_recording_segment.t_start is None: - new_t_start = None + if parent_recording_segment.time_vector is not None: + time_vector = parent_recording_segment.time_vector[decimation_offset::decimation_factor] + decimated_sampling_frequency = None + t_start = None else: - new_t_start = parent_recording_segment.t_start + decimation_offset / parent_rate + time_vector = None + if parent_recording_segment.t_start is None: + t_start = None + else: + t_start = parent_recording_segment.t_start + (decimation_offset / parent_rate) # Do not use BasePreprocessorSegment bcause we have to reset the sampling rate! BaseRecordingSegment.__init__( - self, - sampling_frequency=resample_rate, - t_start=new_t_start, + self, sampling_frequency=decimated_sampling_frequency, t_start=t_start, time_vector=time_vector ) self._parent_segment = parent_recording_segment self._decimation_factor = decimation_factor diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index 8f38f01469..00d9a1a407 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -71,8 +71,10 @@ def __init__(self, recording, list_periods, mode="zeros", noise_levels=None, see if mode in ["noise"]: if noise_levels is None: + random_slices_kwargs = random_chunk_kwargs.copy() + random_slices_kwargs["seed"] = seed noise_levels = get_noise_levels( - recording, return_scaled=False, concatenated=True, seed=seed, **random_chunk_kwargs + recording, return_scaled=False, random_slices_kwargs=random_slices_kwargs ) noise_generator = NoiseGeneratorRecording( num_channels=recording.get_num_channels(), @@ -95,7 +97,8 @@ def __init__(self, recording, list_periods, mode="zeros", noise_levels=None, see rec_segment = SilencedPeriodsRecordingSegment(parent_segment, periods, mode, noise_generator, seg_index) self.add_recording_segment(rec_segment) - self._kwargs = dict(recording=recording, list_periods=list_periods, mode=mode, noise_generator=noise_generator) + self._kwargs = dict(recording=recording, list_periods=list_periods, mode=mode, seed=seed) + self._kwargs.update(random_chunk_kwargs) class SilencedPeriodsRecordingSegment(BasePreprocessorSegment): diff --git a/src/spikeinterface/preprocessing/tests/test_clip.py b/src/spikeinterface/preprocessing/tests/test_clip.py index 724ba2c963..c18c7d37af 100644 --- a/src/spikeinterface/preprocessing/tests/test_clip.py +++ b/src/spikeinterface/preprocessing/tests/test_clip.py @@ -14,12 +14,12 @@ def test_clip(): rec1 = clip(rec, a_min=-1.5) rec1.save(verbose=False) - traces0 = rec0.get_traces(segment_index=0, channel_ids=[1]) + traces0 = rec0.get_traces(segment_index=0, channel_ids=["1"]) assert traces0.shape[1] == 1 assert np.all(-2 <= traces0[0] <= 3) - traces1 = rec1.get_traces(segment_index=0, channel_ids=[0, 1]) + traces1 = rec1.get_traces(segment_index=0, channel_ids=["0", "1"]) assert traces1.shape[1] == 2 assert np.all(-1.5 <= traces1[1]) @@ -34,11 +34,11 @@ def test_blank_staturation(): rec1 = blank_staturation(rec, quantile_threshold=0.01, direction="both", chunk_size=10000) rec1.save(verbose=False) - traces0 = rec0.get_traces(segment_index=0, channel_ids=[1]) + traces0 = rec0.get_traces(segment_index=0, channel_ids=["1"]) assert traces0.shape[1] == 1 assert np.all(traces0 < 3.0) - traces1 = rec1.get_traces(segment_index=0, channel_ids=[0]) + traces1 = rec1.get_traces(segment_index=0, channel_ids=["0"]) assert traces1.shape[1] == 1 # use a smaller value to be sure a_min = rec1._recording_segments[0].a_min diff --git a/src/spikeinterface/preprocessing/tests/test_decimate.py b/src/spikeinterface/preprocessing/tests/test_decimate.py index 100972f762..aab17560a6 100644 --- a/src/spikeinterface/preprocessing/tests/test_decimate.py +++ b/src/spikeinterface/preprocessing/tests/test_decimate.py @@ -8,19 +8,14 @@ import numpy as np -@pytest.mark.parametrize("N_segments", [1, 2]) -@pytest.mark.parametrize("decimation_offset", [0, 1, 9, 10, 11, 100, 101]) -@pytest.mark.parametrize("decimation_factor", [1, 9, 10, 11, 100, 101]) -@pytest.mark.parametrize("start_frame", [0, 1, 5, None, 1000]) -@pytest.mark.parametrize("end_frame", [0, 1, 5, None, 1000]) -def test_decimate(N_segments, decimation_offset, decimation_factor, start_frame, end_frame): - rec = generate_recording() - - segment_num_samps = [101 + i for i in range(N_segments)] - +@pytest.mark.parametrize("num_segments", [1, 2]) +@pytest.mark.parametrize("decimation_offset", [0, 1, 5, 21, 101]) +@pytest.mark.parametrize("decimation_factor", [1, 7, 50]) +def test_decimate(num_segments, decimation_offset, decimation_factor): + segment_num_samps = [20000, 40000] rec = NumpyRecording([np.arange(2 * N).reshape(N, 2) for N in segment_num_samps], 1) - parent_traces = [rec.get_traces(i) for i in range(N_segments)] + parent_traces = [rec.get_traces(i) for i in range(num_segments)] if decimation_offset >= min(segment_num_samps) or decimation_offset >= decimation_factor: with pytest.raises(ValueError): @@ -28,19 +23,59 @@ def test_decimate(N_segments, decimation_offset, decimation_factor, start_frame, return decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset) - decimated_parent_traces = [parent_traces[i][decimation_offset::decimation_factor] for i in range(N_segments)] + decimated_parent_traces = [parent_traces[i][decimation_offset::decimation_factor] for i in range(num_segments)] - if start_frame is None: - start_frame = max(decimated_rec.get_num_samples(i) for i in range(N_segments)) - if end_frame is None: - end_frame = max(decimated_rec.get_num_samples(i) for i in range(N_segments)) + for start_frame in [0, 1, 5, None, 1000]: + for end_frame in [0, 1, 5, None, 1000]: + if start_frame is None: + start_frame = max(decimated_rec.get_num_samples(i) for i in range(num_segments)) + if end_frame is None: + end_frame = max(decimated_rec.get_num_samples(i) for i in range(num_segments)) - for i in range(N_segments): + for i in range(num_segments): + assert decimated_rec.get_num_samples(i) == decimated_parent_traces[i].shape[0] + assert np.all( + decimated_rec.get_traces(i, start_frame, end_frame) + == decimated_parent_traces[i][start_frame:end_frame] + ) + + for i in range(num_segments): assert decimated_rec.get_num_samples(i) == decimated_parent_traces[i].shape[0] assert np.all( decimated_rec.get_traces(i, start_frame, end_frame) == decimated_parent_traces[i][start_frame:end_frame] ) +def test_decimate_with_times(): + rec = generate_recording(durations=[5, 10]) + + # test with times + times = [rec.get_times(0) + 10, rec.get_times(1) + 20] + for i, t in enumerate(times): + rec.set_times(t, i) + + decimation_factor = 2 + decimation_offset = 1 + decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset) + + for segment_index in range(rec.get_num_segments()): + assert np.allclose( + decimated_rec.get_times(segment_index), + rec.get_times(segment_index)[decimation_offset::decimation_factor], + ) + + # test with t_start + rec = generate_recording(durations=[5, 10]) + t_starts = [10, 20] + for t_start, rec_segment in zip(t_starts, rec._recording_segments): + rec_segment.t_start = t_start + decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset) + for segment_index in range(rec.get_num_segments()): + assert np.allclose( + decimated_rec.get_times(segment_index), + rec.get_times(segment_index)[decimation_offset::decimation_factor], + ) + + if __name__ == "__main__": test_decimate() diff --git a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py index 1189f04f7d..06bde4e3d1 100644 --- a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py +++ b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py @@ -163,7 +163,9 @@ def test_output_values(): expected_weights = np.r_[np.tile(np.exp(-2), 3), np.exp(-4)] expected_weights /= np.sum(expected_weights) - si_interpolated_recording = spre.interpolate_bad_channels(recording, bad_channel_indexes, sigma_um=1, p=1) + si_interpolated_recording = spre.interpolate_bad_channels( + recording, bad_channel_ids=bad_channel_ids, sigma_um=1, p=1 + ) si_interpolated = si_interpolated_recording.get_traces() expected_ts = si_interpolated[:, 1:] @ expected_weights diff --git a/src/spikeinterface/preprocessing/tests/test_normalize_scale.py b/src/spikeinterface/preprocessing/tests/test_normalize_scale.py index 576b570832..151752e0e6 100644 --- a/src/spikeinterface/preprocessing/tests/test_normalize_scale.py +++ b/src/spikeinterface/preprocessing/tests/test_normalize_scale.py @@ -15,7 +15,7 @@ def test_normalize_by_quantile(): rec2 = normalize_by_quantile(rec, mode="by_channel") rec2.save(verbose=False) - traces = rec2.get_traces(segment_index=0, channel_ids=[1]) + traces = rec2.get_traces(segment_index=0, channel_ids=["1"]) assert traces.shape[1] == 1 rec2 = normalize_by_quantile(rec, mode="pool_channel") diff --git a/src/spikeinterface/preprocessing/tests/test_rectify.py b/src/spikeinterface/preprocessing/tests/test_rectify.py index b8bb31015e..a2a06e7a1f 100644 --- a/src/spikeinterface/preprocessing/tests/test_rectify.py +++ b/src/spikeinterface/preprocessing/tests/test_rectify.py @@ -15,7 +15,7 @@ def test_rectify(): rec2 = rectify(rec) rec2.save(verbose=False) - traces = rec2.get_traces(segment_index=0, channel_ids=[1]) + traces = rec2.get_traces(segment_index=0, channel_ids=["1"]) assert traces.shape[1] == 1 # import matplotlib.pyplot as plt diff --git a/src/spikeinterface/preprocessing/tests/test_scaling.py b/src/spikeinterface/preprocessing/tests/test_scaling.py index 321d7c9df2..e32d96901e 100644 --- a/src/spikeinterface/preprocessing/tests/test_scaling.py +++ b/src/spikeinterface/preprocessing/tests/test_scaling.py @@ -55,11 +55,11 @@ def test_scaling_in_preprocessing_chain(): recording.set_channel_gains(gains) recording.set_channel_offsets(offsets) - centered_recording = CenterRecording(scale_to_uV(recording=recording)) + centered_recording = CenterRecording(scale_to_uV(recording=recording), seed=2205) traces_scaled_with_argument = centered_recording.get_traces(return_scaled=True) # Chain preprocessors - centered_recording_scaled = CenterRecording(scale_to_uV(recording=recording)) + centered_recording_scaled = CenterRecording(scale_to_uV(recording=recording), seed=2205) traces_scaled_with_preprocessor = centered_recording_scaled.get_traces() np.testing.assert_allclose(traces_scaled_with_argument, traces_scaled_with_preprocessor) @@ -68,3 +68,8 @@ def test_scaling_in_preprocessing_chain(): traces_scaled_with_preprocessor_and_argument = centered_recording_scaled.get_traces(return_scaled=True) np.testing.assert_allclose(traces_scaled_with_preprocessor, traces_scaled_with_preprocessor_and_argument) + + +if __name__ == "__main__": + test_scale_to_uV() + test_scaling_in_preprocessing_chain() diff --git a/src/spikeinterface/preprocessing/tests/test_silence.py b/src/spikeinterface/preprocessing/tests/test_silence.py index 6c2e8ec8b5..20d4f6dfc7 100644 --- a/src/spikeinterface/preprocessing/tests/test_silence.py +++ b/src/spikeinterface/preprocessing/tests/test_silence.py @@ -9,6 +9,8 @@ import numpy as np +from pathlib import Path + def test_silence(create_cache_folder): @@ -46,4 +48,5 @@ def test_silence(create_cache_folder): if __name__ == "__main__": - test_silence() + cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" + test_silence(cache_folder) diff --git a/src/spikeinterface/preprocessing/tests/test_whiten.py b/src/spikeinterface/preprocessing/tests/test_whiten.py index 04b731de4f..b40627d836 100644 --- a/src/spikeinterface/preprocessing/tests/test_whiten.py +++ b/src/spikeinterface/preprocessing/tests/test_whiten.py @@ -5,13 +5,15 @@ from spikeinterface.preprocessing import whiten, scale, compute_whitening_matrix +from pathlib import Path + def test_whiten(create_cache_folder): cache_folder = create_cache_folder rec = generate_recording(num_channels=4, seed=2205) print(rec.get_channel_locations()) - random_chunk_kwargs = {} + random_chunk_kwargs = {"seed": 2205} W1, M = compute_whitening_matrix(rec, "global", random_chunk_kwargs, apply_mean=False, radius_um=None) # print(W) # print(M) @@ -47,4 +49,5 @@ def test_whiten(create_cache_folder): if __name__ == "__main__": - test_whiten() + cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" + test_whiten(cache_folder) diff --git a/src/spikeinterface/preprocessing/whiten.py b/src/spikeinterface/preprocessing/whiten.py index 195969ff79..57400c1199 100644 --- a/src/spikeinterface/preprocessing/whiten.py +++ b/src/spikeinterface/preprocessing/whiten.py @@ -19,6 +19,8 @@ class WhitenRecording(BasePreprocessor): recording : RecordingExtractor The recording extractor to be whitened. dtype : None or dtype, default: None + Datatype of the output recording (covariance matrix estimation + and whitening are performed in float32). If None the the parent dtype is kept. For integer dtype a int_scale must be also given. mode : "global" | "local", default: "global" @@ -74,7 +76,9 @@ def __init__( dtype_ = fix_dtype(recording, dtype) if dtype_.kind == "i": - assert int_scale is not None, "For recording with dtype=int you must set dtype=float32 OR set a int_scale" + assert ( + int_scale is not None + ), "For recording with dtype=int you must set the output dtype to float OR set a int_scale" if W is not None: W = np.asarray(W) @@ -124,7 +128,7 @@ def __init__(self, parent_recording_segment, W, M, dtype, int_scale): def get_traces(self, start_frame, end_frame, channel_indices): traces = self.parent_recording_segment.get_traces(start_frame, end_frame, slice(None)) traces_dtype = traces.dtype - # if uint --> force int + # if uint --> force float if traces_dtype.kind == "u": traces = traces.astype("float32") @@ -185,6 +189,7 @@ def compute_whitening_matrix( """ random_data = get_random_data_chunks(recording, concatenated=True, return_scaled=False, **random_chunk_kwargs) + random_data = random_data.astype(np.float32) regularize_kwargs = regularize_kwargs if regularize_kwargs is not None else {"method": "GraphicalLassoCV"} diff --git a/src/spikeinterface/qualitymetrics/__init__.py b/src/spikeinterface/qualitymetrics/__init__.py index 9d604f6ae2..754c82d8e3 100644 --- a/src/spikeinterface/qualitymetrics/__init__.py +++ b/src/spikeinterface/qualitymetrics/__init__.py @@ -6,4 +6,4 @@ get_default_qm_params, ) from .pca_metrics import get_quality_pca_metric_list -from .misc_metrics import get_synchrony_counts +from .misc_metrics import _get_synchrony_counts diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index 8dfd41cf88..6007de379c 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -520,7 +520,7 @@ def compute_sliding_rp_violations( ) -def get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids): +def _get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids): """ Compute synchrony counts, the number of simultaneous spikes with sizes `synchrony_sizes`. @@ -528,10 +528,10 @@ def get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids): ---------- spikes : np.array Structured numpy array with fields ("sample_index", "unit_index", "segment_index"). - synchrony_sizes : numpy array - The synchrony sizes to compute. Should be pre-sorted. all_unit_ids : list or None, default: None List of unit ids to compute the synchrony metrics. Expecting all units. + synchrony_sizes : None or np.array, default: None + The synchrony sizes to compute. Should be pre-sorted. Returns ------- @@ -565,37 +565,38 @@ def get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids): return synchrony_counts -def compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=(2, 4, 8), unit_ids=None): +def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=None): """ Compute synchrony metrics. Synchrony metrics represent the rate of occurrences of - "synchrony_size" spikes at the exact same sample index. + spikes at the exact same sample index, with synchrony sizes 2, 4 and 8. Parameters ---------- sorting_analyzer : SortingAnalyzer A SortingAnalyzer object. - synchrony_sizes : list or tuple, default: (2, 4, 8) - The synchrony sizes to compute. unit_ids : list or None, default: None List of unit ids to compute the synchrony metrics. If None, all units are used. + synchrony_sizes: None, default: None + Deprecated argument. Please use private `_get_synchrony_counts` if you need finer control over number of synchronous spikes. Returns ------- sync_spike_{X} : dict The synchrony metric for synchrony size X. - Returns are as many as synchrony_sizes. References ---------- Based on concepts described in [Grün]_ This code was adapted from `Elephant - Electrophysiology Analysis Toolkit `_ """ - assert min(synchrony_sizes) > 1, "Synchrony sizes must be greater than 1" - # Sort the synchrony times so we can slice numpy arrays, instead of using dicts - synchrony_sizes_np = np.array(synchrony_sizes, dtype=np.int16) - synchrony_sizes_np.sort() - res = namedtuple("synchrony_metrics", [f"sync_spike_{size}" for size in synchrony_sizes_np]) + if synchrony_sizes is not None: + warning_message = "Custom `synchrony_sizes` is deprecated; the `synchrony_metrics` will be computed using `synchrony_sizes = [2,4,8]`" + warnings.warn(warning_message, DeprecationWarning, stacklevel=2) + + synchrony_sizes = np.array([2, 4, 8]) + + res = namedtuple("synchrony_metrics", [f"sync_spike_{size}" for size in synchrony_sizes]) sorting = sorting_analyzer.sorting @@ -606,10 +607,10 @@ def compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=(2, 4, 8), unit_ spikes = sorting.to_spike_vector() all_unit_ids = sorting.unit_ids - synchrony_counts = get_synchrony_counts(spikes, synchrony_sizes_np, all_unit_ids) + synchrony_counts = _get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids) synchrony_metrics_dict = {} - for sync_idx, synchrony_size in enumerate(synchrony_sizes_np): + for sync_idx, synchrony_size in enumerate(synchrony_sizes): sync_id_metrics_dict = {} for i, unit_id in enumerate(all_unit_ids): if unit_id not in unit_ids: @@ -623,7 +624,7 @@ def compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=(2, 4, 8), unit_ return res(**synchrony_metrics_dict) -_default_params["synchrony"] = dict(synchrony_sizes=(2, 4, 8)) +_default_params["synchrony"] = dict() def compute_firing_ranges(sorting_analyzer, bin_size_s=5, percentiles=(5, 95), unit_ids=None): diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index 4c68dfea59..c789d1af82 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -6,6 +6,7 @@ from copy import deepcopy import platform from tqdm.auto import tqdm +from warnings import warn import numpy as np @@ -41,6 +42,9 @@ max_spikes=10000, min_spikes=10, min_fr=0.0, n_neighbors=4, n_components=10, radius_um=100, peak_sign="neg" ), silhouette=dict(method=("simplified",)), + isolation_distance=dict(), + l_ratio=dict(), + d_prime=dict(), ) @@ -52,6 +56,7 @@ def get_quality_pca_metric_list(): def compute_pc_metrics( sorting_analyzer, metric_names=None, + metric_params=None, qm_params=None, unit_ids=None, seed=None, @@ -70,7 +75,7 @@ def compute_pc_metrics( metric_names : list of str, default: None The list of PC metrics to compute. If not provided, defaults to all PC metrics. - qm_params : dict or None + metric_params : dict or None Dictionary with parameters for each PC metric function. unit_ids : list of int or None List of unit ids to compute metrics for. @@ -86,6 +91,14 @@ def compute_pc_metrics( pc_metrics : dict The computed PC metrics. """ + + if qm_params is not None and metric_params is None: + deprecation_msg = ( + "`qm_params` is deprecated and will be removed in version 0.104.0. Please use metric_params instead" + ) + warn(deprecation_msg, category=DeprecationWarning, stacklevel=2) + metric_params = qm_params + pca_ext = sorting_analyzer.get_extension("principal_components") assert pca_ext is not None, "calculate_pc_metrics() need extension 'principal_components'" @@ -93,8 +106,8 @@ def compute_pc_metrics( if metric_names is None: metric_names = _possible_pc_metric_names.copy() - if qm_params is None: - qm_params = _default_params + if metric_params is None: + metric_params = _default_params extremum_channels = get_template_extremum_channel(sorting_analyzer) @@ -147,7 +160,7 @@ def compute_pc_metrics( pcs = dense_projections[np.isin(all_labels, neighbor_unit_ids)][:, :, neighbor_channel_indices] pcs_flat = pcs.reshape(pcs.shape[0], -1) - func_args = (pcs_flat, labels, non_nn_metrics, unit_id, unit_ids, qm_params, max_threads_per_process) + func_args = (pcs_flat, labels, non_nn_metrics, unit_id, unit_ids, metric_params, max_threads_per_process) items.append(func_args) if not run_in_parallel and non_nn_metrics: @@ -184,7 +197,7 @@ def compute_pc_metrics( units_loop = tqdm(units_loop, desc=f"calculate {metric_name} metric", total=len(unit_ids)) func = _nn_metric_name_to_func[metric_name] - metric_params = qm_params[metric_name] if metric_name in qm_params else {} + metric_params = metric_params[metric_name] if metric_name in metric_params else {} for _, unit_id in units_loop: try: @@ -213,7 +226,7 @@ def compute_pc_metrics( def calculate_pc_metrics( - sorting_analyzer, metric_names=None, qm_params=None, unit_ids=None, seed=None, n_jobs=1, progress_bar=False + sorting_analyzer, metric_names=None, metric_params=None, unit_ids=None, seed=None, n_jobs=1, progress_bar=False ): warnings.warn( "The `calculate_pc_metrics` function is deprecated and will be removed in 0.103.0. Please use compute_pc_metrics instead", @@ -224,7 +237,7 @@ def calculate_pc_metrics( pc_metrics = compute_pc_metrics( sorting_analyzer, metric_names=metric_names, - qm_params=qm_params, + metric_params=metric_params, unit_ids=unit_ids, seed=seed, n_jobs=n_jobs, @@ -977,16 +990,16 @@ def _compute_isolation(pcs_target_unit, pcs_other_unit, n_neighbors: int): def pca_metrics_one_unit(args): - (pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params, max_threads_per_process) = args + (pcs_flat, labels, metric_names, unit_id, unit_ids, metric_params, max_threads_per_process) = args if max_threads_per_process is None: - return _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params) + return _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, metric_params) else: with threadpool_limits(limits=int(max_threads_per_process)): - return _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params) + return _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, metric_params) -def _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params): +def _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, metric_params): pc_metrics = {} # metrics if "isolation_distance" in metric_names or "l_ratio" in metric_names: @@ -1015,7 +1028,7 @@ def _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_ if "nearest_neighbor" in metric_names: try: nn_hit_rate, nn_miss_rate = nearest_neighbors_metrics( - pcs_flat, labels, unit_id, **qm_params["nearest_neighbor"] + pcs_flat, labels, unit_id, **metric_params["nearest_neighbor"] ) except: nn_hit_rate = np.nan @@ -1024,7 +1037,7 @@ def _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_ pc_metrics["nn_miss_rate"] = nn_miss_rate if "silhouette" in metric_names: - silhouette_method = qm_params["silhouette"]["method"] + silhouette_method = metric_params["silhouette"]["method"] if "simplified" in silhouette_method: try: unit_silhouette_score = simplified_silhouette_score(pcs_flat, labels, unit_id) diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index b6a50d60f5..11ce3d0160 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -6,6 +6,7 @@ from copy import deepcopy import numpy as np +from warnings import warn from spikeinterface.core.job_tools import fix_job_kwargs from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension @@ -15,7 +16,8 @@ compute_pc_metrics, _misc_metric_name_to_func, _possible_pc_metric_names, - compute_name_to_column_names, + qm_compute_name_to_column_names, + column_name_to_column_dtype, ) from .misc_metrics import _default_params as misc_metrics_params from .pca_metrics import _default_params as pca_metrics_params @@ -31,7 +33,7 @@ class ComputeQualityMetrics(AnalyzerExtension): A SortingAnalyzer object. metric_names : list or None List of quality metrics to compute. - qm_params : dict or None + metric_params : dict of dicts or None Dictionary with parameters for quality metrics calculation. Default parameters can be obtained with: `si.qualitymetrics.get_default_qm_params()` skip_pc_metrics : bool, default: False @@ -54,10 +56,18 @@ class ComputeQualityMetrics(AnalyzerExtension): need_recording = False use_nodepipeline = False need_job_kwargs = True + need_backward_compatibility_on_load = True + + def _handle_backward_compatibility_on_load(self): + # For backwards compatibility - this renames qm_params as metric_params + if (qm_params := self.params.get("qm_params")) is not None: + self.params["metric_params"] = qm_params + del self.params["qm_params"] def _set_params( self, metric_names=None, + metric_params=None, qm_params=None, peak_sign=None, seed=None, @@ -65,6 +75,12 @@ def _set_params( delete_existing_metrics=False, metrics_to_compute=None, ): + if qm_params is not None and metric_params is None: + deprecation_msg = ( + "`qm_params` is deprecated and will be removed in version 0.104.0 Please use metric_params instead" + ) + metric_params = qm_params + warn(deprecation_msg, category=DeprecationWarning, stacklevel=2) if metric_names is None: metric_names = list(_misc_metric_name_to_func.keys()) @@ -80,12 +96,12 @@ def _set_params( if "drift" in metric_names: metric_names.remove("drift") - qm_params_ = get_default_qm_params() - for k in qm_params_: - if qm_params is not None and k in qm_params: - qm_params_[k].update(qm_params[k]) - if "peak_sign" in qm_params_[k] and peak_sign is not None: - qm_params_[k]["peak_sign"] = peak_sign + metric_params_ = get_default_qm_params() + for k in metric_params_: + if metric_params is not None and k in metric_params: + metric_params_[k].update(metric_params[k]) + if "peak_sign" in metric_params_[k] and peak_sign is not None: + metric_params_[k]["peak_sign"] = peak_sign metrics_to_compute = metric_names qm_extension = self.sorting_analyzer.get_extension("quality_metrics") @@ -101,7 +117,7 @@ def _set_params( metric_names=metric_names, peak_sign=peak_sign, seed=seed, - qm_params=qm_params_, + metric_params=metric_params_, skip_pc_metrics=skip_pc_metrics, delete_existing_metrics=delete_existing_metrics, metrics_to_compute=metrics_to_compute, @@ -125,13 +141,20 @@ def _merge_extension_data( all_unit_ids = new_sorting_analyzer.unit_ids not_new_ids = all_unit_ids[~np.isin(all_unit_ids, new_unit_ids)] + # this creates a new metrics dictionary, but the dtype for everything will be + # object. So we will need to fix this later after computing metrics metrics = pd.DataFrame(index=all_unit_ids, columns=old_metrics.columns) - metrics.loc[not_new_ids, :] = old_metrics.loc[not_new_ids, :] metrics.loc[new_unit_ids, :] = self._compute_metrics( new_sorting_analyzer, new_unit_ids, verbose, metric_names, **job_kwargs ) + # we need to fix the dtypes after we compute everything because we have nans + # we can iterate through the columns and convert them back to the dtype + # of the original quality dataframe. + for column in old_metrics.columns: + metrics[column] = metrics[column].astype(old_metrics[column].dtype) + new_data = dict(metrics=metrics) return new_data @@ -141,7 +164,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri """ import pandas as pd - qm_params = self.params["qm_params"] + metric_params = self.params["metric_params"] # sparsity = self.params["sparsity"] seed = self.params["seed"] @@ -177,7 +200,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri func = _misc_metric_name_to_func[metric_name] - params = qm_params[metric_name] if metric_name in qm_params else {} + params = metric_params[metric_name] if metric_name in metric_params else {} res = func(sorting_analyzer, unit_ids=non_empty_unit_ids, **params) # QM with uninstall dependencies might return None if res is not None: @@ -205,7 +228,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri # sparsity=sparsity, progress_bar=progress_bar, n_jobs=n_jobs, - qm_params=qm_params, + metric_params=metric_params, seed=seed, ) for col, values in pc_metrics.items(): @@ -214,10 +237,20 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri # add NaN for empty units if len(empty_unit_ids) > 0: metrics.loc[empty_unit_ids] = np.nan + # num_spikes is an int and should be 0 + if "num_spikes" in metrics.columns: + metrics.loc[empty_unit_ids, ["num_spikes"]] = 0 # we use the convert_dtypes to convert the columns to the most appropriate dtype and avoid object columns # (in case of NaN values) metrics = metrics.convert_dtypes() + + # we do this because the convert_dtypes infers the wrong types sometimes. + # the actual types for columns can be found in column_name_to_column_dtype dictionary. + for column in metrics.columns: + if column in column_name_to_column_dtype: + metrics[column] = metrics[column].astype(column_name_to_column_dtype[column]) + return metrics def _run(self, verbose=False, **job_kwargs): @@ -246,7 +279,7 @@ def _run(self, verbose=False, **job_kwargs): # append the metrics which were previously computed for metric_name in set(existing_metrics).difference(metrics_to_compute): # some metrics names produce data columns with other names. This deals with that. - for column_name in compute_name_to_column_names[metric_name]: + for column_name in qm_compute_name_to_column_names[metric_name]: computed_metrics[column_name] = qm_extension.data["metrics"][column_name] self.data["metrics"] = computed_metrics diff --git a/src/spikeinterface/qualitymetrics/quality_metric_list.py b/src/spikeinterface/qualitymetrics/quality_metric_list.py index 375dd320ae..23b781eb9d 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_list.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_list.py @@ -55,7 +55,7 @@ } # a dict converting the name of the metric for computation to the output of that computation -compute_name_to_column_names = { +qm_compute_name_to_column_names = { "num_spikes": ["num_spikes"], "firing_rate": ["firing_rate"], "presence_ratio": ["presence_ratio"], @@ -66,7 +66,11 @@ "amplitude_cutoff": ["amplitude_cutoff"], "amplitude_median": ["amplitude_median"], "amplitude_cv": ["amplitude_cv_median", "amplitude_cv_range"], - "synchrony": ["sync_spike_2", "sync_spike_4", "sync_spike_8"], + "synchrony": [ + "sync_spike_2", + "sync_spike_4", + "sync_spike_8", + ], "firing_range": ["firing_range"], "drift": ["drift_ptp", "drift_std", "drift_mad"], "sd_ratio": ["sd_ratio"], @@ -79,3 +83,38 @@ "silhouette": ["silhouette"], "silhouette_full": ["silhouette_full"], } + +# this dict allows us to ensure the appropriate dtype of metrics rather than allow Pandas to infer them +column_name_to_column_dtype = { + "num_spikes": int, + "firing_rate": float, + "presence_ratio": float, + "snr": float, + "isi_violations_ratio": float, + "isi_violations_count": float, + "rp_violations": float, + "rp_contamination": float, + "sliding_rp_violation": float, + "amplitude_cutoff": float, + "amplitude_median": float, + "amplitude_cv_median": float, + "amplitude_cv_range": float, + "sync_spike_2": float, + "sync_spike_4": float, + "sync_spike_8": float, + "firing_range": float, + "drift_ptp": float, + "drift_std": float, + "drift_mad": float, + "sd_ratio": float, + "isolation_distance": float, + "l_ratio": float, + "d_prime": float, + "nn_hit_rate": float, + "nn_miss_rate": float, + "nn_isolation": float, + "nn_unit_id": float, + "nn_noise_overlap": float, + "silhouette": float, + "silhouette_full": float, +} diff --git a/src/spikeinterface/qualitymetrics/tests/conftest.py b/src/spikeinterface/qualitymetrics/tests/conftest.py index 01fa16c8d7..ac1789a375 100644 --- a/src/spikeinterface/qualitymetrics/tests/conftest.py +++ b/src/spikeinterface/qualitymetrics/tests/conftest.py @@ -16,6 +16,11 @@ def small_sorting_analyzer(): seed=1205, ) + channel_ids_as_integers = [id for id in range(recording.get_num_channels())] + unit_ids_as_integers = [id for id in range(sorting.get_num_units())] + recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers) + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers) + sorting = sorting.select_units([2, 7, 0], ["#3", "#9", "#4"]) sorting_analyzer = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory") @@ -60,6 +65,11 @@ def sorting_analyzer_simple(): seed=1205, ) + channel_ids_as_integers = [id for id in range(recording.get_num_channels())] + unit_ids_as_integers = [id for id in range(sorting.get_num_units())] + recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers) + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers) + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True) sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=1205) diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index 4c0890b62b..18b49cd862 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -39,7 +39,7 @@ compute_firing_ranges, compute_amplitude_cv_metrics, compute_sd_ratio, - get_synchrony_counts, + _get_synchrony_counts, compute_quality_metrics, ) @@ -69,7 +69,7 @@ def test_compute_new_quality_metrics(small_sorting_analyzer): assert calculated_metrics == ["snr"] small_sorting_analyzer.compute( - {"quality_metrics": {"metric_names": list(qm_params.keys()), "qm_params": qm_params}} + {"quality_metrics": {"metric_names": list(qm_params.keys()), "metric_params": qm_params}} ) small_sorting_analyzer.compute({"quality_metrics": {"metric_names": ["snr"]}}) @@ -96,13 +96,13 @@ def test_compute_new_quality_metrics(small_sorting_analyzer): # check that, when parameters are changed, the data and metadata are updated old_snr_data = deepcopy(quality_metric_extension.get_data()["snr"].values) small_sorting_analyzer.compute( - {"quality_metrics": {"metric_names": ["snr"], "qm_params": {"snr": {"peak_mode": "peak_to_peak"}}}} + {"quality_metrics": {"metric_names": ["snr"], "metric_params": {"snr": {"peak_mode": "peak_to_peak"}}}} ) new_quality_metric_extension = small_sorting_analyzer.get_extension("quality_metrics") new_snr_data = new_quality_metric_extension.get_data()["snr"].values assert np.all(old_snr_data != new_snr_data) - assert new_quality_metric_extension.params["qm_params"]["snr"]["peak_mode"] == "peak_to_peak" + assert new_quality_metric_extension.params["metric_params"]["snr"]["peak_mode"] == "peak_to_peak" # check that all quality metrics are deleted when parents are recomputed, even after # recomputation @@ -280,10 +280,10 @@ def test_unit_id_order_independence(small_sorting_analyzer): } quality_metrics_1 = compute_quality_metrics( - small_sorting_analyzer, metric_names=get_quality_metric_list(), qm_params=qm_params + small_sorting_analyzer, metric_names=get_quality_metric_list(), metric_params=qm_params ) quality_metrics_2 = compute_quality_metrics( - small_sorting_analyzer_2, metric_names=get_quality_metric_list(), qm_params=qm_params + small_sorting_analyzer_2, metric_names=get_quality_metric_list(), metric_params=qm_params ) for metric, metric_2_data in quality_metrics_2.items(): @@ -352,7 +352,7 @@ def test_synchrony_counts_no_sync(): one_spike["sample_index"] = spike_times one_spike["unit_index"] = spike_units - sync_count = get_synchrony_counts(one_spike, np.array((2)), [0]) + sync_count = _get_synchrony_counts(one_spike, np.array([2, 4, 8]), [0]) assert np.all(sync_count[0] == np.array([0])) @@ -372,7 +372,7 @@ def test_synchrony_counts_one_sync(): two_spikes["sample_index"] = np.concatenate((spike_indices, added_spikes_indices)) two_spikes["unit_index"] = np.concatenate((spike_labels, added_spikes_labels)) - sync_count = get_synchrony_counts(two_spikes, np.array((2)), [0, 1]) + sync_count = _get_synchrony_counts(two_spikes, np.array([2, 4, 8]), [0, 1]) assert np.all(sync_count[0] == np.array([1, 1])) @@ -392,7 +392,7 @@ def test_synchrony_counts_one_quad_sync(): four_spikes["sample_index"] = np.concatenate((spike_indices, added_spikes_indices)) four_spikes["unit_index"] = np.concatenate((spike_labels, added_spikes_labels)) - sync_count = get_synchrony_counts(four_spikes, np.array((2, 4)), [0, 1, 2, 3]) + sync_count = _get_synchrony_counts(four_spikes, np.array([2, 4, 8]), [0, 1, 2, 3]) assert np.all(sync_count[0] == np.array([1, 1, 1, 1])) assert np.all(sync_count[1] == np.array([1, 1, 1, 1])) @@ -409,7 +409,7 @@ def test_synchrony_counts_not_all_units(): three_spikes["sample_index"] = np.concatenate((spike_indices, added_spikes_indices)) three_spikes["unit_index"] = np.concatenate((spike_labels, added_spikes_labels)) - sync_count = get_synchrony_counts(three_spikes, np.array((2)), [0, 1, 2]) + sync_count = _get_synchrony_counts(three_spikes, np.array([2, 4, 8]), [0, 1, 2]) assert np.all(sync_count[0] == np.array([0, 1, 1])) @@ -610,9 +610,9 @@ def test_calculate_rp_violations(sorting_analyzer_violations): def test_synchrony_metrics(sorting_analyzer_simple): sorting_analyzer = sorting_analyzer_simple sorting = sorting_analyzer.sorting - synchrony_sizes = (2, 3, 4) - synchrony_metrics = compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=synchrony_sizes) - print(synchrony_metrics) + synchrony_metrics = compute_synchrony_metrics(sorting_analyzer) + + synchrony_sizes = np.array([2, 4, 8]) # check returns for size in synchrony_sizes: @@ -625,10 +625,8 @@ def test_synchrony_metrics(sorting_analyzer_simple): sorting_sync = add_synchrony_to_sorting(sorting, sync_event_ratio=sync_level) sorting_analyzer_sync = create_sorting_analyzer(sorting_sync, sorting_analyzer.recording, format="memory") - previous_synchrony_metrics = compute_synchrony_metrics( - previous_sorting_analyzer, synchrony_sizes=synchrony_sizes - ) - current_synchrony_metrics = compute_synchrony_metrics(sorting_analyzer_sync, synchrony_sizes=synchrony_sizes) + previous_synchrony_metrics = compute_synchrony_metrics(previous_sorting_analyzer) + current_synchrony_metrics = compute_synchrony_metrics(sorting_analyzer_sync) print(current_synchrony_metrics) # check that all values increased for i, col in enumerate(previous_synchrony_metrics._fields): @@ -647,22 +645,17 @@ def test_synchrony_metrics_unit_id_subset(sorting_analyzer_simple): unit_ids_subset = [3, 7] - synchrony_sizes = (2,) - (synchrony_metrics,) = compute_synchrony_metrics( - sorting_analyzer_simple, synchrony_sizes=synchrony_sizes, unit_ids=unit_ids_subset - ) + synchrony_metrics = compute_synchrony_metrics(sorting_analyzer_simple, unit_ids=unit_ids_subset) - assert list(synchrony_metrics.keys()) == [3, 7] + assert list(synchrony_metrics.sync_spike_2.keys()) == [3, 7] + assert list(synchrony_metrics.sync_spike_4.keys()) == [3, 7] + assert list(synchrony_metrics.sync_spike_8.keys()) == [3, 7] def test_synchrony_metrics_no_unit_ids(sorting_analyzer_simple): - # all_unit_ids = sorting_analyzer_simple.sorting.unit_ids - - synchrony_sizes = (2,) - (synchrony_metrics,) = compute_synchrony_metrics(sorting_analyzer_simple, synchrony_sizes=synchrony_sizes) - - assert np.all(list(synchrony_metrics.keys()) == sorting_analyzer_simple.unit_ids) + synchrony_metrics = compute_synchrony_metrics(sorting_analyzer_simple) + assert np.all(list(synchrony_metrics.sync_spike_2.keys()) == sorting_analyzer_simple.unit_ids) @pytest.mark.sortingcomponents diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index a6415c58e8..ea8939ebb4 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -24,14 +24,14 @@ def test_compute_quality_metrics(sorting_analyzer_simple): metrics = compute_quality_metrics( sorting_analyzer, metric_names=["snr"], - qm_params=dict(isi_violation=dict(isi_threshold_ms=2)), + metric_params=dict(isi_violation=dict(isi_threshold_ms=2)), skip_pc_metrics=True, seed=2205, ) # print(metrics) qm = sorting_analyzer.get_extension("quality_metrics") - assert qm.params["qm_params"]["isi_violation"]["isi_threshold_ms"] == 2 + assert qm.params["metric_params"]["isi_violation"]["isi_threshold_ms"] == 2 assert "snr" in metrics.columns assert "isolation_distance" not in metrics.columns @@ -40,7 +40,7 @@ def test_compute_quality_metrics(sorting_analyzer_simple): metrics = compute_quality_metrics( sorting_analyzer, metric_names=None, - qm_params=dict(isi_violation=dict(isi_threshold_ms=2)), + metric_params=dict(isi_violation=dict(isi_threshold_ms=2)), skip_pc_metrics=False, seed=2205, ) @@ -48,9 +48,10 @@ def test_compute_quality_metrics(sorting_analyzer_simple): assert "isolation_distance" in metrics.columns -def test_compute_quality_metrics_recordingless(sorting_analyzer_simple): +def test_merging_quality_metrics(sorting_analyzer_simple): sorting_analyzer = sorting_analyzer_simple + metrics = compute_quality_metrics( sorting_analyzer, metric_names=None, @@ -59,6 +60,32 @@ def test_compute_quality_metrics_recordingless(sorting_analyzer_simple): seed=2205, ) + # sorting_analyzer_simple has ten units + new_sorting_analyzer = sorting_analyzer.merge_units([[0, 1]]) + + new_metrics = new_sorting_analyzer.get_extension("quality_metrics").get_data() + + # we should copy over the metrics after merge + for column in metrics.columns: + assert column in new_metrics.columns + # should copy dtype too + assert metrics[column].dtype == new_metrics[column].dtype + + # 10 units vs 9 units + assert len(metrics.index) > len(new_metrics.index) + + +def test_compute_quality_metrics_recordingless(sorting_analyzer_simple): + + sorting_analyzer = sorting_analyzer_simple + metrics = compute_quality_metrics( + sorting_analyzer, + metric_names=None, + metric_params=dict(isi_violation=dict(isi_threshold_ms=2)), + skip_pc_metrics=False, + seed=2205, + ) + # make a copy and make it recordingless sorting_analyzer_norec = sorting_analyzer.save_as(format="memory") sorting_analyzer_norec.delete_extension("quality_metrics") @@ -68,7 +95,7 @@ def test_compute_quality_metrics_recordingless(sorting_analyzer_simple): metrics_norec = compute_quality_metrics( sorting_analyzer_norec, metric_names=None, - qm_params=dict(isi_violation=dict(isi_threshold_ms=2)), + metric_params=dict(isi_violation=dict(isi_threshold_ms=2)), skip_pc_metrics=False, seed=2205, ) @@ -101,15 +128,20 @@ def test_empty_units(sorting_analyzer_simple): metrics_empty = compute_quality_metrics( sorting_analyzer_empty, metric_names=None, - qm_params=dict(isi_violation=dict(isi_threshold_ms=2)), + metric_params=dict(isi_violation=dict(isi_threshold_ms=2)), skip_pc_metrics=True, seed=2205, ) - for empty_unit_id in sorting_empty.get_empty_unit_ids(): + # num_spikes are ints not nans so we confirm empty units are nans for everything except + # num_spikes which should be 0 + nan_containing_columns = [column for column in metrics_empty.columns if column != "num_spikes"] + for empty_unit_ids in sorting_empty.get_empty_unit_ids(): from pandas import isnull - assert np.all(isnull(metrics_empty.loc[empty_unit_id].values)) + assert np.all(isnull(metrics_empty.loc[empty_unit_ids, nan_containing_columns].values)) + if "num_spikes" in metrics_empty.columns: + assert sum(metrics_empty.loc[empty_unit_ids, ["num_spikes"]]) == 0 # TODO @alessio all theses old test should be moved in test_metric_functions.py or test_pca_metrics() diff --git a/src/spikeinterface/sorters/basesorter.py b/src/spikeinterface/sorters/basesorter.py index 3502d27548..c59fa29c05 100644 --- a/src/spikeinterface/sorters/basesorter.py +++ b/src/spikeinterface/sorters/basesorter.py @@ -145,9 +145,10 @@ def initialize_folder(cls, recording, output_folder, verbose, remove_existing_fo elif recording.check_serializability("pickle"): recording.dump(output_folder / "spikeinterface_recording.pickle", relative_to=output_folder) else: - # TODO: deprecate and finally remove this after 0.100 - d = {"warning": "The recording is not serializable to json"} - rec_file.write_text(json.dumps(d, indent=4), encoding="utf8") + raise RuntimeError( + "This recording is not serializable and so can not be sorted. Consider `recording.save()` to save a " + "compatible binary file." + ) return output_folder diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 2a9fb34267..ec15506006 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -66,7 +66,7 @@ class Kilosort4Sorter(BaseSorter): "do_correction": True, "keep_good_only": False, "skip_kilosort_preprocessing": False, - "use_binary_file": None, + "use_binary_file": True, "delete_recording_dat": True, } @@ -116,7 +116,7 @@ class Kilosort4Sorter(BaseSorter): "keep_good_only": "If True, only the units labeled as 'good' by Kilosort are returned in the output. (spikeinterface parameter)", "use_binary_file": "If True then Kilosort is run using a binary file. In this case, if the input recording is not binary compatible, it is written to a binary file in the output folder. " "If False then Kilosort is run on the recording object directly using the RecordingExtractorAsArray object. If None, then if the recording is binary compatible, the sorter will use the binary file, otherwise the RecordingExtractorAsArray. " - "Default is None. (spikeinterface parameter)", + "Default is True. (spikeinterface parameter)", "delete_recording_dat": "If True, if a temporary binary file is created, it is deleted after the sorting is done. Default is True. (spikeinterface parameter)", } diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index eed693b343..a3a3523591 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -6,12 +6,16 @@ import numpy as np from spikeinterface.core import NumpySorting -from spikeinterface.core.job_tools import fix_job_kwargs +from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs from spikeinterface.core.recording_tools import get_noise_levels from spikeinterface.core.template import Templates from spikeinterface.core.waveform_tools import estimate_templates from spikeinterface.preprocessing import common_reference, whiten, bandpass_filter, correct_motion -from spikeinterface.sortingcomponents.tools import cache_preprocessing +from spikeinterface.sortingcomponents.tools import ( + cache_preprocessing, + get_prototype_and_waveforms_from_recording, + get_shuffled_recording_slices, +) from spikeinterface.core.basesorting import minimum_spike_dtype from spikeinterface.core.sparsity import compute_sparsity from spikeinterface.core.sortinganalyzer import create_sorting_analyzer @@ -26,7 +30,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): _default_params = { "general": {"ms_before": 2, "ms_after": 2, "radius_um": 75}, "sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25}, - "filtering": {"freq_min": 150, "freq_max": 7000, "ftype": "bessel", "filter_order": 2}, + "filtering": {"freq_min": 150, "freq_max": 7000, "ftype": "bessel", "filter_order": 2, "margin_ms": 10}, "whitening": {"mode": "local", "regularize": False}, "detection": {"peak_sign": "neg", "detect_threshold": 4}, "selection": { @@ -52,7 +56,8 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "matched_filtering": True, "cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True}, "multi_units_only": False, - "job_kwargs": {"n_jobs": 0.8}, + "job_kwargs": {"n_jobs": 0.5}, + "seed": 42, "debug": False, } @@ -74,18 +79,21 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "merging": "A dictionary to specify the final merging param to group cells after template matching (get_potential_auto_merge)", "motion_correction": "A dictionary to be provided if motion correction has to be performed (dense probe only)", "apply_preprocessing": "Boolean to specify whether circus 2 should preprocess the recording or not. If yes, then high_pass filtering + common\ - median reference + zscore", + median reference + whitening", + "apply_motion_correction": "Boolean to specify whether circus 2 should apply motion correction to the recording or not", + "matched_filtering": "Boolean to specify whether circus 2 should detect peaks via matched filtering (slightly slower)", "cache_preprocessing": "How to cache the preprocessed recording. Mode can be memory, file, zarr, with extra arguments. In case of memory (default), \ memory_limit will control how much RAM can be used. In case of folder or zarr, delete_cache controls if cache is cleaned after sorting", "multi_units_only": "Boolean to get only multi units activity (i.e. one template per electrode)", "job_kwargs": "A dictionary to specify how many jobs and which parameters they should used", + "seed": "An int to control how chunks are shuffled while detecting peaks", "debug": "Boolean to specify if internal data structures made during the sorting should be kept for debugging", } sorter_description = """Spyking Circus 2 is a rewriting of Spyking Circus, within the SpikeInterface framework It uses a more conservative clustering algorithm (compared to Spyking Circus), which is less prone to hallucinate units and/or find noise. In addition, it also uses a full Orthogonal Matching Pursuit engine to reconstruct the traces, leading to more spikes - being discovered.""" + being discovered. The code is much faster and memory efficient, inheriting from all the preprocessing possibilities of spikeinterface""" @classmethod def get_sorter_version(cls): @@ -114,8 +122,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): from spikeinterface.sortingcomponents.clustering import find_cluster_from_peaks from spikeinterface.sortingcomponents.matching import find_spikes_from_templates from spikeinterface.sortingcomponents.tools import remove_empty_templates - from spikeinterface.sortingcomponents.tools import get_prototype_spike, check_probe_for_drift_correction - from spikeinterface.sortingcomponents.tools import get_prototype_spike + from spikeinterface.sortingcomponents.tools import check_probe_for_drift_correction job_kwargs = fix_job_kwargs(params["job_kwargs"]) job_kwargs.update({"progress_bar": verbose}) @@ -132,10 +139,14 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## First, we are filtering the data filtering_params = params["filtering"].copy() if params["apply_preprocessing"]: + if verbose: + print("Preprocessing the recording (bandpass filtering + CMR + whitening)") recording_f = bandpass_filter(recording, **filtering_params, dtype="float32") if num_channels > 1: recording_f = common_reference(recording_f) else: + if verbose: + print("Skipping preprocessing (whitening only)") recording_f = recording recording_f.annotate(is_filtered=True) @@ -158,12 +169,14 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # TODO add , regularize=True chen ready whitening_kwargs = params["whitening"].copy() whitening_kwargs["dtype"] = "float32" - whitening_kwargs["radius_um"] = radius_um + whitening_kwargs["regularize"] = whitening_kwargs.get("regularize", False) if num_channels == 1: whitening_kwargs["regularize"] = False + if whitening_kwargs["regularize"]: + whitening_kwargs["regularize_kwargs"] = {"method": "LedoitWolf"} recording_w = whiten(recording_f, **whitening_kwargs) - noise_levels = get_noise_levels(recording_w, return_scaled=False) + noise_levels = get_noise_levels(recording_w, return_scaled=False, **job_kwargs) if recording_w.check_serializability("json"): recording_w.dump(sorter_output_folder / "preprocessed_recording.json", relative_to=None) @@ -174,9 +187,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## Then, we are detecting peaks with a locally_exclusive method detection_params = params["detection"].copy() - detection_params.update(job_kwargs) - - detection_params["radius_um"] = detection_params.get("radius_um", 50) + selection_params = params["selection"].copy() + detection_params["radius_um"] = radius_um detection_params["exclude_sweep_ms"] = exclude_sweep_ms detection_params["noise_levels"] = noise_levels @@ -184,17 +196,47 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): nbefore = int(ms_before * fs / 1000.0) nafter = int(ms_after * fs / 1000.0) + skip_peaks = not params["multi_units_only"] and selection_params.get("method", "uniform") == "uniform" + max_n_peaks = selection_params["n_peaks_per_channel"] * num_channels + n_peaks = max(selection_params["min_n_peaks"], max_n_peaks) + + if params["debug"]: + clustering_folder = sorter_output_folder / "clustering" + clustering_folder.mkdir(parents=True, exist_ok=True) + np.save(clustering_folder / "noise_levels.npy", noise_levels) + if params["matched_filtering"]: - peaks = detect_peaks(recording_w, "locally_exclusive", **detection_params, skip_after_n_peaks=5000) - prototype = get_prototype_spike(recording_w, peaks, ms_before, ms_after, **job_kwargs) + prototype, waveforms, _ = get_prototype_and_waveforms_from_recording( + recording_w, + n_peaks=10000, + ms_before=ms_before, + ms_after=ms_after, + seed=params["seed"], + **detection_params, + **job_kwargs, + ) detection_params["prototype"] = prototype detection_params["ms_before"] = ms_before - peaks = detect_peaks(recording_w, "matched_filtering", **detection_params) + if params["debug"]: + np.save(clustering_folder / "waveforms.npy", waveforms) + np.save(clustering_folder / "prototype.npy", prototype) + if skip_peaks: + detection_params["skip_after_n_peaks"] = n_peaks + detection_params["recording_slices"] = get_shuffled_recording_slices( + recording_w, seed=params["seed"], **job_kwargs + ) + peaks = detect_peaks(recording_w, "matched_filtering", **detection_params, **job_kwargs) else: - peaks = detect_peaks(recording_w, "locally_exclusive", **detection_params) + waveforms = None + if skip_peaks: + detection_params["skip_after_n_peaks"] = n_peaks + detection_params["recording_slices"] = get_shuffled_recording_slices( + recording_w, seed=params["seed"], **job_kwargs + ) + peaks = detect_peaks(recording_w, "locally_exclusive", **detection_params, **job_kwargs) - if verbose: - print("We found %d peaks in total" % len(peaks)) + if not skip_peaks and verbose: + print("Found %d peaks in total" % len(peaks)) if params["multi_units_only"]: sorting = NumpySorting.from_peaks(peaks, sampling_frequency, unit_ids=recording_w.unit_ids) @@ -202,14 +244,12 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## We subselect a subset of all the peaks, by making the distributions os SNRs over all ## channels as flat as possible selection_params = params["selection"] - selection_params["n_peaks"] = min(len(peaks), selection_params["n_peaks_per_channel"] * num_channels) - selection_params["n_peaks"] = max(selection_params["min_n_peaks"], selection_params["n_peaks"]) - + selection_params["n_peaks"] = n_peaks selection_params.update({"noise_levels": noise_levels}) selected_peaks = select_peaks(peaks, **selection_params) if verbose: - print("We kept %d peaks for clustering" % len(selected_peaks)) + print("Kept %d peaks for clustering" % len(selected_peaks)) ## We launch a clustering (using hdbscan) relying on positions and features extracted on ## the fly from the snippets @@ -219,11 +259,13 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clustering_params["radius_um"] = radius_um clustering_params["waveforms"]["ms_before"] = ms_before clustering_params["waveforms"]["ms_after"] = ms_after - clustering_params["job_kwargs"] = job_kwargs + clustering_params["few_waveforms"] = waveforms clustering_params["noise_levels"] = noise_levels - clustering_params["ms_before"] = exclude_sweep_ms - clustering_params["ms_after"] = exclude_sweep_ms + clustering_params["ms_before"] = ms_before + clustering_params["ms_after"] = ms_after + clustering_params["verbose"] = verbose clustering_params["tmp_folder"] = sorter_output_folder / "clustering" + clustering_params["noise_threshold"] = detection_params.get("detect_threshold", 4) legacy = clustering_params.get("legacy", True) @@ -233,7 +275,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clustering_method = "random_projections" labels, peak_labels = find_cluster_from_peaks( - recording_w, selected_peaks, method=clustering_method, method_kwargs=clustering_params + recording_w, selected_peaks, method=clustering_method, method_kwargs=clustering_params, **job_kwargs ) ## We get the labels for our peaks @@ -248,12 +290,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): unit_ids = np.arange(len(np.unique(labeled_peaks["unit_index"]))) sorting = NumpySorting(labeled_peaks, sampling_frequency, unit_ids=unit_ids) - clustering_folder = sorter_output_folder / "clustering" - clustering_folder.mkdir(parents=True, exist_ok=True) - - if not params["debug"]: - shutil.rmtree(clustering_folder) - else: + if params["debug"]: + np.save(clustering_folder / "peak_labels", peak_labels) np.save(clustering_folder / "labels", labels) np.save(clustering_folder / "peaks", selected_peaks) @@ -284,11 +322,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): matching_method = params["matching"].pop("method") matching_params = params["matching"].copy() matching_params["templates"] = templates - matching_job_params = job_kwargs.copy() if matching_method is not None: spikes = find_spikes_from_templates( - recording_w, matching_method, method_kwargs=matching_params, **matching_job_params + recording_w, matching_method, method_kwargs=matching_params, **job_kwargs ) if params["debug"]: @@ -297,7 +334,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): np.save(fitting_folder / "spikes", spikes) if verbose: - print("We found %d spikes" % len(spikes)) + print("Found %d spikes" % len(spikes)) ## And this is it! We have a spyking circus sorting = np.zeros(spikes.size, dtype=minimum_spike_dtype) @@ -337,10 +374,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): sorting.save(folder=curation_folder) # np.save(fitting_folder / "amplitudes", guessed_amplitudes) - sorting = final_cleaning_circus(recording_w, sorting, templates, **merging_params) + sorting = final_cleaning_circus(recording_w, sorting, templates, merging_params, **job_kwargs) if verbose: - print(f"Final merging, keeping {len(sorting.unit_ids)} units") + print(f"Kept {len(sorting.unit_ids)} units after final merging") folder_to_delete = None cache_mode = params["cache_preprocessing"].get("mode", "memory") @@ -379,17 +416,18 @@ def create_sorting_analyzer_with_templates(sorting, recording, templates, remove return sa -def final_cleaning_circus(recording, sorting, templates, **merging_kwargs): +def final_cleaning_circus(recording, sorting, templates, merging_kwargs, **job_kwargs): from spikeinterface.core.sorting_tools import apply_merges_to_sorting sa = create_sorting_analyzer_with_templates(sorting, recording, templates) - sa.compute("unit_locations", method="monopolar_triangulation") + sa.compute("unit_locations", method="monopolar_triangulation", **job_kwargs) similarity_kwargs = merging_kwargs.pop("similarity_kwargs", {}) - sa.compute("template_similarity", **similarity_kwargs) + sa.compute("template_similarity", **similarity_kwargs, **job_kwargs) correlograms_kwargs = merging_kwargs.pop("correlograms_kwargs", {}) - sa.compute("correlograms", **correlograms_kwargs) + sa.compute("correlograms", **correlograms_kwargs, **job_kwargs) + auto_merge_kwargs = merging_kwargs.pop("auto_merge", {}) merges = get_potential_auto_merge(sa, resolve_graph=True, **auto_merge_kwargs) sorting = apply_merges_to_sorting(sa.sorting, merges) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 99c59f493e..bc173a6ff0 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -18,7 +18,6 @@ from spikeinterface.core.waveform_tools import estimate_templates from .clustering_tools import remove_duplicates_via_matching from spikeinterface.core.recording_tools import get_noise_levels, get_channel_distances -from spikeinterface.core.job_tools import fix_job_kwargs from spikeinterface.sortingcomponents.peak_selection import select_peaks from spikeinterface.sortingcomponents.waveforms.temporal_pca import TemporalPCAProjection from spikeinterface.core.template import Templates @@ -41,13 +40,7 @@ class CircusClustering: """ _default_params = { - "hdbscan_kwargs": { - "min_cluster_size": 25, - "allow_single_cluster": True, - "core_dist_n_jobs": -1, - "cluster_selection_method": "eom", - # "cluster_selection_epsilon" : 5 ## To be optimized - }, + "hdbscan_kwargs": {"min_cluster_size": 10, "allow_single_cluster": True, "min_samples": 5}, "cleaning_kwargs": {}, "waveforms": {"ms_before": 2, "ms_after": 2}, "sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25}, @@ -58,21 +51,20 @@ class CircusClustering: }, "radius_um": 100, "n_svd": [5, 2], + "few_waveforms": None, "ms_before": 0.5, "ms_after": 0.5, + "noise_threshold": 4, "rank": 5, "noise_levels": None, "tmp_folder": None, - "job_kwargs": {}, "verbose": True, } @classmethod - def main_function(cls, recording, peaks, params): + def main_function(cls, recording, peaks, params, job_kwargs=dict()): assert HAVE_HDBSCAN, "random projections clustering needs hdbscan to be installed" - job_kwargs = fix_job_kwargs(params["job_kwargs"]) - d = params verbose = d["verbose"] @@ -90,12 +82,25 @@ def main_function(cls, recording, peaks, params): tmp_folder.mkdir(parents=True, exist_ok=True) # SVD for time compression - few_peaks = select_peaks(peaks, recording=recording, method="uniform", n_peaks=10000, margin=(nbefore, nafter)) - few_wfs = extract_waveform_at_max_channel( - recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs - ) + if params["few_waveforms"] is None: + few_peaks = select_peaks( + peaks, recording=recording, method="uniform", n_peaks=10000, margin=(nbefore, nafter) + ) + few_wfs = extract_waveform_at_max_channel( + recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs + ) + wfs = few_wfs[:, :, 0] + else: + offset = int(params["waveforms"]["ms_before"] * fs / 1000) + wfs = params["few_waveforms"][:, offset - nbefore : offset + nafter] + + # Ensure all waveforms have a positive max + wfs *= np.sign(wfs[:, nbefore])[:, np.newaxis] + + # Remove outliers + valid = np.argmax(np.abs(wfs), axis=1) == nbefore + wfs = wfs[valid] - wfs = few_wfs[:, :, 0] from sklearn.decomposition import TruncatedSVD tsvd = TruncatedSVD(params["n_svd"][0]) @@ -193,7 +198,7 @@ def main_function(cls, recording, peaks, params): original_labels = peaks["channel_index"] from spikeinterface.sortingcomponents.clustering.split import split_clusters - min_size = params["hdbscan_kwargs"].get("min_cluster_size", 50) + min_size = 2 * params["hdbscan_kwargs"].get("min_cluster_size", 10) peak_labels, _ = split_clusters( original_labels, @@ -229,47 +234,64 @@ def main_function(cls, recording, peaks, params): nbefore = int(params["waveforms"]["ms_before"] * fs / 1000.0) nafter = int(params["waveforms"]["ms_after"] * fs / 1000.0) + if params["noise_levels"] is None: + params["noise_levels"] = get_noise_levels(recording, return_scaled=False, **job_kwargs) + templates_array = estimate_templates( - recording, spikes, unit_ids, nbefore, nafter, return_scaled=False, job_name=None, **job_kwargs + recording, + spikes, + unit_ids, + nbefore, + nafter, + return_scaled=False, + job_name=None, + **job_kwargs, ) + best_channels = np.argmax(np.abs(templates_array[:, nbefore, :]), axis=1) + peak_snrs = np.abs(templates_array[:, nbefore, :]) + best_snrs_ratio = (peak_snrs / params["noise_levels"])[np.arange(len(peak_snrs)), best_channels] + valid_templates = best_snrs_ratio > params["noise_threshold"] + if d["rank"] is not None: from spikeinterface.sortingcomponents.matching.circus import compress_templates _, _, _, templates_array = compress_templates(templates_array, d["rank"]) templates = Templates( - templates_array=templates_array, + templates_array=templates_array[valid_templates], sampling_frequency=fs, nbefore=nbefore, sparsity_mask=None, channel_ids=recording.channel_ids, - unit_ids=unit_ids, + unit_ids=unit_ids[valid_templates], probe=recording.get_probe(), is_scaled=False, ) - if params["noise_levels"] is None: - params["noise_levels"] = get_noise_levels(recording, return_scaled=False) + sparsity = compute_sparsity(templates, noise_levels=params["noise_levels"], **params["sparsity"]) templates = templates.to_sparse(sparsity) empty_templates = templates.sparsity_mask.sum(axis=1) == 0 templates = remove_empty_templates(templates) + mask = np.isin(peak_labels, np.where(empty_templates)[0]) peak_labels[mask] = -1 + mask = np.isin(peak_labels, np.where(~valid_templates)[0]) + peak_labels[mask] = -1 + if verbose: - print("We found %d raw clusters, starting to clean with matching..." % (len(templates.unit_ids))) + print("Found %d raw clusters, starting to clean with matching" % (len(templates.unit_ids))) - cleaning_matching_params = params["job_kwargs"].copy() - cleaning_matching_params["n_jobs"] = 1 - cleaning_matching_params["progress_bar"] = False + cleaning_job_kwargs = job_kwargs.copy() + cleaning_job_kwargs["progress_bar"] = False cleaning_params = params["cleaning_kwargs"].copy() labels, peak_labels = remove_duplicates_via_matching( - templates, peak_labels, job_kwargs=cleaning_matching_params, **cleaning_params + templates, peak_labels, job_kwargs=cleaning_job_kwargs, **cleaning_params ) if verbose: - print("We kept %d non-duplicated clusters..." % len(labels)) + print("Kept %d non-duplicated clusters" % len(labels)) return labels, peak_labels diff --git a/src/spikeinterface/sortingcomponents/clustering/clean.py b/src/spikeinterface/sortingcomponents/clustering/clean.py index c7d57b14e4..e8bc5a1d49 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clean.py +++ b/src/spikeinterface/sortingcomponents/clustering/clean.py @@ -32,7 +32,6 @@ def clean_clusters( count = np.zeros(n, dtype="int64") for i, label in enumerate(labels_set): count[i] = np.sum(peak_labels == label) - print(count) templates = compute_template_from_sparse(peaks, peak_labels, labels_set, sparse_wfs, sparse_mask, total_channels) @@ -42,6 +41,5 @@ def clean_clusters( max_values = -np.min(templates, axis=(1, 2)) elif peak_sign == "pos": max_values = np.max(templates, axis=(1, 2)) - print(max_values) return clean_labels diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 08a1384333..93db9a268f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -570,7 +570,7 @@ def detect_mixtures(templates, method_kwargs={}, job_kwargs={}, tmp_folder=None, ) else: recording = NumpyRecording(zdata, sampling_frequency=fs) - recording = SharedMemoryRecording.from_recording(recording) + recording = SharedMemoryRecording.from_recording(recording, **job_kwargs) recording = recording.set_probe(templates.probe) recording.annotate(is_filtered=True) @@ -587,6 +587,8 @@ def detect_mixtures(templates, method_kwargs={}, job_kwargs={}, tmp_folder=None, keep_searching = True + local_job_kargs = {"n_jobs": 1, "progress_bar": False} + DEBUG = False while keep_searching: @@ -604,7 +606,11 @@ def detect_mixtures(templates, method_kwargs={}, job_kwargs={}, tmp_folder=None, local_params.update({"ignore_inds": ignore_inds + [i]}) spikes, more_outputs = find_spikes_from_templates( - sub_recording, method="circus-omp-svd", method_kwargs=local_params, extra_outputs=True, **job_kwargs + sub_recording, + method="circus-omp-svd", + method_kwargs=local_params, + extra_outputs=True, + **local_job_kargs, ) local_params["precomputed"] = more_outputs valid = (spikes["sample_index"] >= 0) * (spikes["sample_index"] < duration + 2 * margin) diff --git a/src/spikeinterface/sortingcomponents/clustering/dummy.py b/src/spikeinterface/sortingcomponents/clustering/dummy.py index c1032ee6c6..b5761ad5cf 100644 --- a/src/spikeinterface/sortingcomponents/clustering/dummy.py +++ b/src/spikeinterface/sortingcomponents/clustering/dummy.py @@ -13,7 +13,7 @@ class DummyClustering: _default_params = {} @classmethod - def main_function(cls, recording, peaks, params): + def main_function(cls, recording, peaks, params, job_kwargs=dict()): labels = np.arange(recording.get_num_channels(), dtype="int64") peak_labels = peaks["channel_index"] return labels, peak_labels diff --git a/src/spikeinterface/sortingcomponents/clustering/main.py b/src/spikeinterface/sortingcomponents/clustering/main.py index 99881f2f34..ba0fe6f9ac 100644 --- a/src/spikeinterface/sortingcomponents/clustering/main.py +++ b/src/spikeinterface/sortingcomponents/clustering/main.py @@ -41,7 +41,7 @@ def find_cluster_from_peaks(recording, peaks, method="stupid", method_kwargs={}, params = method_class._default_params.copy() params.update(**method_kwargs) - outputs = method_class.main_function(recording, peaks, params) + outputs = method_class.main_function(recording, peaks, params, job_kwargs=job_kwargs) if extra_outputs: return outputs diff --git a/src/spikeinterface/sortingcomponents/clustering/position.py b/src/spikeinterface/sortingcomponents/clustering/position.py index ae772206bb..dc76d787f6 100644 --- a/src/spikeinterface/sortingcomponents/clustering/position.py +++ b/src/spikeinterface/sortingcomponents/clustering/position.py @@ -25,18 +25,17 @@ class PositionClustering: "hdbscan_kwargs": {"min_cluster_size": 20, "allow_single_cluster": True, "core_dist_n_jobs": -1}, "debug": False, "tmp_folder": None, - "job_kwargs": {"n_jobs": -1, "chunk_memory": "10M"}, } @classmethod - def main_function(cls, recording, peaks, params): + def main_function(cls, recording, peaks, params, job_kwargs=dict()): assert HAVE_HDBSCAN, "position clustering need hdbscan to be installed" d = params if d["peak_locations"] is None: from spikeinterface.sortingcomponents.peak_localization import localize_peaks - peak_locations = localize_peaks(recording, peaks, **d["peak_localization_kwargs"], **d["job_kwargs"]) + peak_locations = localize_peaks(recording, peaks, **d["peak_localization_kwargs"], **job_kwargs) else: peak_locations = d["peak_locations"] diff --git a/src/spikeinterface/sortingcomponents/clustering/position_and_features.py b/src/spikeinterface/sortingcomponents/clustering/position_and_features.py index d23eb26239..20067a2eec 100644 --- a/src/spikeinterface/sortingcomponents/clustering/position_and_features.py +++ b/src/spikeinterface/sortingcomponents/clustering/position_and_features.py @@ -42,23 +42,14 @@ class PositionAndFeaturesClustering: "ms_before": 1.5, "ms_after": 1.5, "cleaning_method": "dip", - "job_kwargs": {"n_jobs": -1, "chunk_memory": "10M", "progress_bar": True}, } @classmethod - def main_function(cls, recording, peaks, params): + def main_function(cls, recording, peaks, params, job_kwargs=dict()): from sklearn.preprocessing import QuantileTransformer assert HAVE_HDBSCAN, "twisted clustering needs hdbscan to be installed" - if "n_jobs" in params["job_kwargs"]: - if params["job_kwargs"]["n_jobs"] == -1: - params["job_kwargs"]["n_jobs"] = os.cpu_count() - - if "core_dist_n_jobs" in params["hdbscan_kwargs"]: - if params["hdbscan_kwargs"]["core_dist_n_jobs"] == -1: - params["hdbscan_kwargs"]["core_dist_n_jobs"] = os.cpu_count() - d = params peak_dtype = [("sample_index", "int64"), ("unit_index", "int64"), ("segment_index", "int64")] @@ -80,7 +71,7 @@ def main_function(cls, recording, peaks, params): } features_data = compute_features_from_peaks( - recording, peaks, features_list, features_params, ms_before=1, ms_after=1, **params["job_kwargs"] + recording, peaks, features_list, features_params, ms_before=1, ms_after=1, **job_kwargs ) hdbscan_data = np.zeros((len(peaks), 3), dtype=np.float32) @@ -150,10 +141,10 @@ def main_function(cls, recording, peaks, params): dtype=recording.get_dtype(), sparsity_mask=None, copy=True, - **params["job_kwargs"], + **job_kwargs, ) - noise_levels = get_noise_levels(recording, return_scaled=False) + noise_levels = get_noise_levels(recording, return_scaled=False, **job_kwargs) labels, peak_labels = remove_duplicates( wfs_arrays, noise_levels, peak_labels, num_samples, num_chans, **params["cleaning_kwargs"] ) @@ -181,7 +172,7 @@ def main_function(cls, recording, peaks, params): nbefore, nafter, return_scaled=False, - **params["job_kwargs"], + **job_kwargs, ) templates = Templates( templates_array=templates_array, @@ -193,7 +184,7 @@ def main_function(cls, recording, peaks, params): ) labels, peak_labels = remove_duplicates_via_matching( - templates, peak_labels, job_kwargs=params["job_kwargs"], **params["cleaning_kwargs"] + templates, peak_labels, job_kwargs=job_kwargs, **params["cleaning_kwargs"] ) shutil.rmtree(tmp_folder) diff --git a/src/spikeinterface/sortingcomponents/clustering/position_and_pca.py b/src/spikeinterface/sortingcomponents/clustering/position_and_pca.py index 4dfe3c960c..c4f372fc21 100644 --- a/src/spikeinterface/sortingcomponents/clustering/position_and_pca.py +++ b/src/spikeinterface/sortingcomponents/clustering/position_and_pca.py @@ -38,7 +38,6 @@ class PositionAndPCAClustering: "ms_after": 2.5, "n_components_by_channel": 3, "n_components": 5, - "job_kwargs": {"n_jobs": -1, "chunk_memory": "10M", "progress_bar": True}, "hdbscan_global_kwargs": {"min_cluster_size": 20, "allow_single_cluster": True, "core_dist_n_jobs": -1}, "hdbscan_local_kwargs": {"min_cluster_size": 20, "allow_single_cluster": True, "core_dist_n_jobs": -1}, "waveform_mode": "shared_memory", @@ -73,7 +72,7 @@ def _check_params(cls, recording, peaks, params): return params2 @classmethod - def main_function(cls, recording, peaks, params): + def main_function(cls, recording, peaks, params, job_kwargs=dict()): # res = PositionClustering(recording, peaks, params) assert HAVE_HDBSCAN, "position_and_pca clustering need hdbscan to be installed" @@ -85,9 +84,7 @@ def main_function(cls, recording, peaks, params): if params["peak_locations"] is None: from spikeinterface.sortingcomponents.peak_localization import localize_peaks - peak_locations = localize_peaks( - recording, peaks, **params["peak_localization_kwargs"], **params["job_kwargs"] - ) + peak_locations = localize_peaks(recording, peaks, **params["peak_localization_kwargs"], **job_kwargs) else: peak_locations = params["peak_locations"] @@ -155,7 +152,7 @@ def main_function(cls, recording, peaks, params): dtype=recording.get_dtype(), sparsity_mask=sparsity_mask, copy=(params["waveform_mode"] == "shared_memory"), - **params["job_kwargs"], + **job_kwargs, ) noise = get_random_data_chunks( @@ -222,7 +219,7 @@ def main_function(cls, recording, peaks, params): dtype=recording.get_dtype(), sparsity_mask=sparsity_mask3, copy=(params["waveform_mode"] == "shared_memory"), - **params["job_kwargs"], + **job_kwargs, ) clean_peak_labels, peak_sample_shifts = auto_clean_clustering( diff --git a/src/spikeinterface/sortingcomponents/clustering/position_ptp_scaled.py b/src/spikeinterface/sortingcomponents/clustering/position_ptp_scaled.py index 788addf1e6..0f7390d7ac 100644 --- a/src/spikeinterface/sortingcomponents/clustering/position_ptp_scaled.py +++ b/src/spikeinterface/sortingcomponents/clustering/position_ptp_scaled.py @@ -26,7 +26,6 @@ class PositionPTPScaledClustering: "ptps": None, "scales": (1, 1, 10), "peak_localization_kwargs": {"method": "center_of_mass"}, - "job_kwargs": {"n_jobs": -1, "chunk_memory": "10M", "progress_bar": True}, "hdbscan_kwargs": { "min_cluster_size": 20, "min_samples": 20, @@ -38,7 +37,7 @@ class PositionPTPScaledClustering: } @classmethod - def main_function(cls, recording, peaks, params): + def main_function(cls, recording, peaks, params, job_kwargs=dict()): assert HAVE_HDBSCAN, "position clustering need hdbscan to be installed" d = params @@ -60,7 +59,7 @@ def main_function(cls, recording, peaks, params): if d["ptps"] is None: (ptps,) = compute_features_from_peaks( - recording, peaks, ["ptp"], feature_params={"ptp": {"all_channels": True}}, **d["job_kwargs"] + recording, peaks, ["ptp"], feature_params={"ptp": {"all_channels": True}}, **job_kwargs ) else: ptps = d["ptps"] diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index f7ca999d53..1d4d8881ad 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -16,8 +16,7 @@ from spikeinterface.core.basesorting import minimum_spike_dtype from spikeinterface.core.waveform_tools import estimate_templates from .clustering_tools import remove_duplicates_via_matching -from spikeinterface.core.recording_tools import get_noise_levels, get_channel_distances -from spikeinterface.core.job_tools import fix_job_kwargs +from spikeinterface.core.recording_tools import get_noise_levels from spikeinterface.sortingcomponents.waveforms.savgol_denoiser import SavGolDenoiser from spikeinterface.sortingcomponents.features_from_peaks import RandomProjectionsFeature from spikeinterface.core.template import Templates @@ -54,17 +53,15 @@ class RandomProjectionClustering: "random_seed": 42, "noise_levels": None, "smoothing_kwargs": {"window_length_ms": 0.25}, + "noise_threshold": 4, "tmp_folder": None, - "job_kwargs": {}, "verbose": True, } @classmethod - def main_function(cls, recording, peaks, params): + def main_function(cls, recording, peaks, params, job_kwargs=dict()): assert HAVE_HDBSCAN, "random projections clustering need hdbscan to be installed" - job_kwargs = fix_job_kwargs(params["job_kwargs"]) - d = params verbose = d["verbose"] @@ -133,44 +130,59 @@ def main_function(cls, recording, peaks, params): nbefore = int(params["waveforms"]["ms_before"] * fs / 1000.0) nafter = int(params["waveforms"]["ms_after"] * fs / 1000.0) + if params["noise_levels"] is None: + params["noise_levels"] = get_noise_levels(recording, return_scaled=False, **job_kwargs) + templates_array = estimate_templates( - recording, spikes, unit_ids, nbefore, nafter, return_scaled=False, job_name=None, **job_kwargs + recording, + spikes, + unit_ids, + nbefore, + nafter, + return_scaled=False, + job_name=None, + **job_kwargs, ) + best_channels = np.argmax(np.abs(templates_array[:, nbefore, :]), axis=1) + peak_snrs = np.abs(templates_array[:, nbefore, :]) + best_snrs_ratio = (peak_snrs / params["noise_levels"])[np.arange(len(peak_snrs)), best_channels] + valid_templates = best_snrs_ratio > params["noise_threshold"] + templates = Templates( - templates_array=templates_array, + templates_array=templates_array[valid_templates], sampling_frequency=fs, nbefore=nbefore, sparsity_mask=None, channel_ids=recording.channel_ids, - unit_ids=unit_ids, + unit_ids=unit_ids[valid_templates], probe=recording.get_probe(), is_scaled=False, ) - if params["noise_levels"] is None: - params["noise_levels"] = get_noise_levels(recording, return_scaled=False) - sparsity = compute_sparsity(templates, params["noise_levels"], **params["sparsity"]) + + sparsity = compute_sparsity(templates, noise_levels=params["noise_levels"], **params["sparsity"]) templates = templates.to_sparse(sparsity) + empty_templates = templates.sparsity_mask.sum(axis=1) == 0 templates = remove_empty_templates(templates) - if verbose: - print("We found %d raw clusters, starting to clean with matching..." % (len(templates.unit_ids))) + mask = np.isin(peak_labels, np.where(empty_templates)[0]) + peak_labels[mask] = -1 - cleaning_matching_params = job_kwargs.copy() - for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: - if value in cleaning_matching_params: - cleaning_matching_params[value] = None - cleaning_matching_params["chunk_duration"] = "100ms" - cleaning_matching_params["n_jobs"] = 1 - cleaning_matching_params["progress_bar"] = False + mask = np.isin(peak_labels, np.where(~valid_templates)[0]) + peak_labels[mask] = -1 + + if verbose: + print("Found %d raw clusters, starting to clean with matching" % (len(templates.unit_ids))) + cleaning_job_kwargs = job_kwargs.copy() + cleaning_job_kwargs["progress_bar"] = False cleaning_params = params["cleaning_kwargs"].copy() labels, peak_labels = remove_duplicates_via_matching( - templates, peak_labels, job_kwargs=cleaning_matching_params, **cleaning_params + templates, peak_labels, job_kwargs=cleaning_job_kwargs, **cleaning_params ) if verbose: - print("We kept %d non-duplicated clusters..." % len(labels)) + print("Kept %d non-duplicated clusters" % len(labels)) return labels, peak_labels diff --git a/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py b/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py index 8b9acbc92d..5f8ac99848 100644 --- a/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py +++ b/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py @@ -23,7 +23,7 @@ get_random_data_chunks, extract_waveforms_to_buffers, ) -from .clustering_tools import auto_clean_clustering, auto_split_clustering +from .clustering_tools import auto_clean_clustering class SlidingHdbscanClustering: @@ -55,18 +55,17 @@ class SlidingHdbscanClustering: "auto_merge_quantile_limit": 0.8, "ratio_num_channel_intersect": 0.5, # ~ 'auto_trash_misalignment_shift' : 4, - "job_kwargs": {"n_jobs": -1, "chunk_memory": "10M", "progress_bar": True}, } @classmethod - def main_function(cls, recording, peaks, params): + def main_function(cls, recording, peaks, params, job_kwargs=dict()): assert HAVE_HDBSCAN, "sliding_hdbscan clustering need hdbscan to be installed" params = cls._check_params(recording, peaks, params) wfs_arrays, sparsity_mask, noise = cls._initialize_folder(recording, peaks, params) peak_labels = cls._find_clusters(recording, peaks, wfs_arrays, sparsity_mask, noise, params) wfs_arrays2, sparsity_mask2 = cls._prepare_clean( - recording, peaks, wfs_arrays, sparsity_mask, peak_labels, params + recording, peaks, wfs_arrays, sparsity_mask, peak_labels, params, job_kwargs ) clean_peak_labels, peak_sample_shifts = cls._clean_cluster( @@ -100,7 +99,7 @@ def _check_params(cls, recording, peaks, params): return params2 @classmethod - def _initialize_folder(cls, recording, peaks, params): + def _initialize_folder(cls, recording, peaks, params, job_kwargs=dict()): d = params tmp_folder = params["tmp_folder"] @@ -145,7 +144,7 @@ def _initialize_folder(cls, recording, peaks, params): dtype=dtype, sparsity_mask=sparsity_mask, copy=(d["waveform_mode"] == "shared_memory"), - **d["job_kwargs"], + **job_kwargs, ) # noise @@ -401,7 +400,7 @@ def _find_clusters(cls, recording, peaks, wfs_arrays, sparsity_mask, noise, d): return peak_labels @classmethod - def _prepare_clean(cls, recording, peaks, wfs_arrays, sparsity_mask, peak_labels, d): + def _prepare_clean(cls, recording, peaks, wfs_arrays, sparsity_mask, peak_labels, d, job_kwargs=dict()): tmp_folder = d["tmp_folder"] if tmp_folder is None: wf_folder = None @@ -465,7 +464,7 @@ def _prepare_clean(cls, recording, peaks, wfs_arrays, sparsity_mask, peak_labels dtype=recording.get_dtype(), sparsity_mask=sparsity_mask2, copy=(d["waveform_mode"] == "shared_memory"), - **d["job_kwargs"], + **job_kwargs, ) return wfs_arrays2, sparsity_mask2 diff --git a/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py b/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py index a6ffa5fdc2..40cedacdc5 100644 --- a/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py +++ b/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py @@ -71,11 +71,10 @@ class SlidingNNClustering: "tmp_folder": None, "verbose": False, "tmp_folder": None, - "job_kwargs": {"n_jobs": -1}, } @classmethod - def _initialize_folder(cls, recording, peaks, params): + def _initialize_folder(cls, recording, peaks, params, job_kwargs=dict()): assert HAVE_NUMBA, "SlidingNN needs numba to work" assert HAVE_TORCH, "SlidingNN needs torch to work" assert HAVE_NNDESCENT, "SlidingNN needs pynndescent to work" @@ -126,16 +125,16 @@ def _initialize_folder(cls, recording, peaks, params): dtype=dtype, sparsity_mask=sparsity_mask, copy=(d["waveform_mode"] == "shared_memory"), - **d["job_kwargs"], + **job_kwargs, ) return wfs_arrays, sparsity_mask @classmethod - def main_function(cls, recording, peaks, params): + def main_function(cls, recording, peaks, params, job_kwargs=dict()): d = params - # wfs_arrays, sparsity_mask, noise = cls._initialize_folder(recording, peaks, params) + # wfs_arrays, sparsity_mask, noise = cls._initialize_folder(recording, peaks, params, job_kwargs) # prepare neighborhood parameters fs = recording.get_sampling_frequency() @@ -228,7 +227,7 @@ def main_function(cls, recording, peaks, params): n_channel_neighbors=d["n_channel_neighbors"], low_memory=d["low_memory"], knn_verbose=d["verbose"], - n_jobs=d["job_kwargs"]["n_jobs"], + n_jobs=job_kwargs["n_jobs"], ) # remove the first nearest neighbor (which should be self) knn_distances = knn_distances[:, 1:] @@ -297,7 +296,7 @@ def main_function(cls, recording, peaks, params): # TODO HDBSCAN can be done on GPU with NVIDIA RAPIDS for speed clusterer = hdbscan.HDBSCAN( prediction_data=True, - core_dist_n_jobs=d["job_kwargs"]["n_jobs"], + core_dist_n_jobs=job_kwargs["n_jobs"], **d["hdbscan_kwargs"], ).fit(embeddings_chunk) diff --git a/src/spikeinterface/sortingcomponents/clustering/tdc.py b/src/spikeinterface/sortingcomponents/clustering/tdc.py index 13af5b0fab..59472d1374 100644 --- a/src/spikeinterface/sortingcomponents/clustering/tdc.py +++ b/src/spikeinterface/sortingcomponents/clustering/tdc.py @@ -9,27 +9,21 @@ from spikeinterface.core import ( get_channel_distances, - Templates, - compute_sparsity, get_global_tmp_folder, ) from spikeinterface.core.node_pipeline import ( run_node_pipeline, - ExtractDenseWaveforms, ExtractSparseWaveforms, PeakRetriever, ) -from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel, cache_preprocessing -from spikeinterface.sortingcomponents.peak_detection import detect_peaks, DetectPeakLocallyExclusive +from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel from spikeinterface.sortingcomponents.peak_selection import select_peaks -from spikeinterface.sortingcomponents.peak_localization import LocalizeCenterOfMass, LocalizeGridConvolution from spikeinterface.sortingcomponents.waveforms.temporal_pca import TemporalPCAProjection from spikeinterface.sortingcomponents.clustering.split import split_clusters from spikeinterface.sortingcomponents.clustering.merge import merge_clusters -from spikeinterface.sortingcomponents.clustering.tools import compute_template_from_sparse class TdcClustering: @@ -50,15 +44,12 @@ class TdcClustering: "merge_radius_um": 40.0, "threshold_diff": 1.5, }, - "job_kwargs": {}, } @classmethod - def main_function(cls, recording, peaks, params): + def main_function(cls, recording, peaks, params, job_kwargs=dict()): import hdbscan - job_kwargs = params["job_kwargs"] - if params["folder"] is None: randname = "".join(random.choices(string.ascii_uppercase + string.digits, k=6)) clustering_folder = get_global_tmp_folder() / f"tdcclustering_{randname}" diff --git a/src/spikeinterface/sortingcomponents/clustering/tools.py b/src/spikeinterface/sortingcomponents/clustering/tools.py index e2a0d273d6..64cc0f39c4 100644 --- a/src/spikeinterface/sortingcomponents/clustering/tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/tools.py @@ -172,8 +172,6 @@ def apply_waveforms_shift(waveforms, peak_shifts, inplace=False): """ - print("apply_waveforms_shift") - if inplace: aligned_waveforms = waveforms else: @@ -193,6 +191,4 @@ def apply_waveforms_shift(waveforms, peak_shifts, inplace=False): else: aligned_waveforms[mask, -shift:, :] = wfs[:, :-shift, :] - print("apply_waveforms_shift DONE") - return aligned_waveforms diff --git a/src/spikeinterface/sortingcomponents/motion/dredge.py b/src/spikeinterface/sortingcomponents/motion/dredge.py index e2b6b1a2bc..bfedd4e1ee 100644 --- a/src/spikeinterface/sortingcomponents/motion/dredge.py +++ b/src/spikeinterface/sortingcomponents/motion/dredge.py @@ -22,20 +22,19 @@ """ +import gc import warnings -from tqdm.auto import trange import numpy as np - -import gc +from tqdm.auto import trange from .motion_utils import ( Motion, + get_spatial_bin_edges, get_spatial_windows, get_window_domains, - scipy_conv1d, make_2d_motion_histogram, - get_spatial_bin_edges, + scipy_conv1d, ) @@ -979,7 +978,7 @@ def xcorr_windows( if max_disp_um is None: if rigid: - max_disp_um = int(spatial_bin_edges_um.ptp() // 4) + max_disp_um = int(np.ptp(spatial_bin_edges_um) // 4) else: max_disp_um = int(win_scale_um // 4) diff --git a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index a5e6ded519..fc8ccb788b 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -6,6 +6,8 @@ from spikeinterface.preprocessing.basepreprocessor import BasePreprocessor, BasePreprocessorSegment from spikeinterface.preprocessing.filter import fix_dtype +from .motion_utils import ensure_time_bin_edges, ensure_time_bins + def correct_motion_on_peaks(peaks, peak_locations, motion, recording) -> np.ndarray: """ @@ -54,6 +56,7 @@ def interpolate_motion_on_traces( segment_index=None, channel_inds=None, interpolation_time_bin_centers_s=None, + interpolation_time_bin_edges_s=None, spatial_interpolation_method="kriging", spatial_interpolation_kwargs={}, dtype=None, @@ -61,7 +64,11 @@ def interpolate_motion_on_traces( """ Apply inverse motion with spatial interpolation on traces. - Traces can be full traces, but also waveforms snippets. + Traces can be full traces, but also waveforms snippets. Times used for looking up + displacements are controlled by interpolation_time_bin_edges_s or + interpolation_time_bin_centers_s, or fall back to the Motion object's time bins + by default; times in the recording outside these time bins use the closest edge + bin's displacement value during interpolation. Parameters ---------- @@ -80,6 +87,9 @@ def interpolate_motion_on_traces( interpolation_time_bin_centers_s : None or np.array Manually specify the time bins which the interpolation happens in for this segment. If None, these are the motion estimate's time bins. + interpolation_time_bin_edges_s : None or np.array + If present, interpolation chunks will be the time bins defined by these edges + rather than interpolation_time_bin_centers_s or the motion's bins. spatial_interpolation_method : "idw" | "kriging", default: "kriging" The spatial interpolation method used to interpolate the channel locations: * idw : Inverse Distance Weighing @@ -119,26 +129,33 @@ def interpolate_motion_on_traces( total_num_chans = channel_locations.shape[0] # -- determine the blocks of frames that will land in the same interpolation time bin - time_bins = interpolation_time_bin_centers_s - if time_bins is None: - time_bins = motion.temporal_bins_s[segment_index] - bin_s = time_bins[1] - time_bins[0] - bins_start = time_bins[0] - 0.5 * bin_s - # nearest bin center for each frame? - bin_inds = (times - bins_start) // bin_s - bin_inds = bin_inds.astype(int) + if interpolation_time_bin_centers_s is None and interpolation_time_bin_edges_s is None: + interpolation_time_bin_centers_s = motion.temporal_bins_s[segment_index] + interpolation_time_bin_edges_s = motion.temporal_bin_edges_s[segment_index] + else: + interpolation_time_bin_centers_s, interpolation_time_bin_edges_s = ensure_time_bins( + interpolation_time_bin_centers_s, interpolation_time_bin_edges_s + ) + + # bin the frame times according to the interpolation time bins. + # searchsorted(b, t, side="right") == i means that b[i-1] <= t < b[i] + # hence the -1. doing it with "left" is not as nice -- we want t==b[0] + # to lead to i=1 (rounding down). + interpolation_bin_inds = np.searchsorted(interpolation_time_bin_edges_s, times, side="right") - 1 + # the time bins may not cover the whole set of times in the recording, # so we need to clip these indices to the valid range - np.clip(bin_inds, 0, time_bins.size, out=bin_inds) + n_bins = interpolation_time_bin_edges_s.shape[0] - 1 + np.clip(interpolation_bin_inds, 0, n_bins - 1, out=interpolation_bin_inds) # -- what are the possibilities here anyway? - bins_here = np.arange(bin_inds[0], bin_inds[-1] + 1) + interpolation_bins_here = np.arange(interpolation_bin_inds[0], interpolation_bin_inds[-1] + 1) # inperpolation kernel will be the same per temporal bin interp_times = np.empty(total_num_chans) current_start_index = 0 - for bin_ind in bins_here: - bin_time = time_bins[bin_ind] + for interp_bin_ind in interpolation_bins_here: + bin_time = interpolation_time_bin_centers_s[interp_bin_ind] interp_times.fill(bin_time) channel_motions = motion.get_displacement_at_time_and_depth( interp_times, @@ -166,16 +183,17 @@ def interpolate_motion_on_traces( # ax.set_title(f"bin_ind {bin_ind} - {bin_time}s - {spatial_interpolation_method}") # plt.show() + # quick search logic to find frames corresponding to this interpolation bin in the recording # quickly find the end of this bin, which is also the start of the next next_start_index = current_start_index + np.searchsorted( - bin_inds[current_start_index:], bin_ind + 1, side="left" + interpolation_bin_inds[current_start_index:], interp_bin_ind + 1, side="left" ) - in_bin = slice(current_start_index, next_start_index) + frames_in_bin = slice(current_start_index, next_start_index) # here we use a simple np.matmul even if dirft_kernel can be super sparse. # because the speed for a sparse matmul is not so good when we disable multi threaad (due multi processing # in ChunkRecordingExecutor) - np.matmul(traces[in_bin], drift_kernel, out=traces_corrected[in_bin]) + np.matmul(traces[frames_in_bin], drift_kernel, out=traces_corrected[frames_in_bin]) current_start_index = next_start_index return traces_corrected @@ -297,6 +315,7 @@ def __init__( p=1, num_closest=3, interpolation_time_bin_centers_s=None, + interpolation_time_bin_edges_s=None, interpolation_time_bin_size_s=None, dtype=None, **spatial_interpolation_kwargs, @@ -363,9 +382,14 @@ def __init__( # handle manual interpolation_time_bin_centers_s # the case where interpolation_time_bin_size_s is set is handled per-segment below - if interpolation_time_bin_centers_s is None: + if interpolation_time_bin_centers_s is None and interpolation_time_bin_edges_s is None: if interpolation_time_bin_size_s is None: interpolation_time_bin_centers_s = motion.temporal_bins_s + interpolation_time_bin_edges_s = motion.temporal_bin_edges_s + else: + interpolation_time_bin_centers_s, interpolation_time_bin_edges_s = ensure_time_bins( + interpolation_time_bin_centers_s, interpolation_time_bin_edges_s + ) for segment_index, parent_segment in enumerate(recording._recording_segments): # finish the per-segment part of the time bin logic @@ -375,8 +399,13 @@ def __init__( t_start, t_end = parent_segment.sample_index_to_time(np.array([0, s_end])) halfbin = interpolation_time_bin_size_s / 2.0 segment_interpolation_time_bins_s = np.arange(t_start + halfbin, t_end, interpolation_time_bin_size_s) + segment_interpolation_time_bin_edges_s = np.arange( + t_start, t_end + halfbin, interpolation_time_bin_size_s + ) + assert segment_interpolation_time_bin_edges_s.shape == (segment_interpolation_time_bins_s.shape[0] + 1,) else: segment_interpolation_time_bins_s = interpolation_time_bin_centers_s[segment_index] + segment_interpolation_time_bin_edges_s = interpolation_time_bin_edges_s[segment_index] rec_segment = InterpolateMotionRecordingSegment( parent_segment, @@ -387,6 +416,7 @@ def __init__( channel_inds, segment_index, segment_interpolation_time_bins_s, + segment_interpolation_time_bin_edges_s, dtype=dtype_, ) self.add_recording_segment(rec_segment) @@ -420,6 +450,7 @@ def __init__( channel_inds, segment_index, interpolation_time_bin_centers_s, + interpolation_time_bin_edges_s, dtype="float32", ): BasePreprocessorSegment.__init__(self, parent_recording_segment) @@ -429,13 +460,11 @@ def __init__( self.channel_inds = channel_inds self.segment_index = segment_index self.interpolation_time_bin_centers_s = interpolation_time_bin_centers_s + self.interpolation_time_bin_edges_s = interpolation_time_bin_edges_s self.dtype = dtype self.motion = motion def get_traces(self, start_frame, end_frame, channel_indices): - if self.time_vector is not None: - raise NotImplementedError("InterpolateMotionRecording does not yet support recordings with time_vectors.") - if start_frame is None: start_frame = 0 if end_frame is None: @@ -453,7 +482,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): channel_inds=self.channel_inds, spatial_interpolation_method=self.spatial_interpolation_method, spatial_interpolation_kwargs=self.spatial_interpolation_kwargs, - interpolation_time_bin_centers_s=self.interpolation_time_bin_centers_s, + interpolation_time_bin_edges_s=self.interpolation_time_bin_edges_s, ) if channel_indices is not None: diff --git a/src/spikeinterface/sortingcomponents/motion/motion_utils.py b/src/spikeinterface/sortingcomponents/motion/motion_utils.py index 635624cca8..680d75f221 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_utils.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_utils.py @@ -1,5 +1,5 @@ -import warnings import json +import warnings from pathlib import Path import numpy as np @@ -54,6 +54,7 @@ def __init__(self, displacement, temporal_bins_s, spatial_bins_um, direction="y" self.direction = direction self.dim = ["x", "y", "z"].index(direction) self.check_properties() + self.temporal_bin_edges_s = [ensure_time_bin_edges(tbins) for tbins in self.temporal_bins_s] def check_properties(self): assert all(d.ndim == 2 for d in self.displacement) @@ -576,3 +577,40 @@ def make_3d_motion_histograms( motion_histograms = np.log2(1 + motion_histograms) return motion_histograms, temporal_bin_edges, spatial_bin_edges + + +def ensure_time_bins(time_bin_centers_s=None, time_bin_edges_s=None): + """Ensure that both bin edges and bin centers are present + + If either of the inputs are None but not both, the missing is reconstructed + from the present. Going from edges to centers is done by taking midpoints. + Going from centers to edges is done by taking midpoints and padding with the + left and rightmost centers. + + Parameters + ---------- + time_bin_centers_s : None or np.array + time_bin_edges_s : None or np.array + + Returns + ------- + time_bin_centers_s, time_bin_edges_s + """ + if time_bin_centers_s is None and time_bin_edges_s is None: + raise ValueError("Need at least one of time_bin_centers_s or time_bin_edges_s.") + + if time_bin_centers_s is None: + assert time_bin_edges_s.ndim == 1 and time_bin_edges_s.size >= 2 + time_bin_centers_s = 0.5 * (time_bin_edges_s[1:] + time_bin_edges_s[:-1]) + + if time_bin_edges_s is None: + time_bin_edges_s = np.empty(time_bin_centers_s.shape[0] + 1, dtype=time_bin_centers_s.dtype) + time_bin_edges_s[[0, -1]] = time_bin_centers_s[[0, -1]] + if time_bin_centers_s.size > 2: + time_bin_edges_s[1:-1] = 0.5 * (time_bin_centers_s[1:] + time_bin_centers_s[:-1]) + + return time_bin_centers_s, time_bin_edges_s + + +def ensure_time_bin_edges(time_bin_centers_s=None, time_bin_edges_s=None): + return ensure_time_bins(time_bin_centers_s, time_bin_edges_s)[1] diff --git a/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py index e022f0cc6c..e4ba870325 100644 --- a/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py @@ -1,16 +1,14 @@ -from pathlib import Path +import warnings import numpy as np -import pytest import spikeinterface.core as sc -from spikeinterface import download_dataset +from spikeinterface.sortingcomponents.motion import Motion from spikeinterface.sortingcomponents.motion.motion_interpolation import ( InterpolateMotionRecording, correct_motion_on_peaks, interpolate_motion, interpolate_motion_on_traces, ) -from spikeinterface.sortingcomponents.motion import Motion from spikeinterface.sortingcomponents.tests.common import make_dataset @@ -67,43 +65,45 @@ def test_interpolate_motion_on_traces(): times = rec.get_times()[0:30000] for method in ("kriging", "idw", "nearest"): - traces_corrected = interpolate_motion_on_traces( - traces, - times, - channel_locations, - motion, - channel_inds=None, - spatial_interpolation_method=method, - # spatial_interpolation_kwargs={}, - spatial_interpolation_kwargs={"force_extrapolate": True}, - ) - assert traces.shape == traces_corrected.shape - assert traces.dtype == traces_corrected.dtype + for interpolation_time_bin_centers_s in (None, np.linspace(*times[[0, -1]], num=3)): + traces_corrected = interpolate_motion_on_traces( + traces, + times, + channel_locations, + motion, + channel_inds=None, + spatial_interpolation_method=method, + interpolation_time_bin_centers_s=interpolation_time_bin_centers_s, + # spatial_interpolation_kwargs={}, + spatial_interpolation_kwargs={"force_extrapolate": True}, + ) + assert traces.shape == traces_corrected.shape + assert traces.dtype == traces_corrected.dtype def test_interpolation_simple(): # a recording where a 1 moves at 1 chan per second. 30 chans 10 frames. # there will be 9 chans of drift, so we add 9 chans of padding to the bottom - nt = nc0 = 10 # these need to be the same for this test - nc1 = nc0 + nc0 - 1 - traces = np.zeros((nt, nc1), dtype="float32") - traces[:, :nc0] = np.eye(nc0) + n_samples = num_chans_orig = 10 # these need to be the same for this test + num_chans_drifted = num_chans_orig + num_chans_orig - 1 + traces = np.zeros((n_samples, num_chans_drifted), dtype="float32") + traces[:, :num_chans_orig] = np.eye(num_chans_orig) rec = sc.NumpyRecording(traces, sampling_frequency=1) - rec.set_dummy_probe_from_locations(np.c_[np.zeros(nc1), np.arange(nc1)]) + rec.set_dummy_probe_from_locations(np.c_[np.zeros(num_chans_drifted), np.arange(num_chans_drifted)]) - true_motion = Motion(np.arange(nt)[:, None], 0.5 + np.arange(nt), np.zeros(1)) + true_motion = Motion(np.arange(n_samples)[:, None], 0.5 + np.arange(n_samples), np.zeros(1)) rec_corrected = interpolate_motion(rec, true_motion, spatial_interpolation_method="nearest") traces_corrected = rec_corrected.get_traces() - assert traces_corrected.shape == (nc0, nc0) - assert np.array_equal(traces_corrected[:, 0], np.ones(nt)) - assert np.array_equal(traces_corrected[:, 1:], np.zeros((nt, nc0 - 1))) + assert traces_corrected.shape == (num_chans_orig, num_chans_orig) + assert np.array_equal(traces_corrected[:, 0], np.ones(n_samples)) + assert np.array_equal(traces_corrected[:, 1:], np.zeros((n_samples, num_chans_orig - 1))) # let's try a new version where we interpolate too slowly rec_corrected = interpolate_motion( rec, true_motion, spatial_interpolation_method="nearest", num_closest=2, interpolation_time_bin_size_s=2 ) traces_corrected = rec_corrected.get_traces() - assert traces_corrected.shape == (nc0, nc0) + assert traces_corrected.shape == (num_chans_orig, num_chans_orig) # what happens with nearest here? # well... due to rounding towards the nearest even number, the motion (which at # these time bin centers is 0.5, 2.5, 4.5, ...) flips the signal's nearest @@ -115,6 +115,66 @@ def test_interpolation_simple(): assert np.all(traces_corrected[:, 2:] == 0) +def test_cross_band_interpolation(): + """Simple version of using LFP to interpolate AP data + + This also tests the time vector implementation in interpolation. + The idea is to have two recordings which are all 0s with a 1 that + moves from one channel to another after 3s. They're at different + sampling frequencies. motion estimation in one sampling frequency + applied to the other should still lead to perfect correction. + """ + from spikeinterface.sortingcomponents.motion import estimate_motion + + # sampling freqs and timing for AP and LFP recordings + fs_lfp = 50.0 + fs_ap = 300.0 + t_start = 10.0 + total_duration = 5.0 + num_samples_lfp = int(fs_lfp * total_duration) + num_samples_ap = int(fs_ap * total_duration) + t_switch = 3 + + # because interpolation uses bin centers logic, there will be a half + # bin offset at the change point in the AP recording. + halfbin_ap_lfp = int(0.5 * (fs_ap / fs_lfp)) + + # channel geometry + num_chans = 10 + geom = np.c_[np.zeros(num_chans), np.arange(num_chans)] + + # make an LFP recording which drifts a bit + traces_lfp = np.zeros((num_samples_lfp, num_chans)) + traces_lfp[: int(t_switch * fs_lfp), 5] = 1.0 + traces_lfp[int(t_switch * fs_lfp) :, 6] = 1.0 + rec_lfp = sc.NumpyRecording(traces_lfp, sampling_frequency=fs_lfp) + rec_lfp.set_dummy_probe_from_locations(geom) + + # same for AP + traces_ap = np.zeros((num_samples_ap, num_chans)) + traces_ap[: int(t_switch * fs_ap) - halfbin_ap_lfp, 5] = 1.0 + traces_ap[int(t_switch * fs_ap) - halfbin_ap_lfp :, 6] = 1.0 + rec_ap = sc.NumpyRecording(traces_ap, sampling_frequency=fs_ap) + rec_ap.set_dummy_probe_from_locations(geom) + + # set times for both, and silence the warning + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=UserWarning) + rec_lfp.set_times(t_start + np.arange(num_samples_lfp) / fs_lfp) + rec_ap.set_times(t_start + np.arange(num_samples_ap) / fs_ap) + + # estimate motion + motion = estimate_motion(rec_lfp, method="dredge_lfp", rigid=True) + + # nearest to keep it simple + rec_corrected = interpolate_motion(rec_ap, motion, spatial_interpolation_method="nearest", num_closest=2) + traces_corrected = rec_corrected.get_traces() + target = np.zeros((num_samples_ap, num_chans - 2)) + target[:, 4] = 1 + ii, jj = np.nonzero(traces_corrected) + assert np.array_equal(traces_corrected, target) + + def test_InterpolateMotionRecording(): rec, sorting = make_dataset() motion = make_fake_motion(rec) @@ -147,6 +207,7 @@ def test_InterpolateMotionRecording(): if __name__ == "__main__": # test_correct_motion_on_peaks() - # test_interpolate_motion_on_traces() - test_interpolation_simple() - test_InterpolateMotionRecording() + test_interpolate_motion_on_traces() + # test_interpolation_simple() + # test_InterpolateMotionRecording() + test_cross_band_interpolation() diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index 5b1d33b334..12955e2c40 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -57,6 +57,7 @@ def detect_peaks( folder=None, names=None, skip_after_n_peaks=None, + recording_slices=None, **kwargs, ): """Peak detection based on threshold crossing in term of k x MAD. @@ -83,6 +84,10 @@ def detect_peaks( skip_after_n_peaks : None | int Skip the computation after n_peaks. This is not an exact because internally this skip is done per worker in average. + recording_slices : None | list[tuple] + Optionaly give a list of slices to run the pipeline only on some chunks of the recording. + It must be a list of (segment_index, frame_start, frame_stop). + If None (default), the function iterates over the entire duration of the recording. {method_doc} {job_doc} @@ -113,7 +118,11 @@ def detect_peaks( squeeze_output = True else: squeeze_output = False - job_name += f" + {len(pipeline_nodes)} nodes" + if len(pipeline_nodes) == 1: + plural = "" + else: + plural = "s" + job_name += f" + {len(pipeline_nodes)} node{plural}" # because node are modified inplace (insert parent) they need to copy incase # the same pipeline is run several times @@ -135,6 +144,7 @@ def detect_peaks( folder=folder, names=names, skip_after_n_peaks=skip_after_n_peaks, + recording_slices=recording_slices, ) return outs @@ -671,7 +681,6 @@ def __init__( medians = medians[:, None] noise_levels = np.median(np.abs(conv_random_data - medians), axis=1) / 0.6744897501960817 self.abs_thresholds = noise_levels * detect_threshold - self._dtype = np.dtype(base_peak_dtype + [("z", "float32")]) def get_dtype(self): @@ -721,8 +730,8 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): return (np.zeros(0, dtype=self._dtype),) peak_sample_ind += self.exclude_sweep_size + self.conv_margin + self.nbefore - peak_amplitude = traces[peak_sample_ind, peak_chan_ind] + local_peaks = np.zeros(peak_sample_ind.size, dtype=self._dtype) local_peaks["sample_index"] = peak_sample_ind local_peaks["channel_index"] = peak_chan_ind diff --git a/src/spikeinterface/sortingcomponents/peak_localization.py b/src/spikeinterface/sortingcomponents/peak_localization.py index 08bcabf5e5..1e4e0edded 100644 --- a/src/spikeinterface/sortingcomponents/peak_localization.py +++ b/src/spikeinterface/sortingcomponents/peak_localization.py @@ -33,7 +33,7 @@ get_grid_convolution_templates_and_weights, ) -from .tools import get_prototype_spike +from .tools import get_prototype_and_waveforms_from_peaks def get_localization_pipeline_nodes( @@ -73,8 +73,8 @@ def get_localization_pipeline_nodes( assert isinstance(peak_source, (PeakRetriever, SpikeRetriever)) # extract prototypes silently job_kwargs["progress_bar"] = False - method_kwargs["prototype"] = get_prototype_spike( - recording, peak_source.peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs + method_kwargs["prototype"], _, _ = get_prototype_and_waveforms_from_peaks( + recording, peaks=peak_source.peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs ) extract_dense_waveforms = ExtractDenseWaveforms( recording, parents=[peak_source], ms_before=ms_before, ms_after=ms_after, return_output=False diff --git a/src/spikeinterface/sortingcomponents/tests/common.py b/src/spikeinterface/sortingcomponents/tests/common.py index 01e4445a13..d5e5b6be1b 100644 --- a/src/spikeinterface/sortingcomponents/tests/common.py +++ b/src/spikeinterface/sortingcomponents/tests/common.py @@ -21,4 +21,10 @@ def make_dataset(): noise_kwargs=dict(noise_levels=5.0, strategy="on_the_fly"), seed=2205, ) + + channel_ids_as_integers = [id for id in range(recording.get_num_channels())] + unit_ids_as_integers = [id for id in range(sorting.get_num_units())] + recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers) + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers) + return recording, sorting diff --git a/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py b/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py index 7c34f5948d..341ed3426d 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py +++ b/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py @@ -22,7 +22,7 @@ ) from spikeinterface.core.node_pipeline import run_node_pipeline -from spikeinterface.sortingcomponents.tools import get_prototype_spike +from spikeinterface.sortingcomponents.tools import get_prototype_and_waveforms_from_peaks from spikeinterface.sortingcomponents.tests.common import make_dataset @@ -314,7 +314,9 @@ def test_detect_peaks_locally_exclusive_matched_filtering(recording, job_kwargs) ms_before = 1.0 ms_after = 1.0 - prototype = get_prototype_spike(recording, peaks_by_channel_np, ms_before, ms_after, **job_kwargs) + prototype, _, _ = get_prototype_and_waveforms_from_peaks( + recording, peaks=peaks_by_channel_np, ms_before=ms_before, ms_after=ms_after, **job_kwargs + ) peaks_local_mf_filtering = detect_peaks( recording, diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 1501582336..1bd2381cda 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -69,25 +69,174 @@ def extract_waveform_at_max_channel(rec, peaks, ms_before=0.5, ms_after=1.5, **j return all_wfs -def get_prototype_spike(recording, peaks, ms_before=0.5, ms_after=0.5, nb_peaks=1000, **job_kwargs): +def get_prototype_and_waveforms_from_peaks( + recording, peaks, n_peaks=5000, ms_before=0.5, ms_after=0.5, seed=None, **all_kwargs +): + """ + Function to extract a prototype waveform from peaks. + + Parameters + ---------- + recording : Recording + The recording object containing the data. + peaks : numpy.array, optional + Array of peaks, if None, peaks will be detected, by default None. + n_peaks : int, optional + Number of peaks to consider, by default 5000. + ms_before : float, optional + Time in milliseconds before the peak to extract the waveform, by default 0.5. + ms_after : float, optional + Time in milliseconds after the peak to extract the waveform, by default 0.5. + seed : int or None, optional + Seed for random number generator, by default None. + **all_kwargs : dict + Additional keyword arguments for peak detection and job kwargs. + + Returns + ------- + prototype : numpy.array + The prototype waveform. + waveforms : numpy.array + The extracted waveforms for the selected peaks. + peaks : numpy.array + The selected peaks used to extract waveforms. + """ from spikeinterface.sortingcomponents.peak_selection import select_peaks + _, job_kwargs = split_job_kwargs(all_kwargs) + nbefore = int(ms_before * recording.sampling_frequency / 1000.0) nafter = int(ms_after * recording.sampling_frequency / 1000.0) - few_peaks = select_peaks(peaks, recording=recording, method="uniform", n_peaks=nb_peaks, margin=(nbefore, nafter)) - + few_peaks = select_peaks( + peaks, recording=recording, method="uniform", n_peaks=n_peaks, margin=(nbefore, nafter), seed=seed + ) waveforms = extract_waveform_at_max_channel( recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs ) + + with np.errstate(divide="ignore", invalid="ignore"): + prototype = np.nanmedian(waveforms[:, :, 0] / (np.abs(waveforms[:, nbefore, 0][:, np.newaxis])), axis=0) + + return prototype, waveforms[:, :, 0], few_peaks + + +def get_prototype_and_waveforms_from_recording( + recording, n_peaks=5000, ms_before=0.5, ms_after=0.5, seed=None, **all_kwargs +): + """ + Function to extract a prototype waveform from peaks detected on the fly. + + Parameters + ---------- + recording : Recording + The recording object containing the data. + n_peaks : int, optional + Number of peaks to consider, by default 5000. + ms_before : float, optional + Time in milliseconds before the peak to extract the waveform, by default 0.5. + ms_after : float, optional + Time in milliseconds after the peak to extract the waveform, by default 0.5. + seed : int or None, optional + Seed for random number generator, by default None. + **all_kwargs : dict + Additional keyword arguments for peak detection and job kwargs. + + Returns + ------- + prototype : numpy.array + The prototype waveform. + waveforms : numpy.array + The extracted waveforms for the selected peaks. + peaks : numpy.array + The selected peaks used to extract waveforms. + """ + from spikeinterface.sortingcomponents.peak_detection import detect_peaks + from spikeinterface.core.node_pipeline import ExtractSparseWaveforms + + detection_kwargs, job_kwargs = split_job_kwargs(all_kwargs) + + nbefore = int(ms_before * recording.sampling_frequency / 1000.0) + node = ExtractSparseWaveforms( + recording, + parents=None, + return_output=True, + ms_before=ms_before, + ms_after=ms_after, + radius_um=0, + ) + + pipeline_nodes = [node] + + recording_slices = get_shuffled_recording_slices(recording, seed=seed, **job_kwargs) + + res = detect_peaks( + recording, + pipeline_nodes=pipeline_nodes, + skip_after_n_peaks=n_peaks, + recording_slices=recording_slices, + **detection_kwargs, + **job_kwargs, + ) + + rng = np.random.RandomState(seed) + indices = rng.permutation(np.arange(len(res[0]))) + + few_peaks = res[0][indices[:n_peaks]] + waveforms = res[1][indices[:n_peaks]] + with np.errstate(divide="ignore", invalid="ignore"): prototype = np.nanmedian(waveforms[:, :, 0] / (np.abs(waveforms[:, nbefore, 0][:, np.newaxis])), axis=0) - return prototype + + return prototype, waveforms[:, :, 0], few_peaks + + +def get_prototype_and_waveforms( + recording, n_peaks=5000, peaks=None, ms_before=0.5, ms_after=0.5, seed=None, **all_kwargs +): + """ + Function to extract a prototype waveform either from peaks or from a peak detection. Note that in case + of a peak detection, the detection stops as soon as n_peaks are detected. + + Parameters + ---------- + recording : Recording + The recording object containing the data. + n_peaks : int, optional + Number of peaks to consider, by default 5000. + peaks : numpy.array, optional + Array of peaks, if None, peaks will be detected, by default None. + ms_before : float, optional + Time in milliseconds before the peak to extract the waveform, by default 0.5. + ms_after : float, optional + Time in milliseconds after the peak to extract the waveform, by default 0.5. + seed : int or None, optional + Seed for random number generator, by default None. + **all_kwargs : dict + Additional keyword arguments for peak detection and job kwargs. + + Returns + ------- + prototype : numpy.array + The prototype waveform. + waveforms : numpy.array + The extracted waveforms for the selected peaks. + peaks : numpy.array + The selected peaks used to extract waveforms. + """ + if peaks is None: + return get_prototype_and_waveforms_from_recording( + recording, n_peaks, ms_before=ms_before, ms_after=ms_after, seed=seed, **all_kwargs + ) + else: + return get_prototype_and_waveforms_from_peaks( + recording, peaks, n_peaks, ms_before=ms_before, ms_after=ms_after, seed=seed, **all_kwargs + ) def check_probe_for_drift_correction(recording, dist_x_max=60): num_channels = recording.get_num_channels() - if num_channels < 32: + if num_channels <= 32: return False else: locations = recording.get_channel_locations() @@ -151,3 +300,20 @@ def fit_sigmoid(xdata, ydata, p0=None): popt, pcov = curve_fit(sigmoid, xdata, ydata, p0) return popt + + +def get_shuffled_recording_slices(recording, seed=None, **job_kwargs): + from spikeinterface.core.job_tools import ensure_chunk_size + from spikeinterface.core.job_tools import divide_segment_into_chunks + + chunk_size = ensure_chunk_size(recording, **job_kwargs) + recording_slices = [] + for segment_index in range(recording.get_num_segments()): + num_frames = recording.get_num_samples(segment_index) + chunks = divide_segment_into_chunks(num_frames, chunk_size) + recording_slices.extend([(segment_index, frame_start, frame_stop) for frame_start, frame_stop in chunks]) + + rng = np.random.default_rng(seed) + recording_slices = rng.permutation(recording_slices) + + return recording_slices diff --git a/src/spikeinterface/widgets/autocorrelograms.py b/src/spikeinterface/widgets/autocorrelograms.py index c8acd93dc2..c211a277f8 100644 --- a/src/spikeinterface/widgets/autocorrelograms.py +++ b/src/spikeinterface/widgets/autocorrelograms.py @@ -9,7 +9,13 @@ class AutoCorrelogramsWidget(CrossCorrelogramsWidget): # the doc is copied form CrossCorrelogramsWidget def __init__(self, *args, **kargs): - CrossCorrelogramsWidget.__init__(self, *args, **kargs) + _ = kargs.pop("min_similarity_for_correlograms", 0.0) + CrossCorrelogramsWidget.__init__( + self, + *args, + **kargs, + min_similarity_for_correlograms=None, + ) def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt diff --git a/src/spikeinterface/widgets/crosscorrelograms.py b/src/spikeinterface/widgets/crosscorrelograms.py index cdb2041aa3..88dd803323 100644 --- a/src/spikeinterface/widgets/crosscorrelograms.py +++ b/src/spikeinterface/widgets/crosscorrelograms.py @@ -21,7 +21,8 @@ class CrossCorrelogramsWidget(BaseWidget): List of unit ids min_similarity_for_correlograms : float, default: 0.2 For sortingview backend. Threshold for computing pair-wise cross-correlograms. - If template similarity between two units is below this threshold, the cross-correlogram is not displayed + If template similarity between two units is below this threshold, the cross-correlogram is not displayed. + For auto-correlograms plot, this is automatically set to None. window_ms : float, default: 100.0 Window for CCGs in ms. If correlograms are already computed (e.g. with SortingAnalyzer), this argument is ignored diff --git a/src/spikeinterface/widgets/sorting_summary.py b/src/spikeinterface/widgets/sorting_summary.py index a113298851..8eada29b0e 100644 --- a/src/spikeinterface/widgets/sorting_summary.py +++ b/src/spikeinterface/widgets/sorting_summary.py @@ -2,6 +2,8 @@ import numpy as np +import warnings + from .base import BaseWidget, to_attr from .amplitudes import AmplitudesWidget @@ -14,6 +16,9 @@ from ..core import SortingAnalyzer +_default_displayed_unit_properties = ["firing_rate", "num_spikes", "x", "y", "amplitude_median", "snr", "rp_violation"] + + class SortingSummaryWidget(BaseWidget): """ Plots spike sorting summary. @@ -42,14 +47,24 @@ class SortingSummaryWidget(BaseWidget): label_choices : list or None, default: None List of labels to be added to the curation table (sortingview backend) - unit_table_properties : list or None, default: None + displayed_unit_properties : list or None, default: None List of properties to be added to the unit table. These may be drawn from the sorting extractor, and, if available, - the quality_metrics and template_metrics extensions of the SortingAnalyzer. + the quality_metrics/template_metrics/unit_locations extensions of the SortingAnalyzer. See all properties available with sorting.get_property_keys(), and, if available, analyzer.get_extension("quality_metrics").get_data().columns and analyzer.get_extension("template_metrics").get_data().columns. - (sortingview backend) + extra_unit_properties : dict or None, default: None + A dict with extra units properties to display. + curation_dict : dict or None, default: None + When curation is True, optionaly the viewer can get a previous 'curation_dict' + to continue/check previous curations on this analyzer. + In this case label_definitions must be None beacuse it is already included in the curation_dict. + (spikeinterface_gui backend) + label_definitions : dict or None, default: None + When curation is True, optionaly the user can provide a label_definitions dict. + This replaces the label_choices in the curation_format. + (spikeinterface_gui backend) """ def __init__( @@ -60,11 +75,24 @@ def __init__( max_amplitudes_per_unit=None, min_similarity_for_correlograms=0.2, curation=False, - unit_table_properties=None, + displayed_unit_properties=None, + extra_unit_properties=None, label_choices=None, + curation_dict=None, + label_definitions=None, backend=None, + unit_table_properties=None, **backend_kwargs, ): + + if unit_table_properties is not None: + warnings.warn( + "plot_sorting_summary() : unit_table_properties is deprecated, use displayed_unit_properties instead", + category=DeprecationWarning, + stacklevel=2, + ) + displayed_unit_properties = unit_table_properties + sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) self.check_extensions( sorting_analyzer, ["correlograms", "spike_amplitudes", "unit_locations", "template_similarity"] @@ -74,18 +102,29 @@ def __init__( if unit_ids is None: unit_ids = sorting.get_unit_ids() - plot_data = dict( + if curation_dict is not None and label_definitions is not None: + raise ValueError("curation_dict and label_definitions are mutualy exclusive, they cannot be not None both") + + if displayed_unit_properties is None: + displayed_unit_properties = list(_default_displayed_unit_properties) + if extra_unit_properties is not None: + displayed_unit_properties += list(extra_unit_properties.keys()) + + data_plot = dict( sorting_analyzer=sorting_analyzer, unit_ids=unit_ids, sparsity=sparsity, min_similarity_for_correlograms=min_similarity_for_correlograms, - unit_table_properties=unit_table_properties, + displayed_unit_properties=displayed_unit_properties, + extra_unit_properties=extra_unit_properties, curation=curation, label_choices=label_choices, max_amplitudes_per_unit=max_amplitudes_per_unit, + curation_dict=curation_dict, + label_definitions=label_definitions, ) - BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + BaseWidget.__init__(self, data_plot, backend=backend, **backend_kwargs) def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv @@ -156,7 +195,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs): # unit ids v_units_table = generate_unit_table_view( - dp.sorting_analyzer, dp.unit_table_properties, similarity_scores=similarity_scores + dp.sorting_analyzer, dp.displayed_unit_properties, similarity_scores=similarity_scores ) if dp.curation: @@ -190,9 +229,14 @@ def plot_sortingview(self, data_plot, **backend_kwargs): def plot_spikeinterface_gui(self, data_plot, **backend_kwargs): sorting_analyzer = data_plot["sorting_analyzer"] - import spikeinterface_gui + from spikeinterface_gui import run_mainwindow - app = spikeinterface_gui.mkQApp() - win = spikeinterface_gui.MainWindow(sorting_analyzer, curation=data_plot["curation"]) - win.show() - app.exec_() + run_mainwindow( + sorting_analyzer, + with_traces=True, + curation=data_plot["curation"], + curation_dict=data_plot["curation_dict"], + label_definitions=data_plot["label_definitions"], + extra_unit_properties=data_plot["extra_unit_properties"], + displayed_unit_properties=data_plot["displayed_unit_properties"], + ) diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 80f58f5ad9..d5ffec6dba 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -73,7 +73,9 @@ def setUpClass(cls): spike_amplitudes=dict(), unit_locations=dict(), spike_locations=dict(), - quality_metrics=dict(metric_names=["snr", "isi_violation", "num_spikes", "amplitude_cutoff"]), + quality_metrics=dict( + metric_names=["snr", "isi_violation", "num_spikes", "firing_rate", "amplitude_cutoff"] + ), template_metrics=dict(), correlograms=dict(), template_similarity=dict(), @@ -531,18 +533,29 @@ def test_plot_sorting_summary(self): possible_backends = list(sw.SortingSummaryWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_sorting_summary(self.sorting_analyzer_dense, backend=backend, **self.backend_kwargs[backend]) - sw.plot_sorting_summary(self.sorting_analyzer_sparse, backend=backend, **self.backend_kwargs[backend]) + sw.plot_sorting_summary( + self.sorting_analyzer_dense, + displayed_unit_properties=[], + backend=backend, + **self.backend_kwargs[backend], + ) + sw.plot_sorting_summary( + self.sorting_analyzer_sparse, + displayed_unit_properties=[], + backend=backend, + **self.backend_kwargs[backend], + ) sw.plot_sorting_summary( self.sorting_analyzer_sparse, sparsity=self.sparsity_strict, + displayed_unit_properties=[], backend=backend, **self.backend_kwargs[backend], ) - # add unit_properties + # select unit_properties sw.plot_sorting_summary( self.sorting_analyzer_sparse, - unit_table_properties=["firing_rate", "snr"], + displayed_unit_properties=["firing_rate", "snr"], backend=backend, **self.backend_kwargs[backend], ) @@ -550,7 +563,7 @@ def test_plot_sorting_summary(self): with self.assertWarns(UserWarning): sw.plot_sorting_summary( self.sorting_analyzer_sparse, - unit_table_properties=["missing_property"], + displayed_unit_properties=["missing_property"], backend=backend, **self.backend_kwargs[backend], ) @@ -688,9 +701,9 @@ def test_plot_motion_info(self): # mytest.test_plot_unit_presence() # mytest.test_plot_peak_activity() # mytest.test_plot_multicomparison() - # mytest.test_plot_sorting_summary() + mytest.test_plot_sorting_summary() # mytest.test_plot_motion() - mytest.test_plot_motion_info() - plt.show() + # mytest.test_plot_motion_info() + # plt.show() # TestWidgets.tearDownClass() diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index 86f2350a85..f5dadc780f 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -52,6 +52,8 @@ class TracesWidget(BaseWidget): If dict, keys should be the same as recording keys scale : float, default: 1 Scale factor for the traces + vspacing_factor : float, default: 1.5 + Vertical spacing between channels as a multiple of maximum channel amplitude with_colorbar : bool, default: True When mode is "map", a colorbar is added tile_size : int, default: 1500 @@ -82,6 +84,7 @@ def __init__( tile_size=1500, seconds_per_row=0.2, scale=1, + vspacing_factor=1.5, with_colorbar=True, add_legend=True, backend=None, @@ -168,7 +171,7 @@ def __init__( traces0 = list_traces[0] mean_channel_std = np.mean(np.std(traces0, axis=0)) max_channel_amp = np.max(np.max(np.abs(traces0), axis=0)) - vspacing = max_channel_amp * 1.5 + vspacing = max_channel_amp * vspacing_factor if rec0.get_channel_groups() is None: color_groups = False diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index 755e60ccbf..9466110110 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -107,82 +107,90 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): # and use custum grid spec fig = self.figure nrows = 2 - ncols = 3 - if sorting_analyzer.has_extension("correlograms") or sorting_analyzer.has_extension("spike_amplitudes"): + ncols = 2 + if sorting_analyzer.has_extension("correlograms"): + ncols += 1 + if sorting_analyzer.has_extension("waveforms"): ncols += 1 if sorting_analyzer.has_extension("spike_amplitudes"): nrows += 1 gs = fig.add_gridspec(nrows, ncols) + col_counter = 0 - if sorting_analyzer.has_extension("unit_locations"): - ax1 = fig.add_subplot(gs[:2, 0]) - # UnitLocationsPlotter().do_plot(dp.plot_data_unit_locations, ax=ax1) - w = UnitLocationsWidget( - sorting_analyzer, - unit_ids=[unit_id], - unit_colors=unit_colors, - plot_legend=False, - backend="matplotlib", - ax=ax1, - **unitlocationswidget_kwargs, - ) - - unit_locations = sorting_analyzer.get_extension("unit_locations").get_data(outputs="by_unit") - unit_location = unit_locations[unit_id] - x, y = unit_location[0], unit_location[1] - ax1.set_xlim(x - 80, x + 80) - ax1.set_ylim(y - 250, y + 250) - ax1.set_xticks([]) - ax1.set_xlabel(None) - ax1.set_ylabel(None) - - ax2 = fig.add_subplot(gs[:2, 1]) - w = UnitWaveformsWidget( + # Unit locations and unit waveform plots are always generated + ax_unit_locations = fig.add_subplot(gs[:2, col_counter]) + _ = UnitLocationsWidget( + sorting_analyzer, + unit_ids=[unit_id], + unit_colors=unit_colors, + plot_legend=False, + backend="matplotlib", + ax=ax_unit_locations, + **unitlocationswidget_kwargs, + ) + col_counter += 1 + + unit_locations = sorting_analyzer.get_extension("unit_locations").get_data(outputs="by_unit") + unit_location = unit_locations[unit_id] + x, y = unit_location[0], unit_location[1] + ax_unit_locations.set_xlim(x - 80, x + 80) + ax_unit_locations.set_ylim(y - 250, y + 250) + ax_unit_locations.set_xticks([]) + ax_unit_locations.set_xlabel(None) + ax_unit_locations.set_ylabel(None) + + ax_unit_waveforms = fig.add_subplot(gs[:2, col_counter]) + _ = UnitWaveformsWidget( sorting_analyzer, unit_ids=[unit_id], unit_colors=unit_colors, plot_templates=True, + plot_waveforms=sorting_analyzer.has_extension("waveforms"), same_axis=True, plot_legend=False, sparsity=sparsity, backend="matplotlib", - ax=ax2, + ax=ax_unit_waveforms, **unitwaveformswidget_kwargs, ) + col_counter += 1 - ax2.set_title(None) + ax_unit_waveforms.set_title(None) - ax3 = fig.add_subplot(gs[:2, 2]) - UnitWaveformDensityMapWidget( - sorting_analyzer, - unit_ids=[unit_id], - unit_colors=unit_colors, - use_max_channel=True, - same_axis=False, - backend="matplotlib", - ax=ax3, - **unitwaveformdensitymapwidget_kwargs, - ) - ax3.set_ylabel(None) + if sorting_analyzer.has_extension("waveforms"): + ax_waveform_density = fig.add_subplot(gs[:2, col_counter]) + UnitWaveformDensityMapWidget( + sorting_analyzer, + unit_ids=[unit_id], + unit_colors=unit_colors, + use_max_channel=True, + same_axis=False, + backend="matplotlib", + ax=ax_waveform_density, + **unitwaveformdensitymapwidget_kwargs, + ) + col_counter += 1 + ax_waveform_density.set_ylabel(None) if sorting_analyzer.has_extension("correlograms"): - ax4 = fig.add_subplot(gs[:2, 3]) + ax_correlograms = fig.add_subplot(gs[:2, col_counter]) AutoCorrelogramsWidget( sorting_analyzer, unit_ids=[unit_id], unit_colors=unit_colors, backend="matplotlib", - ax=ax4, + ax=ax_correlograms, **autocorrelogramswidget_kwargs, ) + col_counter += 1 - ax4.set_title(None) - ax4.set_yticks([]) + ax_correlograms.set_title(None) + ax_correlograms.set_yticks([]) if sorting_analyzer.has_extension("spike_amplitudes"): - ax5 = fig.add_subplot(gs[2, :3]) - ax6 = fig.add_subplot(gs[2, 3]) - axes = np.array([ax5, ax6]) + ax_spike_amps = fig.add_subplot(gs[2, : col_counter - 1]) + ax_amps_distribution = fig.add_subplot(gs[2, col_counter - 1]) + axes = np.array([ax_spike_amps, ax_amps_distribution]) AmplitudesWidget( sorting_analyzer, unit_ids=[unit_id], diff --git a/src/spikeinterface/widgets/unit_waveforms.py b/src/spikeinterface/widgets/unit_waveforms.py index c593836061..3b31eacee5 100644 --- a/src/spikeinterface/widgets/unit_waveforms.py +++ b/src/spikeinterface/widgets/unit_waveforms.py @@ -565,7 +565,7 @@ def _update_plot(self, change): channel_locations = self.sorting_analyzer.get_channel_locations() else: unit_indices = [list(self.templates.unit_ids).index(unit_id) for unit_id in unit_ids] - templates = self.templates.templates_array[unit_indices] + templates = self.templates.get_dense_templates()[unit_indices] templates_shadings = None channel_locations = self.templates.get_channel_locations() diff --git a/src/spikeinterface/widgets/utils.py b/src/spikeinterface/widgets/utils.py index ca09cc4d8f..75c6248f0f 100644 --- a/src/spikeinterface/widgets/utils.py +++ b/src/spikeinterface/widgets/utils.py @@ -243,3 +243,93 @@ def array_to_image( output_image = np.frombuffer(image.tobytes(), dtype=np.uint8).reshape(output_image.shape) return output_image + + +def make_units_table_from_sorting(sorting, units_table=None): + """ + Make a DataFrame from sorting properties. + Only for properties with ndim=1 + + Parameters + ---------- + sorting : Sorting + The Sorting object + units_table : None | pd.DataFrame + Optionally a existing dataframe. + + Returns + ------- + units_table : pd.DataFrame + Table containing all columns. + """ + + if units_table is None: + import pandas as pd + + units_table = pd.DataFrame(index=sorting.unit_ids) + + for col in sorting.get_property_keys(): + values = sorting.get_property(col) + if values.dtype.kind in "iuUSfb" and values.ndim == 1: + units_table.loc[:, col] = values + + return units_table + + +def make_units_table_from_analyzer( + analyzer, + extra_properties=None, +): + """ + Make a DataFrame by aggregating : + * quality metrics + * template metrics + * unit_position + * sorting properties + * extra columns + + This used in sortingview and spikeinterface-gui to display the units table in a flexible way. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The SortingAnalyzer object + extra_properties : None | dict + Extra columns given as dict. + + Returns + ------- + units_table : pd.DataFrame + Table containing all columns. + """ + import pandas as pd + + all_df = [] + + if analyzer.get_extension("unit_locations") is not None: + locs = analyzer.get_extension("unit_locations").get_data() + df = pd.DataFrame(locs[:, :2], columns=["x", "y"], index=analyzer.unit_ids) + all_df.append(df) + + if analyzer.get_extension("quality_metrics") is not None: + df = analyzer.get_extension("quality_metrics").get_data() + all_df.append(df) + + if analyzer.get_extension("template_metrics") is not None: + df = analyzer.get_extension("template_metrics").get_data() + all_df.append(df) + + if len(all_df) > 0: + units_table = pd.concat(all_df, axis=1) + else: + units_table = pd.DataFrame(index=analyzer.unit_ids) + + make_units_table_from_sorting(analyzer.sorting, units_table=units_table) + + if extra_properties is not None: + for col, values in extra_properties.items(): + # the ndim = 1 is important because we need column only for the display in gui. + if values.dtype.kind in "iuUSfb" and values.ndim == 1: + units_table.loc[:, col] = values + + return units_table diff --git a/src/spikeinterface/widgets/utils_sortingview.py b/src/spikeinterface/widgets/utils_sortingview.py index a6cc562ba2..d594414287 100644 --- a/src/spikeinterface/widgets/utils_sortingview.py +++ b/src/spikeinterface/widgets/utils_sortingview.py @@ -1,10 +1,12 @@ from __future__ import annotations +from warnings import warn + import numpy as np from ..core import SortingAnalyzer, BaseSorting from ..core.core_tools import check_json -from warnings import warn +from .utils import make_units_table_from_sorting, make_units_table_from_analyzer def make_serializable(*args): @@ -50,105 +52,55 @@ def handle_display_and_url(widget, view, **backend_kwargs): def generate_unit_table_view( sorting_or_sorting_analyzer: SortingAnalyzer | BaseSorting, unit_properties: list[str] | None = None, - similarity_scores: npndarray | None = None, + similarity_scores: np.ndarray | None = None, ): import sortingview.views as vv if isinstance(sorting_or_sorting_analyzer, SortingAnalyzer): analyzer = sorting_or_sorting_analyzer + units_tables = make_units_table_from_analyzer(analyzer) sorting = analyzer.sorting else: sorting = sorting_or_sorting_analyzer - analyzer = None - - # Find available unit properties from all sources - sorting_props = list(sorting.get_property_keys()) - if analyzer is not None: - if analyzer.get_extension("quality_metrics") is not None: - qm_props = list(analyzer.get_extension("quality_metrics").get_data().columns) - qm_data = analyzer.get_extension("quality_metrics").get_data() - else: - qm_props = [] - if analyzer.get_extension("template_metrics") is not None: - tm_props = list(analyzer.get_extension("template_metrics").get_data().columns) - tm_data = analyzer.get_extension("template_metrics").get_data() - else: - tm_props = [] - # Check for any overlaps and warn user if any - all_props = sorting_props + qm_props + tm_props - else: - all_props = sorting_props - qm_props = [] - tm_props = [] - qm_data = None - tm_data = None - - overlap_props = [prop for prop in all_props if all_props.count(prop) > 1] - if len(overlap_props) > 0: - warn( - f"Warning: Overlapping properties found in sorting, quality_metrics, and template_metrics: {overlap_props}" - ) - - # Get unit properties + units_tables = make_units_table_from_sorting(sorting) + # analyzer = None + if unit_properties is None: ut_columns = [] ut_rows = [vv.UnitsTableRow(unit_id=u, values={}) for u in sorting.unit_ids] else: + # keep only selected columns + unit_properties = np.array(unit_properties) + keep = np.isin(unit_properties, units_tables.columns) + if sum(keep) < len(unit_properties): + warn(f"Some unit properties are not in the sorting: {unit_properties[~keep]}") + unit_properties = unit_properties[keep] + units_tables = units_tables.loc[:, unit_properties] + + dtype_convertor = {"i": "int", "u": "int", "f": "float", "U": "str", "S": "str", "b": "bool"} + ut_columns = [] + for col in unit_properties: + values = units_tables[col].to_numpy() + if values.dtype.kind in dtype_convertor: + txt_dtype = dtype_convertor[values.dtype.kind] + ut_columns.append(vv.UnitsTableColumn(key=col, label=col, dtype=txt_dtype)) + ut_rows = [] - values = {} - valid_unit_properties = [] - - # Create columns for each property - for prop_name in unit_properties: - - # Get property values from correct location - if prop_name in sorting_props: - property_values = sorting.get_property(prop_name) - elif prop_name in qm_props: - property_values = qm_data[prop_name].to_numpy() - elif prop_name in tm_props: - property_values = tm_data[prop_name].to_numpy() - else: - warn(f"Property '{prop_name}' not found in sorting, quality_metrics, or template_metrics") - continue - - # make dtype available - val0 = np.array(property_values[0]) - if val0.dtype.kind in ("i", "u"): - dtype = "int" - elif val0.dtype.kind in ("U", "S"): - dtype = "str" - elif val0.dtype.kind == "f": - dtype = "float" - elif val0.dtype.kind == "b": - dtype = "bool" - else: - warn(f"Unsupported dtype {val0.dtype} for property {prop_name}. Skipping") - continue - ut_columns.append(vv.UnitsTableColumn(key=prop_name, label=prop_name, dtype=dtype)) - valid_unit_properties.append(prop_name) - - # Create rows for each unit - for ui, unit in enumerate(sorting.unit_ids): - for prop_name in valid_unit_properties: - - # Get property values from correct location - if prop_name in sorting_props: - property_values = sorting.get_property(prop_name) - elif prop_name in qm_props: - property_values = qm_data[prop_name].to_numpy() - elif prop_name in tm_props: - property_values = tm_data[prop_name].to_numpy() - - # Check for NaN values and round floats - val0 = np.array(property_values[0]) - if val0.dtype.kind == "f": - if np.isnan(property_values[ui]): - continue - property_values[ui] = np.format_float_positional(property_values[ui], precision=4, fractional=False) - values[prop_name] = property_values[ui] - ut_rows.append(vv.UnitsTableRow(unit_id=unit, values=check_json(values))) + for unit_index, unit_id in enumerate(sorting.unit_ids): + row_values = {} + for col in unit_properties: + values = units_tables[col].to_numpy() + if values.dtype.kind in dtype_convertor: + value = values[unit_index] + if values.dtype.kind == "f": + # Check for NaN values and round floats + if np.isnan(values[unit_index]): + continue + value = np.format_float_positional(value, precision=4, fractional=False) + row_values[col] = value + ut_rows.append(vv.UnitsTableRow(unit_id=unit_id, values=check_json(row_values))) v_units_table = vv.UnitsTable(rows=ut_rows, columns=ut_columns, similarity_scores=similarity_scores) + return v_units_table