From 36769caf7302e8f35323bc0a4f22195de7d0b193 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Sat, 4 Nov 2023 13:52:29 -0400 Subject: [PATCH 01/48] add import docs --- doc/import.rst | 81 ++++++++++++++++++++++++++++++++++++++++++++ doc/index.rst | 1 + doc/installation.rst | 12 ++++--- doc/modules/core.rst | 31 ----------------- 4 files changed, 90 insertions(+), 35 deletions(-) create mode 100644 doc/import.rst diff --git a/doc/import.rst b/doc/import.rst new file mode 100644 index 0000000000..8024ae5be8 --- /dev/null +++ b/doc/import.rst @@ -0,0 +1,81 @@ +Importing SpikeInterface +======================== + +SpikeInterface allows for the generation of powerful and reproducible spike sorting pipelines. +Flexibility is built into the package starting from import to maximize the productivity of +the developer and the scientist. Thus there are three ways that SpikeInterface and its components +can be imported: + + +Importing by Module +------------------- + +Since each spike sorting pipeline involves a series of often repeated steps, many of the developers +working on SpikeInterface recommend importing in a module by module fashion. This will allow you to +keep track of your processing steps (preprocessing, postprocessing, quality metrics, etc.). This can +be accomplished by: + +.. code-block:: python + + import spikeinterface as si + +to import the :code:`core` module followed by: + +.. code-block:: python + + import spikeinterface.extractors as se + import spikeinterface.preprocessing as spre + import spikeinterface.sorters as ss + import spikinterface.postprocessing as spost + import spikeinterface.qualitymetrics as sqm + import spikeinterface.exporters as sexp + import spikeinterface.comparsion as scmp + import spikeinterface.curation as scur + import spikeinterface.sortingcomponents as sc + import spikeinterface.widgets as sw + +to import any of the other modules you wish to use. + +The benefit of this approach is that it is lighter and faster than importing the whole package and allows +you to choose which of the modules you actually want to use. If you don't plan to export the results out of +SpikeInterface than you don't have to :code:`import spikeinterface.exporters`. Additionally the documentation +of the package is set-up in a modular fashion, so if you have a problem with :code:`spikeinterface.curation`, +you will know to go to the :code:`curation` section of this documention. The disadvantage of this approach is +that you have more aliases to keep track of. + + +Flat Import +----------- + +A second option is to import the SpikeInterface package in :code:`full` mode. This would be similar to +what is seen with packages like NumPy (:code:`np`) or Pandas (:code:`pd`). To accomplish this one does: + + +.. code-block:: python + + import spikeinterface.full as si + + +This import statement will import all of SpikeInterface modules as one flattened module. +Note that importing :code:`spikeinterface.full` will take a few extra seconds, because some modules use +just-in-time :code:`numba` compilation performed at the time of import. +We recommend this approach for advanced users, since it requires a deeper knowledge of the API. The advantage +being that users with advanced API knowledge can access all functions using one alias. + + +Importing Individual Functions +------------------------------ + +Finally, some users may find it useful to have extremely light imports and only import the exact functions +they plan to use. This can easily be accomplished by importing functions directly into the name space. + +For example: + +.. code-block:: python + + from spikeinterface.preprocessing import bandpass_filter, common_reference + from spikeinterface.core import extract_waveforms + from spikeinterface.extractors import read_binary + +As mentioned this approach only imports exactly what you plan on using so is the most minimalist. It does require +knowledge of the API to know which module to pull a function from. diff --git a/doc/index.rst b/doc/index.rst index df76a1a4c2..57ce14ed44 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -40,6 +40,7 @@ SpikeInterface is made of several modules to deal with different aspects of the overview installation + import modules/index how_to/index modules_gallery/index diff --git a/doc/installation.rst b/doc/installation.rst index acc5117249..05472591be 100644 --- a/doc/installation.rst +++ b/doc/installation.rst @@ -14,7 +14,7 @@ To install the current release version, you can use: The :code:`[full]` option installs all the extra dependencies for all the different sub-modules. -Note that if using Z shell (:code:`zsh` - the default shell on mac), you will need to use quotes (:code:`pip install "spikeinterface[full]"`). +Note that if using Z shell (:code:`zsh` - the default shell on macOS), you will need to use quotes (:code:`pip install "spikeinterface[full]"`). To install all interactive widget backends, you can use: @@ -63,14 +63,14 @@ as :code:`spikeinterface` strongly relies on these packages to interface with va It is also sometimes useful to have local copies of :code:`neo` and :code:`probeinterface` to make changes to the code. To achieve this, repeat the first set of commands, -replacing `https://github.com/SpikeInterface/spikeinterface.git` with the appropriate repository in the first code block of this section. +replacing :code:`https://github.com/SpikeInterface/spikeinterface.git` with the appropriate repository in the first code block of this section. For beginners ------------- We provide some installation tips for beginners in Python here: -https://github.com/SpikeInterface/spikeinterface/tree/master/installation_tips +https://github.com/SpikeInterface/spikeinterface/tree/main/installation_tips @@ -89,12 +89,16 @@ Requirements Sub-modules have more dependencies, so you should also install: * zarr + * h5py * scipy * pandas * xarray - * sklearn + * scikit-learn * networkx * matplotlib + * numba + * distinctipy + * cude-python (for non-macOS users) All external spike sorters can be either run inside containers (Docker or Singularity - see :ref:`containerizedsorters`) diff --git a/doc/modules/core.rst b/doc/modules/core.rst index 4c03950b1d..b25648a7a0 100644 --- a/doc/modules/core.rst +++ b/doc/modules/core.rst @@ -22,37 +22,6 @@ All classes support: * multiple segments, where each segment is a contiguous piece of data (recording, sorting, events). -Import rules ------------- - -Importing the SpikeInterface module - -.. code-block:: python - - import spikeinterface as si - -will only import the :code:`core` module. Other submodules must be imported separately: - -.. code-block:: python - - import spikeinterface.extractors as se - import spikeinterface.sorters as ss - import spikeinterface.widgets as sw - - -A second option is to import the SpikeInterface package in :code:`full` mode: - -.. code-block:: python - - import spikeinterface.full as si - -This import statement will import all of SpikeInterface modules as a flattened module. -Note that importing :code:`spikeinterface.full` will take a few extra seconds, because some modules use -just-in-time :code:`numba` compilation performed at the time of import. -We recommend this approach to advanced users, since it requires a deeper knowledge of the API. - - - Recording --------- From c40b3ad715ac2a57edf67b48a71cafd99a47603b Mon Sep 17 00:00:00 2001 From: Zach McKenzie <92116279+zm711@users.noreply.github.com> Date: Wed, 8 Nov 2023 08:02:50 -0500 Subject: [PATCH 02/48] one round of edits --- doc/import.rst | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/doc/import.rst b/doc/import.rst index 8024ae5be8..eaeec6002e 100644 --- a/doc/import.rst +++ b/doc/import.rst @@ -38,8 +38,8 @@ to import any of the other modules you wish to use. The benefit of this approach is that it is lighter and faster than importing the whole package and allows you to choose which of the modules you actually want to use. If you don't plan to export the results out of -SpikeInterface than you don't have to :code:`import spikeinterface.exporters`. Additionally the documentation -of the package is set-up in a modular fashion, so if you have a problem with :code:`spikeinterface.curation`, +SpikeInterface then you don't have to :code:`import spikeinterface.exporters`. Additionally the documentation +of the package is set-up in a modular fashion, so if you have a problem with the module :code:`spikeinterface.curation`, you will know to go to the :code:`curation` section of this documention. The disadvantage of this approach is that you have more aliases to keep track of. @@ -56,11 +56,11 @@ what is seen with packages like NumPy (:code:`np`) or Pandas (:code:`pd`). To ac import spikeinterface.full as si -This import statement will import all of SpikeInterface modules as one flattened module. +This import statement will import all of the SpikeInterface modules as one flattened module. Note that importing :code:`spikeinterface.full` will take a few extra seconds, because some modules use just-in-time :code:`numba` compilation performed at the time of import. We recommend this approach for advanced users, since it requires a deeper knowledge of the API. The advantage -being that users with advanced API knowledge can access all functions using one alias. +being that users can access all functions using one alias. Importing Individual Functions @@ -77,5 +77,5 @@ For example: from spikeinterface.core import extract_waveforms from spikeinterface.extractors import read_binary -As mentioned this approach only imports exactly what you plan on using so is the most minimalist. It does require -knowledge of the API to know which module to pull a function from. +As mentioned this approach only imports exactly what you plan on using so it is the most minimalist. It does require +knowledge of the API to know which module to pull a function from. It could also lead to naming clashes if pulling functions directly from other scientific libraries. From 5d318fa22194d7d58dcc7abe1864d86418783e22 Mon Sep 17 00:00:00 2001 From: Zach McKenzie <92116279+zm711@users.noreply.github.com> Date: Wed, 8 Nov 2023 08:13:54 -0500 Subject: [PATCH 03/48] Sam's hot takes Co-authored-by: Garcia Samuel --- doc/import.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/import.rst b/doc/import.rst index eaeec6002e..5da9389df0 100644 --- a/doc/import.rst +++ b/doc/import.rst @@ -59,8 +59,8 @@ what is seen with packages like NumPy (:code:`np`) or Pandas (:code:`pd`). To ac This import statement will import all of the SpikeInterface modules as one flattened module. Note that importing :code:`spikeinterface.full` will take a few extra seconds, because some modules use just-in-time :code:`numba` compilation performed at the time of import. -We recommend this approach for advanced users, since it requires a deeper knowledge of the API. The advantage -being that users can access all functions using one alias. +We recommend this approach for advanced (or lazy) users, since it requires a deeper knowledge of the API. The advantage +being that users can access all functions using one alias without the need of memorizing all aliases. Importing Individual Functions From 17c91a27eb4d2063848d45573b19fb1bf7534f5e Mon Sep 17 00:00:00 2001 From: Zach McKenzie <92116279+zm711@users.noreply.github.com> Date: Thu, 9 Nov 2023 05:58:26 -0500 Subject: [PATCH 04/48] Heberto's fix Co-authored-by: Heberto Mayorquin --- doc/installation.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/installation.rst b/doc/installation.rst index 05472591be..08f9077a1d 100644 --- a/doc/installation.rst +++ b/doc/installation.rst @@ -98,7 +98,7 @@ Sub-modules have more dependencies, so you should also install: * matplotlib * numba * distinctipy - * cude-python (for non-macOS users) + * cuda-python (for non-macOS users) All external spike sorters can be either run inside containers (Docker or Singularity - see :ref:`containerizedsorters`) From 6126888e49795d58a6284e7a5540dafbc56264ca Mon Sep 17 00:00:00 2001 From: Zach McKenzie <92116279+zm711@users.noreply.github.com> Date: Thu, 9 Nov 2023 06:55:35 -0500 Subject: [PATCH 05/48] some updates based on feedback --- doc/import.rst | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/doc/import.rst b/doc/import.rst index 5da9389df0..19fc74c3a8 100644 --- a/doc/import.rst +++ b/doc/import.rst @@ -36,19 +36,22 @@ to import the :code:`core` module followed by: to import any of the other modules you wish to use. -The benefit of this approach is that it is lighter and faster than importing the whole package and allows -you to choose which of the modules you actually want to use. If you don't plan to export the results out of -SpikeInterface then you don't have to :code:`import spikeinterface.exporters`. Additionally the documentation -of the package is set-up in a modular fashion, so if you have a problem with the module :code:`spikeinterface.curation`, -you will know to go to the :code:`curation` section of this documention. The disadvantage of this approach is -that you have more aliases to keep track of. +The benefit of this approach is that it is lighter than importing the whole library as a flat module and allows +you to choose which of the modules you actually want to use. It also reminds you what step of the pipeline each +submodule is meant to be used for. If you don't plan to export the results out of SpikeInterface then you +don't have to :code:`import spikeinterface.exporters`. Additionally the documentation of SpikeInterface is set-up +in a modular fashion, so if you have a problem with the submodule :code:`spikeinterface.curation`,you will know +to go to the :code:`curation` section of this documention. The disadvantage of this approach is that you have +more aliases to keep track of. Flat Import ----------- A second option is to import the SpikeInterface package in :code:`full` mode. This would be similar to -what is seen with packages like NumPy (:code:`np`) or Pandas (:code:`pd`). To accomplish this one does: +what is seen with packages like NumPy (:code:`np`) or Pandas (:code:`pd`), which offer the majority of +their functionality with a single alias and the option to import additional functionality separately. +To accomplish this one does: .. code-block:: python @@ -78,4 +81,5 @@ For example: from spikeinterface.extractors import read_binary As mentioned this approach only imports exactly what you plan on using so it is the most minimalist. It does require -knowledge of the API to know which module to pull a function from. It could also lead to naming clashes if pulling functions directly from other scientific libraries. +knowledge of the API to know which module to pull a function from. It could also lead to naming clashes if pulling +functions directly from other scientific libraries. Type :code:`import this` for more information. From ec09322d4738557e104aee264988c1a477a212d2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 9 Nov 2023 11:55:53 +0000 Subject: [PATCH 06/48] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- doc/import.rst | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/doc/import.rst b/doc/import.rst index 19fc74c3a8..430e53ad4e 100644 --- a/doc/import.rst +++ b/doc/import.rst @@ -38,10 +38,10 @@ to import any of the other modules you wish to use. The benefit of this approach is that it is lighter than importing the whole library as a flat module and allows you to choose which of the modules you actually want to use. It also reminds you what step of the pipeline each -submodule is meant to be used for. If you don't plan to export the results out of SpikeInterface then you -don't have to :code:`import spikeinterface.exporters`. Additionally the documentation of SpikeInterface is set-up -in a modular fashion, so if you have a problem with the submodule :code:`spikeinterface.curation`,you will know -to go to the :code:`curation` section of this documention. The disadvantage of this approach is that you have +submodule is meant to be used for. If you don't plan to export the results out of SpikeInterface then you +don't have to :code:`import spikeinterface.exporters`. Additionally the documentation of SpikeInterface is set-up +in a modular fashion, so if you have a problem with the submodule :code:`spikeinterface.curation`,you will know +to go to the :code:`curation` section of this documention. The disadvantage of this approach is that you have more aliases to keep track of. @@ -49,8 +49,8 @@ Flat Import ----------- A second option is to import the SpikeInterface package in :code:`full` mode. This would be similar to -what is seen with packages like NumPy (:code:`np`) or Pandas (:code:`pd`), which offer the majority of -their functionality with a single alias and the option to import additional functionality separately. +what is seen with packages like NumPy (:code:`np`) or Pandas (:code:`pd`), which offer the majority of +their functionality with a single alias and the option to import additional functionality separately. To accomplish this one does: @@ -81,5 +81,5 @@ For example: from spikeinterface.extractors import read_binary As mentioned this approach only imports exactly what you plan on using so it is the most minimalist. It does require -knowledge of the API to know which module to pull a function from. It could also lead to naming clashes if pulling +knowledge of the API to know which module to pull a function from. It could also lead to naming clashes if pulling functions directly from other scientific libraries. Type :code:`import this` for more information. From 7b06d495a75083bd879042495e6e32c7e30f518f Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Sat, 11 Nov 2023 17:04:26 -0500 Subject: [PATCH 07/48] update for api changes + typos --- doc/install_sorters.rst | 76 +++++++++++++------- doc/modules/core.rst | 112 ++++++++++++++++++------------ doc/modules/curation.rst | 3 +- doc/modules/exporters.rst | 12 ++-- doc/modules/postprocessing.rst | 2 +- doc/modules/sorters.rst | 10 +-- doc/modules/sortingcomponents.rst | 18 ++--- doc/modules/widgets.rst | 16 ++--- doc/viewers.rst | 5 +- 9 files changed, 149 insertions(+), 105 deletions(-) diff --git a/doc/install_sorters.rst b/doc/install_sorters.rst index 10a3185c5c..e805f03eed 100644 --- a/doc/install_sorters.rst +++ b/doc/install_sorters.rst @@ -32,8 +32,8 @@ Some novel spike sorting algorithms are implemented directly in SpikeInterface u :py:mod:`spikeinterface.sortingcomponents` module. Checkout the :ref:`si_based` section of this page for more information! -If you experience installation problems please directly contact the authors of theses tools or write on the -related mailing list, google group, etc. +If you experience installation problems please directly contact the authors of these tools or write on the +related mailing list, google group, GitHub issue page, etc. Please feel free to enhance this document with more installation tips. @@ -251,31 +251,6 @@ Combinato # or using CombinatoSorter.set_combinato_path() -Klusta (LEGACY) -^^^^^^^^^^^^^^^ - -* Python -* Requires SpikeInterface<0.96.0 (and Python 3.7) -* Url: https://github.com/kwikteam/klusta -* Authors: Cyrille Rossant, Shabnam Kadir, Dan Goodman, Max Hunter, Kenneth Harris -* Installation:: - - pip install Cython h5py tqdm - pip install click klusta klustakwik2 - -* See also: https://github.com/kwikteam/phy - - -Yass (LEGACY) -^^^^^^^^^^^^^ - -* Python, CUDA, torch -* Requires SpikeInterface<0.96.0 (and Python 3.7) -* Url: https://github.com/paninski-lab/yass -* Authors: JinHyung Lee, Catalin Mitelut, Liam Paninski -* Installation:: - - https://github.com/paninski-lab/yass/wiki/Installation-Local .. _si_based: @@ -302,3 +277,50 @@ working not only at peak times but at all times, recovering more spikes close to pip install hdbscan pip install spikeinterface pip install numba (or conda install numba as recommended by conda authors) + + +Tridesclous2 +^^^^^^^^^^^^ + +This is an upgraded version of Tridesclous, natively written in SpikeInterface. +#Same add his notes. + +* Python +* Requires: HDBSCAN and Numba +* Authors: Samuel Garcia +* Installation:: + + pip install hdbscan + pip install spikeinterface + pip install numba + + + +Legacy Sorters +-------------- + +Klusta (LEGACY) +^^^^^^^^^^^^^^^ + +* Python +* Requires SpikeInterface<0.96.0 (and Python 3.7) +* Url: https://github.com/kwikteam/klusta +* Authors: Cyrille Rossant, Shabnam Kadir, Dan Goodman, Max Hunter, Kenneth Harris +* Installation:: + + pip install Cython h5py tqdm + pip install click klusta klustakwik2 + +* See also: https://github.com/kwikteam/phy + + +Yass (LEGACY) +^^^^^^^^^^^^^ + +* Python, CUDA, torch +* Requires SpikeInterface<0.96.0 (and Python 3.7) +* Url: https://github.com/paninski-lab/yass +* Authors: JinHyung Lee, Catalin Mitelut, Liam Paninski +* Installation:: + + https://github.com/paninski-lab/yass/wiki/Installation-Local diff --git a/doc/modules/core.rst b/doc/modules/core.rst index 4c03950b1d..ef7b266a86 100644 --- a/doc/modules/core.rst +++ b/doc/modules/core.rst @@ -56,7 +56,7 @@ We recommend this approach to advanced users, since it requires a deeper knowled Recording --------- -The :py:class:`~spikeinterface.core.BaseRecording` class serves as basis for all +The :py:class:`~spikeinterface.core.BaseRecording` class serves as the basis for all :code:`Recording` classes. It interfaces with the raw traces and has the following features: @@ -86,7 +86,7 @@ with 16 channels: # retrieve raw traces between frames 100 and 200 traces = recording.get_traces(start_frame=100, end_frame=200, segment_index=0) - # retrieve raw traces only for the first 4 of the channels + # retrieve raw traces only for the first 4 channels traces_slice = recording.get_traces(start_frame=100, end_frame=200, segment_index=0, channel_ids=channel_ids[:4]) # retrieve traces after scaling to uV @@ -119,7 +119,7 @@ with 16 channels: # 'recording_by_group' is a dict with group as keys (0,1,2,3) and channel # sliced recordings as values - # set times (for synchronization) - assume out times start at 300 seconds + # set times (for synchronization) - assume our times start at 300 seconds timestamps = np.arange(num_samples) / sampling_frequency + 300 recording.set_times(timestamps, segment_index=0) @@ -127,7 +127,7 @@ with 16 channels: Sorting ------- -The :py:class:`~spikeinterface.core.BaseSorting` class serves as basis for all :code:`Sorting` classes. +The :py:class:`~spikeinterface.core.BaseSorting` class serves as the basis for all :code:`Sorting` classes. It interfaces with a spike-sorted output and has the following features: * retrieve spike trains of different units @@ -150,7 +150,7 @@ with 10 units: # retrieve spike trains for a unit (returned as sample indices) unit0 = unit_ids[0] spike_train = sorting.get_unit_spike_train(unit_id=unit0, segment_index=0) - # retrieve spikes between 100 and 200 + # retrieve spikes between frames 100 and 200 spike_train_slice = sorting.get_unit_spike_train(unit_id=unit0, start_frame=100, end_frame=200, segment_index=0) @@ -167,13 +167,13 @@ with 10 units: sorting.annotate(date="Spike sorted today") sorting.get_annotation(key="date") - # get new sorting with the first 10s of spike trains + # get new sorting within the first 10s of the spike trains sorting_slice_frames = sorting.frame_slice(start_frame=0, end_frame=int(10*sampling_frequency)) - # get new sorting with the first 4 units + # get new sorting with only the first 4 units sorting_select_units = sorting.select_units(unit_ids=unit_ids[:4]) - # register 'recording' from previous and get spike trains in seconds + # register 'recording' from the previous example and get the spike trains in seconds sorting.register_recording(recording) spike_train_s = sorting.get_unit_spike_train(unit_id=unit0, segment_index=0, return_times=True) @@ -183,10 +183,10 @@ with 10 units: Internally, any sorting object can construct 2 internal caches: - 1. a list (per segment) of dict (per unit) of numpy.array. This cache is usefull when accessing spiketrains unit - per unit across segments. - 2. a unique numpy.array with structured dtype aka "spikes vector". This is usefull for processing by small chunk of - time, like extract amplitudes from a recording. + 1. a list (per segment) of dict (per unit) of numpy.array. This cache is useful when accessing spike trains on a unit + per unit basis across segments. + 2. a unique numpy.array with structured dtype aka "spikes vector". This is useful for processing by small chunks of + time, like for extracting amplitudes from a recording. WaveformExtractor @@ -194,12 +194,12 @@ WaveformExtractor The :py:class:`~spikeinterface.core.WaveformExtractor` class is the core object to combine a :py:class:`~spikeinterface.core.BaseRecording` and a :py:class:`~spikeinterface.core.BaseSorting` object. -Waveforms are very important for additional analysis, and the basis of several postprocessing and quality metrics +Waveforms are very important for additional analyses, and the basis of several postprocessing and quality metrics computations. The :py:class:`~spikeinterface.core.WaveformExtractor` allows us to: -* extract and waveforms +* extract waveforms * sub-sample spikes for waveform extraction * compute templates (i.e. average extracellular waveforms) with different modes * save waveforms in a folder (in numpy / `Zarr `_) for easy retrieval @@ -215,16 +215,28 @@ Finally, an existing :py:class:`~spikeinterface.core.WaveformExtractor` can be s .. code-block:: python # extract dense waveforms on 500 spikes per unit - we = extract_waveforms(recording, sorting, folder="waveforms", - max_spikes_per_unit=500) + we = extract_waveforms(recording=recording, + sorting=sorting, + sparse=False, + folder="waveforms", + max_spikes_per_unit=500 + overwrite=True) # same, but with parallel processing! (1s chunks processed by 8 jobs) job_kwargs = dict(n_jobs=8, chunk_duration="1s") - we = extract_waveforms(recording, sorting, folder="waveforms_par", - max_spikes_per_unit=500, overwrite=True, + we = extract_waveforms(recording=recording, + sorting=sorting, + sparse=False, + folder="waveforms_parallel", + max_spikes_per_unit=500, + overwrite=True, **job_kwargs) # same, but in-memory - we_mem = extract_waveforms(recording, sorting, folder=None, - mode="memory", max_spikes_per_unit=500, + we_mem = extract_waveforms(recording=recording, + sorting=sorting, + sparse=False, + folder=None, + mode="memory", + max_spikes_per_unit=500, **job_kwargs) # load pre-computed waveforms @@ -243,13 +255,16 @@ Finally, an existing :py:class:`~spikeinterface.core.WaveformExtractor` can be s template_stds = we.get_all_templates(mode="std") # save to Zarr - we_zarr = we.save(folder="waveforms.zarr", format="zarr") + we_zarr = we.save(folder="waveforms_zarr", format="zarr") # extract sparse waveforms (see Sparsity section) # this will use 50 spike per unit to estimate the sparsity of 40um radius for each unit - we_sparse = extract_waveforms(recording, sorting, folder="waveforms_sparse", - max_spikes_per_unit=500, sparse=True, - method="radius", radius_um=40, + we_sparse = extract_waveforms(recording=recording, + sorting=sorting, + folder="waveforms_sparse", + max_spikes_per_unit=500, + method="radius", + radius_um=40, num_spikes_for_sparsity=50) @@ -265,11 +280,14 @@ In order to make a waveform folder portable (e.g. copied to another location or # save the sorting object in the "processed" folder sorting = sorting.save(folder=processed_folder / "sorting") # extract waveforms using relative paths - we = extract_waveforms(recording, sorting, folder=processed_folder / "waveforms", + we = extract_waveforms(recording=recording, + sorting=sorting, + folder=processed_folder / "waveforms", use_relative_path=True) # the "processed" folder is now portable, and the waveform extractor can be reloaded # from a different location/machine (without loading the recording) - we_loaded = si.load_waveforms(processed_folder / "waveforms", with_recording=False) + we_loaded = si.load_waveforms(folder=processed_folder / "waveforms", + with_recording=False) Event @@ -278,7 +296,7 @@ Event The :py:class:`~spikeinterface.core.BaseEvent` class serves as basis for all :code:`Event` classes. It allows one to retrieve events and epochs (e.g. TTL pulses). Internally, events are represented as numpy arrays with a structured dtype. The structured dtype -must contain the :code:`time` field, which represent the event times in seconds. Other fields are +must contain the :code:`time` field, which represents the event times in seconds. Other fields are optional. Here we assume :code:`event` is a :py:class:`~spikeinterface.core.BaseEvent` object @@ -313,7 +331,7 @@ threshold and only record the times at which a peak was detected and the wavefor the peak. **NOTE**: while we support this class (mainly for legacy formats), this approach is a bad practice -and highly discouraged! Most modern spike sorters, in fact, require the raw traces to perform +and is highly discouraged! Most modern spike sorters, in fact, require the raw traces to perform template matching to recover spikes! Here we assume :code:`snippets` is a :py:class:`~spikeinterface.core.BaseSnippets` object @@ -374,9 +392,9 @@ The probe has 4 shanks, which can be loaded as separate groups (and spike sorted # set probe recording_w_probe = recording.set_probe(probe) - # set probe with group info + # set probe with group info and return a new recording object recording_w_probe = recording.set_probe(probe, group_mode="by_shank") - # set probe in place + # set probe in place, ie, modify the current recording recording.set_probe(probe, group_mode="by_shank", in_place=True) # retrieve probe @@ -420,7 +438,7 @@ There are several methods to compute sparsity, including: * | :code:`method="radius"`: selects the channels based on the channel locations. For example, using a | :code:`radius_um=40`, will select, for each unit, the channels which are whithin 40um of the channel with the - | largest amplitude (*extremum channel*). **This is the recommended method for high-density probes** + | largest amplitude (*the extremum channel*). **This is the recommended method for high-density probes** * | :code:`method="best_channels"`: selects the best :code:`num_channels` channels based on their amplitudes. Note that | in this case the selected channels might not be close to each other. * | :code:`method="threshold"`: selects channels based on an SNR threshold (:code:`threshold` argument) @@ -432,7 +450,7 @@ The computed sparsity can be used in several postprocessing and visualization fu .. code-block:: python - we_sparse = we.save(we, sparsity=sparsity, folder="waveforms_sparse") + we_sparse = we.save(waveform_extractor=we, sparsity=sparsity, folder="waveforms_sparse") The :code:`we_sparse` object will now have an associated sparsity (:code:`we.sparsity`), which is automatically taken into consideration for downstream analysis (with the :py:meth:`~spikeinterface.core.WaveformExtractor.is_sparse` @@ -442,6 +460,10 @@ waveforms folder. .. _save_load: +**NOTE:** As of SpikeInterface 0.99.0, :code:`extract_waveforms` now defaults to :code:`sparse=True`, so that default +behavior is to always have sparse waveforms. To have dense waveforms (the previous default behavior), remember to set +:code:`sparsity=False`. + Saving, loading, and compression -------------------------------- @@ -460,10 +482,12 @@ and annotations associated to the object. The save function also supports parallel processing to speed up the writing process. From a SpikeInterface folder, the saved object can be reloaded with the :code:`load_extractor()` function. -This saving/loading features enables to store SpikeInterface objects efficiently and to distribute processing. +This saving/loading features enables us to store SpikeInterface objects efficiently and to distribute processing. .. code-block:: python + # n_jobs is related to the number of processors you want to use + # n_jobs=-1 indicates to use all available job_kwargs = dict(n_jobs=8, chunk_duration="1s") # save recording to folder in binary (default) format recording_bin = recording.save(folder="recording", **job_kwargs) @@ -475,7 +499,7 @@ This saving/loading features enables to store SpikeInterface objects efficiently sorting_saved = sorting.save(folder="sorting") **NOTE:** the Zarr format by default applies data compression with :code:`Blosc.Zstandard` codec with BIT shuffling. -Any other Zarr-compatible compressor and filters can be applied using the :code:`compressor` and :code:`filters` +Any other Zarr-compatible compressors and filters can be applied using the :code:`compressor` and :code:`filters` arguments. For example, in this case we apply `LZMA `_ and use a `Delta `_ filter: @@ -550,7 +574,7 @@ In order to do this, one can use the :code:`Numpy*` classes, :py:class:`~spikein but they are not bound to a file. Also note the class :py:class:`~spikeinterface.core.SharedMemorySorting` which is very similar to -Similar to :py:class:`~spikeinterface.core.NumpySorting` but with an unerlying SharedMemory which is usefull for +Similar to :py:class:`~spikeinterface.core.NumpySorting` but with an unerlying SharedMemory which is useful for parallel computing. In this example, we create a recording and a sorting object from numpy objects: @@ -585,12 +609,12 @@ In this example, we create a recording and a sorting object from numpy objects: Any sorting object can be transformed into a :py:class:`~spikeinterface.core.NumpySorting` or -:py:class:`~spikeinterface.core.SharedMemorySorting` easily like this +:py:class:`~spikeinterface.core.SharedMemorySorting` easily like this: .. code-block:: python # turn any sortinto into NumpySorting - soring_np = sorting.to_numpy_sorting() + sorting_np = sorting.to_numpy_sorting() # or to SharedMemorySorting for parrallel computing sorting_shm = sorting.to_shared_memory_sorting() @@ -602,7 +626,7 @@ Manipulating objects: slicing, aggregating ------------------------------------------- :py:class:`~spikeinterface.core.BaseRecording` (and :py:class:`~spikeinterface.core.BaseSnippets`) -and :py:class:`~spikeinterface.core.BaseSorting` objects can be sliced in the time or channel/unit axis. +and :py:class:`~spikeinterface.core.BaseSorting` objects can be sliced on the time or channel/unit axis. This operations are completely lazy, as there is no data duplication. After slicing or aggregating, the new objects will be a *view* of the original ones. @@ -613,9 +637,9 @@ the new objects will be a *view* of the original ones. recording = read_spikeglx('np_folder') sorting =read_kilosrt('ks_folder') - # keep one channel every ten channels - keep_ids = rec.channel_ids[::10] - sub_recording = rec.channel_slice(channel_ids=keep_ids) + # keep one channel of every tenth channel + keep_ids = recording.channel_ids[::10] + sub_recording = recording.channel_slice(channel_ids=keep_ids) # keep between 5min and 12min fs = recording.sampling_frequency @@ -641,8 +665,8 @@ We can also aggregate (or stack) multiple sortings on the unit axis using the .. code-block:: python - sortingA = read_npz('sortingA.npz') - sortingB = read_npz('sortingB.npz') + sortingA = read_npz_sorting('sortingA.npz') + sortingB = read_npz_sorting('sortingB.npz') sorting_20_units = aggregate_units([sortingA, sortingB]) @@ -706,7 +730,7 @@ object: * :py:func:`~spikeinterface.core.get_chunk_with_margin`: gets traces with a left and right margin * :py:func:`~spikeinterface.core.get_closest_channels`: returns the :code:`num_channels` closest channels to each specified channel * :py:func:`~spikeinterface.core.get_channel_distances`: returns a square matrix with channel distances - * :py:func:`~spikeinterface.core.order_channels_by_depth`: gets channel order in depth: + * :py:func:`~spikeinterface.core.order_channels_by_depth`: gets channel order in depth Template tools diff --git a/doc/modules/curation.rst b/doc/modules/curation.rst index 23e9e20d96..d533cdcac8 100644 --- a/doc/modules/curation.rst +++ b/doc/modules/curation.rst @@ -108,8 +108,9 @@ The manual curation (including merges and labels) can be applied to a SpikeInter _ = compute_correlograms(waveform_extractor=we) # This loads the data to the cloud for web-based plotting and sharing + # curation=True required for allowing curation in the sortingview gui plot_sorting_summary(waveform_extractor=we, curation=True, backend='sortingview') - # we open the printed link URL in a browswe + # we open the printed link URL in a browser # - make manual merges and labeling # - from the curation box, click on "Save as snapshot (sha1://)" diff --git a/doc/modules/exporters.rst b/doc/modules/exporters.rst index 155050ddb0..d9c4be963f 100644 --- a/doc/modules/exporters.rst +++ b/doc/modules/exporters.rst @@ -31,11 +31,11 @@ The input of the :py:func:`~spikeinterface.exporters.export_to_phy` is a :code:` we = extract_waveforms(recording=recording, sorting=sorting, folder='waveforms', sparse=True) # some computations are done before to control all options - compute_spike_amplitudes(waveform_extractor=we) - compute_principal_components(waveform_extractor=we, n_components=3, mode='by_channel_global') + _ = compute_spike_amplitudes(waveform_extractor=we) + _ = compute_principal_components(waveform_extractor=we, n_components=3, mode='by_channel_global') # the export process is fast because everything is pre-computed - export_to_phy(wavefor_extractor=we, output_folder='path/to/phy_folder') + export_to_phy(waveform_extractor=we, output_folder='path/to/phy_folder') @@ -74,9 +74,9 @@ with many units! we = extract_waveforms(recording=recording, sorting=sorting, folder='path/to/wf', sparse=True) # some computations are done before to control all options - compute_spike_amplitudes(waveform_extractor=we) - compute_correlograms(waveform_extractor=we) - compute_quality_metrics(waveform_extractor=we, metric_names=['snr', 'isi_violation', 'presence_ratio']) + _ = compute_spike_amplitudes(waveform_extractor=we) + - = compute_correlograms(waveform_extractor=we) + _ = compute_quality_metrics(waveform_extractor=we, metric_names=['snr', 'isi_violation', 'presence_ratio']) # the export process export_report(waveform_extractor=we, output_folder='path/to/spikeinterface-report-folder') diff --git a/doc/modules/postprocessing.rst b/doc/modules/postprocessing.rst index 112c6e367d..195413e2af 100644 --- a/doc/modules/postprocessing.rst +++ b/doc/modules/postprocessing.rst @@ -18,7 +18,7 @@ of a :code:`WaveformExtractor` will be saved along side the :code:`WaveformExtra This workflow is convenient for retrieval of time-consuming computations (such as pca or spike amplitudes) when reloading a :code:`WaveformExtractor`. -:py:class:`~spikeinterface.core.BaseWaveformExtractorExtension` objects are tightly connected to the +:py:class:`~spikeinterface.core.BaseWaveformExtractorExtension` objects are tightly connected to the parent :code:`WaveformExtractor` object, so that operations done on the :code:`WaveformExtractor`, such as saving, loading, or selecting units, will be automatically applied to all extensions. diff --git a/doc/modules/sorters.rst b/doc/modules/sorters.rst index 5040b01ec2..e0005e51fe 100644 --- a/doc/modules/sorters.rst +++ b/doc/modules/sorters.rst @@ -12,9 +12,9 @@ On the other hand SpikeInterface directly implements some internal sorters (**sp that do not depend on external tools, but depend on the :py:mod:`spikeinterface.sortingcomponents` module. **Note that internal sorters are currently experimental and under development**. -A drawback of using external sorters is the installation of these tools. Sometimes they need MATLAB, -specific versions of CUDA, specific gcc versions vary or even worse outdated versions of -Python/NumPy. In that case, SpikeInterface offer the mechanism of running external sorters inside a +A drawback of using external sorters is the separate installation of these tools. Sometimes they need MATLAB, +specific versions of CUDA, specific gcc versions or outdated versions of +Python/NumPy. In this case, SpikeInterface offer the mechanism of running external sorters inside a container (Docker/Singularity) with the sorter pre-installed. See :ref:`containerizedsorters`. @@ -244,7 +244,7 @@ There are three options: the current development version from the :code:`main` branch will be installed in the container. 3. **local copy**: if you installed :code:`spikeinterface` from source and you have some changes in your branch or fork - that are not in the :code:`main` branch, you can install a copy of your :code:`spikeinterface` packahe in the container. + that are not in the :code:`main` branch, you can install a copy of your :code:`spikeinterface` package in the container. To do so, you need to set en environment variable :code:`SPIKEINTERFACE_DEV_PATH` to the location where you cloned the :code:`spikeinterface` repo (e.g. on Linux: :code:`export SPIKEINTERFACE_DEV_PATH="path-to-spikeinterface-clone"`. @@ -397,7 +397,7 @@ to concatenate the recordings before spike sorting and how to split the sorted o on the concatenation. Note that some sorters (tridesclous, spykingcircus2) handle a multi-segments paradigm directly. In -that case we will use the :py:func:`~spikeinterface.core.append_recordings()` function. Many sorters +this case we will use the :py:func:`~spikeinterface.core.append_recordings()` function. Many sorters do not handle multi-segment, and in that case we will use the :py:func:`~spikeinterface.core.concatenate_recordings()` function. diff --git a/doc/modules/sortingcomponents.rst b/doc/modules/sortingcomponents.rst index 1e58972497..f33a0b3cf2 100644 --- a/doc/modules/sortingcomponents.rst +++ b/doc/modules/sortingcomponents.rst @@ -15,7 +15,7 @@ Another advantage of *modularization* is that we can accurately benchmark every For example, what is the performance of peak detection method 1 or 2, provided that the rest of the pipeline is the same? -For now, we have methods for: +Currently, we have methods for: * peak detection * peak localization * peak selection @@ -24,7 +24,7 @@ For now, we have methods for: * clustering * template matching -For some of theses steps, implementations are in a very early stage and are still a bit *drafty*. +For some of these steps, implementations are in a very early stage and are still a bit *drafty*. Signature and behavior may change from time to time in this alpha period development. You can also have a look `spikeinterface blog `_ where there are more detailed @@ -76,7 +76,7 @@ Different methods are available with the :code:`method` argument: **NOTE**: the torch implementations give slightly different results due to a different implementation. -Peak detection, as many sorting components, can be run in parallel. +Peak detection, as many of the other sorting components, can be run in parallel. Peak localization @@ -105,8 +105,8 @@ Currently, the following methods are implemented: * 'center_of_mass' * 'monopolar_triangulation' with optimizer='least_square' This method is from Julien Boussard and Erdem Varol from the Paninski lab. - This has been presented at [NeurIPS](https://nips.cc/Conferences/2021/ScheduleMultitrack?event=26709) - see also [here](https://openreview.net/forum?id=ohfi44BZPC4) + This has been presented at `NeurIPS `_ + see also `here `_ * 'monopolar_triangulation' with optimizer='minimize_with_log_penality' These methods are the same as implemented in :py:mod:`spikeinterface.postprocessing.unit_localization` @@ -133,7 +133,7 @@ Peak selection -------------- When too many peaks are detected a strategy can be used to select (or sub-sample) only some of them before clustering. -This is the strategy used by spyking-circus or tridesclous, for instance. +This is the strategy used by spyking-circus and tridesclous, for instance. Then, clustering is run on this subset of peaks, templates are extracted, and a template-matching step is run to find all spikes. @@ -219,7 +219,7 @@ Here is a short example that depends on the output of "Motion interpolation": from spikeinterface.sortingcomponents.motion_interpolation import InterpolateMotionRecording recording_corrected = InterpolateMotionRecording(recording=recording_with_drift, motion=motion, temporal_bins=temporal_bins, spatial_bins=spatial_bins - spatial_interpolation_method='kriging, + spatial_interpolation_method='kriging', border_mode='remove_channels') **Notes**: @@ -227,14 +227,14 @@ Here is a short example that depends on the output of "Motion interpolation": * :code:`border_mode` is a very important parameter. It controls dealing with the border because motion causes units on the border to not be present throughout the entire recording. We highly recommend the :code:`border_mode='remove_channels'` because this removes channels on the border that will be impacted by drift. Of course the larger the motion is - the more channels are removed. + the greater the number of channels that would be removed. Clustering ---------- The clustering step remains the central step of spike sorting. -Historically this step was separted into two distinct parts: feature reduction and clustering. +Historically this step was separated into two distinct parts: feature reduction and clustering. In SpikeInterface, we decided to regroup these two steps into the same module. This allows one to compute feature reduction 'on-the-fly' and avoid long computations and storage of large features. diff --git a/doc/modules/widgets.rst b/doc/modules/widgets.rst index f37b2a5a6f..d5a2ee87c6 100644 --- a/doc/modules/widgets.rst +++ b/doc/modules/widgets.rst @@ -70,7 +70,7 @@ To install it, run: .. code-block:: bash - pip install sortingview figurl-jupyter + pip install sortingview Internally, the processed data to be rendered are uploaded to a public bucket in the cloud, so that they can be visualized via the web (if :code:`generate_url=True`). @@ -78,7 +78,7 @@ When running in a Jupyter notebook or JupyterLab, the sortingview widget will al notebook! To set up the backend, you need to authenticate to `kachery-cloud` using your GitHub account by running -the following command (you will be prompted a link): +the following command (you will be prompted with a link): .. code-block:: bash @@ -196,13 +196,13 @@ The functions have the following additional arguments: .. code-block:: python # sortingview backend - w_ts = sw.plot_traces(recording=recording, backend="ipywidgets") - w_ss = sw.plot_sorting_summary(recording=recording, backend="sortingview") + w_ts = sw.plot_traces(recording=recording, backend="sortingview") + w_ss = sw.plot_sorting_summary(waveform_extractor = we, curation=True, backend="sortingview") **Output:** -* `Timeseries link `_ +* `plot_traces link `_ .. image:: ../images/sv_timeseries.png @@ -278,11 +278,7 @@ Available plotting functions * :py:func:`~spikeinterface.widgets.plot_unit_waveforms` (backends: :code:`matplotlib`, :code:`ipywidgets`) -Legacy plotting functions -^^^^^^^^^^^^^^^^^^^^^^^^^ - -These functions are still part of the package, but they are directly implemented in :code:`matplotlib` without the -more recend backend mechanism: +# Which have been moved over? * :py:func:`~spikeinterface.widgets.plot_rasters` * :py:func:`~spikeinterface.widgets.plot_probe_map` diff --git a/doc/viewers.rst b/doc/viewers.rst index 55463146ce..a906ee29db 100644 --- a/doc/viewers.rst +++ b/doc/viewers.rst @@ -24,7 +24,7 @@ spikeinterface-gui `spikeinterface-gui `_ is a local desktop application which is built on top of :code:`spikeinterface`. -It is the easiest and fastest way to inspect interactively a spike sorting output. +It is the easiest and fastest way to interactively inspect a spike sorting output. It's easy to install and ready to use! Authors: Samuel Garcia @@ -44,6 +44,7 @@ phy --- `phy `_ is the de-facto standard tool for manual curation of a sorting output. -The current drawback of :code:`phy` is that the dataset (including filtered signals and **all** waveforms of spikes) has to be copied in a separate folder and this is very time consuming process and occupies a lot of disk space. +The current drawback of :code:`phy` is that the dataset (including filtered signals and **all** waveforms of spikes) has to be copied +in a separate folder and this is very time consuming process and occupies a lot of disk space. Author : Cyrill Rossant From a0501e3c319799af33fe1c55415db4a51fb8e090 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Mon, 13 Nov 2023 10:49:38 +0100 Subject: [PATCH 08/48] WIP --- src/spikeinterface/sorters/internal/spyking_circus2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index a16b642dd5..a120d4e97a 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -169,4 +169,4 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): sorting = sorting.save(folder=sorting_folder) - return sorting + return sorting \ No newline at end of file From fcdca11d488c6fd0d92420a6462da704de81c0b7 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Mon, 13 Nov 2023 11:27:43 +0100 Subject: [PATCH 09/48] Baseline implementation of circus1 --- .../sorters/internal/spyking_circus2.py | 2 +- .../sortingcomponents/clustering/circus.py | 373 +++++++++--------- 2 files changed, 198 insertions(+), 177 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index a120d4e97a..53a72e2696 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -110,7 +110,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clustering_params.update({"noise_levels": noise_levels}) labels, peak_labels = find_cluster_from_peaks( - recording_f, selected_peaks, method="random_projections", method_kwargs=clustering_params + recording_f, selected_peaks, method="circus", method_kwargs=clustering_params ) ## We get the labels for our peaks diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 39f46475dc..83a92f1970 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -19,7 +19,53 @@ from spikeinterface.core import NumpySorting from spikeinterface.core import extract_waveforms from spikeinterface.core.recording_tools import get_channel_distances, get_random_data_chunks +from spikeinterface.core.waveform_tools import extract_waveforms_to_single_buffer +from spikeinterface.sortingcomponents.peak_selection import select_peaks +from spikeinterface.sortingcomponents.waveforms.temporal_pca import TemporalPCAProjection +from sklearn.decomposition import TruncatedSVD +import pickle, json +from spikeinterface.core.node_pipeline import ( + run_node_pipeline, + ExtractDenseWaveforms, + ExtractSparseWaveforms, + PeakRetriever, +) + + +def extract_waveform_at_max_channel(rec, peaks, ms_before=0.5, ms_after=1.5, **job_kwargs): + """ + Helper function to extractor waveforms at max channel from a peak list + + """ + n = rec.get_num_channels() + unit_ids = np.arange(n, dtype="int64") + sparsity_mask = np.eye(n, dtype="bool") + + spikes = np.zeros( + peaks.size, dtype=[("sample_index", "int64"), ("unit_index", "int64"), ("segment_index", "int64")] + ) + spikes["sample_index"] = peaks["sample_index"] + spikes["unit_index"] = peaks["channel_index"] + spikes["segment_index"] = peaks["segment_index"] + + nbefore = int(ms_before * rec.sampling_frequency / 1000.0) + nafter = int(ms_after * rec.sampling_frequency / 1000.0) + + all_wfs = extract_waveforms_to_single_buffer( + rec, + spikes, + unit_ids, + nbefore, + nafter, + mode="shared_memory", + return_scaled=False, + sparsity_mask=sparsity_mask, + copy=True, + **job_kwargs, + ) + + return all_wfs class CircusClustering: """ @@ -27,8 +73,6 @@ class CircusClustering: """ _default_params = { - "peak_locations": None, - "peak_localization_kwargs": {"method": "center_of_mass"}, "hdbscan_kwargs": { "min_cluster_size": 50, "allow_single_cluster": True, @@ -36,15 +80,17 @@ class CircusClustering: "cluster_selection_method": "leaf", }, "cleaning_kwargs": {}, - "tmp_folder": None, + "waveforms": {"ms_before": 2, "ms_after": 2, "max_spikes_per_unit": 100}, "radius_um": 100, - "n_pca": 10, - "max_spikes_per_unit": 200, - "ms_before": 1.5, - "ms_after": 2.5, - "cleaning_method": "dip", - "waveform_mode": "memmap", - "job_kwargs": {"n_jobs": -1, "chunk_memory": "10M"}, + "selection_method": "closest_to_centroid", + "n_svd": 10, + "ms_before": 1, + "ms_after": 1, + "random_seed": 42, + "noise_levels": None, + "shared_memory": True, + "tmp_folder": None, + "job_kwargs": {"n_jobs": os.cpu_count(), "chunk_memory": "100M", "verbose": True, "progress_bar": True}, } @classmethod @@ -70,131 +116,109 @@ def _check_params(cls, recording, peaks, params): @classmethod def main_function(cls, recording, peaks, params): - assert HAVE_HDBSCAN, "twisted clustering needs hdbscan to be installed" - - params = cls._check_params(recording, peaks, params) - d = params + assert HAVE_HDBSCAN, "random projections clustering need hdbscan to be installed" - if d["peak_locations"] is None: - from spikeinterface.sortingcomponents.peak_localization import localize_peaks + if "n_jobs" in params["job_kwargs"]: + if params["job_kwargs"]["n_jobs"] == -1: + params["job_kwargs"]["n_jobs"] = os.cpu_count() - peak_locations = localize_peaks(recording, peaks, **d["peak_localization_kwargs"], **d["job_kwargs"]) - else: - peak_locations = d["peak_locations"] - - tmp_folder = d["tmp_folder"] - if tmp_folder is not None: - tmp_folder.mkdir(exist_ok=True) + if "core_dist_n_jobs" in params["hdbscan_kwargs"]: + if params["hdbscan_kwargs"]["core_dist_n_jobs"] == -1: + params["hdbscan_kwargs"]["core_dist_n_jobs"] = os.cpu_count() - location_keys = ["x", "y"] - locations = np.stack([peak_locations[k] for k in location_keys], axis=1) - - chan_locs = recording.get_channel_locations() + d = params + verbose = d["job_kwargs"]["verbose"] peak_dtype = [("sample_index", "int64"), ("unit_index", "int64"), ("segment_index", "int64")] - spikes = np.zeros(peaks.size, dtype=peak_dtype) - spikes["sample_index"] = peaks["sample_index"] - spikes["segment_index"] = peaks["segment_index"] - spikes["unit_index"] = peaks["channel_index"] - - num_chans = recording.get_num_channels() - sparsity_mask = np.zeros((peaks.size, num_chans), dtype="bool") - - unit_inds = range(num_chans) - chan_distances = get_channel_distances(recording) - - for main_chan in unit_inds: - (closest_chans,) = np.nonzero(chan_distances[main_chan, :] <= params["radius_um"]) - sparsity_mask[main_chan, closest_chans] = True - - if params["waveform_mode"] == "shared_memory": - wf_folder = None - else: - assert params["tmp_folder"] is not None, "tmp_folder must be supplied" - wf_folder = params["tmp_folder"] / "sparse_snippets" - wf_folder.mkdir() fs = recording.get_sampling_frequency() - nbefore = int(params["ms_before"] * fs / 1000.0) - nafter = int(params["ms_after"] * fs / 1000.0) + ms_before = params["ms_before"] + ms_after = params["ms_after"] + nbefore = int(ms_before * fs / 1000.0) + nafter = int(ms_after * fs / 1000.0) num_samples = nbefore + nafter + num_chans = recording.get_num_channels() - wfs_arrays = extract_waveforms_to_buffers( - recording, - spikes, - unit_inds, - nbefore, - nafter, - mode=params["waveform_mode"], - return_scaled=False, - folder=wf_folder, - dtype=recording.get_dtype(), - sparsity_mask=sparsity_mask, - copy=(params["waveform_mode"] == "shared_memory"), - **params["job_kwargs"], - ) - - n_loc = len(location_keys) - import sklearn.decomposition, hdbscan + if d["noise_levels"] is None: + noise_levels = get_noise_levels(recording, return_scaled=False) + else: + noise_levels = d["noise_levels"] - noise_levels = get_noise_levels(recording, return_scaled=False) + np.random.seed(d["random_seed"]) - nb_clusters = 0 - peak_labels = np.zeros(len(spikes), dtype=np.int32) + if params["tmp_folder"] is None: + name = "".join(random.choices(string.ascii_uppercase + string.digits, k=8)) + tmp_folder = get_global_tmp_folder() / name + else: + tmp_folder = Path(params["tmp_folder"]).absolute() - noise = get_random_data_chunks( - recording, - return_scaled=False, - num_chunks_per_segment=params["max_spikes_per_unit"], - chunk_size=nbefore + nafter, - concatenated=False, - seed=None, - ) - noise = np.stack(noise, axis=0) + tmp_folder.mkdir(parents=True, exist_ok=True) - for main_chan, waveforms in wfs_arrays.items(): - idx = np.where(spikes["unit_index"] == main_chan)[0] - (channels,) = np.nonzero(sparsity_mask[main_chan]) - sub_noise = noise[:, :, channels] + # SVD for time compression + few_peaks = select_peaks(peaks, method="uniform", n_peaks=5000) + few_wfs = extract_waveform_at_max_channel(recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **params["job_kwargs"]) - if len(waveforms) > 0: - sub_waveforms = waveforms + wfs = few_wfs[:, :, 0] + tsvd = TruncatedSVD(params["n_svd"]) + tsvd.fit(wfs) - wfs = np.swapaxes(sub_waveforms, 1, 2).reshape(len(sub_waveforms), -1) - noise_wfs = np.swapaxes(sub_noise, 1, 2).reshape(len(sub_noise), -1) + model_folder = tmp_folder / "tsvd_model" - n_pca = min(d["n_pca"], len(wfs)) - pca = sklearn.decomposition.PCA(n_pca) + model_folder.mkdir(exist_ok=True) + with open(model_folder / "pca_model.pkl", "wb") as f: + pickle.dump(tsvd, f) - hdbscan_data = np.vstack((wfs, noise_wfs)) + model_params = { + "ms_before": ms_before, + "ms_after": ms_after, + "sampling_frequency": float(fs), + } - pca.fit(wfs) - hdbscan_data_pca = pca.transform(hdbscan_data) - clustering = hdbscan.hdbscan(hdbscan_data_pca, **d["hdbscan_kwargs"]) + with open(model_folder / "params.json", "w") as f: + json.dump(model_params, f) - noise_labels = clustering[0][len(wfs) :] - valid_labels = clustering[0][: len(wfs)] + # features + features_folder = model_folder / "features" + node0 = PeakRetriever(recording, peaks) - shared_indices = np.intersect1d(np.unique(noise_labels), np.unique(valid_labels)) - for l in shared_indices: - idx_noise = noise_labels == l - idx_valid = valid_labels == l - if np.sum(idx_noise) > np.sum(idx_valid): - valid_labels[idx_valid] = -1 + radius_um = params["radius_um"] + node3 = ExtractSparseWaveforms( + recording, + parents=[node0], + return_output=False, + ms_before=ms_before, + ms_after=ms_after, + radius_um=radius_um, + ) - if np.unique(valid_labels).min() == -1: - valid_labels += 1 + node4 = TemporalPCAProjection( + recording, parents=[node0, node3], return_output=True, model_folder_path=model_folder + ) - for l in np.unique(valid_labels): - idx_valid = valid_labels == l - if np.sum(idx_valid) < d["hdbscan_kwargs"]["min_cluster_size"]: - valid_labels[idx_valid] = -1 + # pipeline_nodes = [node0, node1, node2, node3, node4] + pipeline_nodes = [node0, node3, node4] - peak_labels[idx] = valid_labels + nb_clusters + all_pc_data = run_node_pipeline( + recording, + pipeline_nodes, + params["job_kwargs"], + job_name="extracting PCs", + ) - labels = np.unique(valid_labels) - labels = labels[labels >= 0] - nb_clusters += len(labels) + peak_labels = -1 * np.ones(len(peaks), dtype=int) + nb_clusters = 0 + for c in np.unique(peaks['channel_index']): + mask = peaks['channel_index'] == c + tsvd = TruncatedSVD(params["n_svd"]) + sub_data = all_pc_data[mask] + hdbscan_data = tsvd.fit_transform(sub_data.reshape(len(sub_data), -1)) + clustering = hdbscan.hdbscan(hdbscan_data, **d['hdbscan_kwargs']) + local_labels = clustering[0] + valid_clusters = local_labels > -1 + if np.sum(valid_clusters) > 0: + local_labels[valid_clusters] += nb_clusters + peak_labels[mask] = local_labels + nb_clusters += len(np.unique(local_labels[valid_clusters])) labels = np.unique(peak_labels) labels = labels[labels >= 0] @@ -202,11 +226,22 @@ def main_function(cls, recording, peaks, params): best_spikes = {} nb_spikes = 0 + import sklearn + all_indices = np.arange(0, peak_labels.size) + max_spikes = params["waveforms"]["max_spikes_per_unit"] + selection_method = params["selection_method"] + for unit_ind in labels: mask = peak_labels == unit_ind - best_spikes[unit_ind] = np.random.permutation(all_indices[mask])[: params["max_spikes_per_unit"]] + if selection_method == "closest_to_centroid": + data = all_pc_data[mask].reshape(np.sum(mask), -1) + centroid = np.median(data, axis=0) + distances = sklearn.metrics.pairwise_distances(centroid[np.newaxis, :], data)[0] + best_spikes[unit_ind] = all_indices[mask][np.argsort(distances)[:max_spikes]] + elif selection_method == "random": + best_spikes[unit_ind] = np.random.permutation(all_indices[mask])[:max_spikes] nb_spikes += best_spikes[unit_ind].size spikes = np.zeros(nb_spikes, dtype=peak_dtype) @@ -222,72 +257,58 @@ def main_function(cls, recording, peaks, params): spikes["segment_index"] = peaks[mask]["segment_index"] spikes["unit_index"] = peak_labels[mask] - if params["waveform_mode"] == "shared_memory": - wf_folder = None + if verbose: + print("We found %d raw clusters, starting to clean with matching..." % (len(labels))) + + sorting_folder = tmp_folder / "sorting" + unit_ids = np.arange(len(np.unique(spikes["unit_index"]))) + sorting = NumpySorting(spikes, fs, unit_ids=unit_ids) + + if params["shared_memory"]: + waveform_folder = None + mode = "memory" else: - assert params["tmp_folder"] is not None, "tmp_folder must be supplied" - wf_folder = params["tmp_folder"] / "dense_snippets" - wf_folder.mkdir() - - cleaning_method = params["cleaning_method"] - - print(f"We found {len(labels)} raw clusters, starting to clean with {cleaning_method}...") - - if cleaning_method == "cosine": - wfs_arrays = extract_waveforms_to_buffers( - recording, - spikes, - labels, - nbefore, - nafter, - mode=params["waveform_mode"], - return_scaled=False, - folder=wf_folder, - dtype=recording.get_dtype(), - sparsity_mask=None, - copy=(params["waveform_mode"] == "shared_memory"), - **params["job_kwargs"], - ) - - labels, peak_labels = remove_duplicates( - wfs_arrays, noise_levels, peak_labels, num_samples, num_chans, **params["cleaning_kwargs"] - ) - - elif cleaning_method == "dip": - wfs_arrays = extract_waveforms_to_buffers( - recording, - spikes, - labels, - nbefore, - nafter, - mode=params["waveform_mode"], - return_scaled=False, - folder=wf_folder, - dtype=recording.get_dtype(), - sparsity_mask=None, - copy=(params["waveform_mode"] == "shared_memory"), - **params["job_kwargs"], - ) - - labels, peak_labels = remove_duplicates_via_dip(wfs_arrays, peak_labels) - - elif cleaning_method == "matching": - name = "".join(random.choices(string.ascii_uppercase + string.digits, k=8)) - tmp_folder = Path(os.path.join(get_global_tmp_folder(), name)) - - sorting = NumpySorting.from_times_labels(spikes["sample_index"], spikes["unit_index"], fs) - we = extract_waveforms( - recording, - sorting, - tmp_folder, - overwrite=True, - ms_before=params["ms_before"], - ms_after=params["ms_after"], - **params["job_kwargs"], - ) - labels, peak_labels = remove_duplicates_via_matching(we, peak_labels, job_kwargs=params["job_kwargs"]) + waveform_folder = tmp_folder / "waveforms" + mode = "folder" + sorting = sorting.save(folder=sorting_folder) + + we = extract_waveforms( + recording, + sorting, + waveform_folder, + **params["job_kwargs"], + **params["waveforms"], + return_scaled=False, + mode=mode, + ) + + cleaning_matching_params = params["job_kwargs"].copy() + for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: + if value in cleaning_matching_params: + cleaning_matching_params.pop(value) + cleaning_matching_params["chunk_duration"] = "100ms" + cleaning_matching_params["n_jobs"] = 1 + cleaning_matching_params["verbose"] = False + cleaning_matching_params["progress_bar"] = False + + cleaning_params = params["cleaning_kwargs"].copy() + cleaning_params["tmp_folder"] = tmp_folder + + labels, peak_labels = remove_duplicates_via_matching( + we, noise_levels, peak_labels, job_kwargs=cleaning_matching_params, **cleaning_params + ) + + del we, sorting + + if params["tmp_folder"] is None: shutil.rmtree(tmp_folder) + else: + if not params["shared_memory"]: + shutil.rmtree(tmp_folder / "waveforms") + shutil.rmtree(tmp_folder / "sorting") - print(f"We kept {len(labels)} non-duplicated clusters...") + if verbose: + print("We kept %d non-duplicated clusters..." % len(labels)) return labels, peak_labels + From d5dd8447cdd535d4664cc2243c5ebd608602db65 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Mon, 13 Nov 2023 11:34:22 +0100 Subject: [PATCH 10/48] WIP --- src/spikeinterface/sortingcomponents/clustering/circus.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 83a92f1970..4a2f36e1e1 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -74,7 +74,7 @@ class CircusClustering: _default_params = { "hdbscan_kwargs": { - "min_cluster_size": 50, + "min_cluster_size": 20, "allow_single_cluster": True, "core_dist_n_jobs": -1, "cluster_selection_method": "leaf", @@ -83,7 +83,7 @@ class CircusClustering: "waveforms": {"ms_before": 2, "ms_after": 2, "max_spikes_per_unit": 100}, "radius_um": 100, "selection_method": "closest_to_centroid", - "n_svd": 10, + "n_svd": 5, "ms_before": 1, "ms_after": 1, "random_seed": 42, From a00c4374abc257cf71b8145db18df06c2883e8d0 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 14 Nov 2023 14:43:12 +0100 Subject: [PATCH 11/48] Circus 1 like --- .../comparison/groundtruthstudy.py | 2 +- .../sorters/internal/spyking_circus2.py | 2 +- .../sortingcomponents/clustering/circus.py | 29 +++---------------- .../clustering/clustering_tools.py | 2 +- 4 files changed, 7 insertions(+), 28 deletions(-) diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index 0d08922543..adc2898071 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -21,7 +21,7 @@ # This is to separate names when the key are tuples when saving folders -_key_separator = " ## " +_key_separator = "--" class GroundTruthStudy: diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 53a72e2696..76cc9684fa 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -21,7 +21,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): _default_params = { "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, - "waveforms": {"max_spikes_per_unit": 200, "overwrite": True, "sparse": True, "method": "ptp", "threshold": 1}, + "waveforms": {"max_spikes_per_unit": 200, "overwrite": True, "sparse": True, "method": "ptp", "threshold": 0.5}, "filtering": {"freq_min": 150, "dtype": "float32"}, "detection": {"peak_sign": "neg", "detect_threshold": 5}, "selection": {"n_peaks_per_channel": 5000, "min_n_peaks": 20000}, diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 4a2f36e1e1..c193b1f93d 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -77,13 +77,13 @@ class CircusClustering: "min_cluster_size": 20, "allow_single_cluster": True, "core_dist_n_jobs": -1, - "cluster_selection_method": "leaf", + "cluster_selection_method": "eom", }, "cleaning_kwargs": {}, "waveforms": {"ms_before": 2, "ms_after": 2, "max_spikes_per_unit": 100}, "radius_um": 100, "selection_method": "closest_to_centroid", - "n_svd": 5, + "n_svd": [6, 6], "ms_before": 1, "ms_after": 1, "random_seed": 42, @@ -93,27 +93,6 @@ class CircusClustering: "job_kwargs": {"n_jobs": os.cpu_count(), "chunk_memory": "100M", "verbose": True, "progress_bar": True}, } - @classmethod - def _check_params(cls, recording, peaks, params): - d = params - params2 = params.copy() - - tmp_folder = params["tmp_folder"] - if params["waveform_mode"] == "memmap": - if tmp_folder is None: - name = "".join(random.choices(string.ascii_uppercase + string.digits, k=8)) - tmp_folder = Path(os.path.join(get_global_tmp_folder(), name)) - else: - tmp_folder = Path(tmp_folder) - tmp_folder.mkdir() - params2["tmp_folder"] = tmp_folder - elif params["waveform_mode"] == "shared_memory": - assert tmp_folder is None, "tmp_folder must be None for shared_memory" - else: - raise ValueError("'waveform_mode' must be 'memmap' or 'shared_memory'") - - return params2 - @classmethod def main_function(cls, recording, peaks, params): assert HAVE_HDBSCAN, "random projections clustering need hdbscan to be installed" @@ -159,7 +138,7 @@ def main_function(cls, recording, peaks, params): few_wfs = extract_waveform_at_max_channel(recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **params["job_kwargs"]) wfs = few_wfs[:, :, 0] - tsvd = TruncatedSVD(params["n_svd"]) + tsvd = TruncatedSVD(params["n_svd"][0]) tsvd.fit(wfs) model_folder = tmp_folder / "tsvd_model" @@ -209,7 +188,7 @@ def main_function(cls, recording, peaks, params): nb_clusters = 0 for c in np.unique(peaks['channel_index']): mask = peaks['channel_index'] == c - tsvd = TruncatedSVD(params["n_svd"]) + tsvd = TruncatedSVD(params["n_svd"][1]) sub_data = all_pc_data[mask] hdbscan_data = tsvd.fit_transform(sub_data.reshape(len(sub_data), -1)) clustering = hdbscan.hdbscan(hdbscan_data, **d['hdbscan_kwargs']) diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index b4938717f8..66fe660918 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -599,7 +599,7 @@ def remove_duplicates_via_matching( { "waveform_extractor": waveform_extractor, "noise_levels": noise_levels, - "amplitudes": [0.95, 1.05], + "amplitudes": [0.975, 1.025], "omp_min_sps": 0.05, } ) From a7e297018d76317cf5b4c1a55ba9a64c35178d0a Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 14 Nov 2023 20:50:32 +0100 Subject: [PATCH 12/48] WIP for circus 1 --- .../sorters/internal/spyking_circus2.py | 2 +- .../sortingcomponents/clustering/circus.py | 11 +++++------ .../sortingcomponents/clustering/clustering_tools.py | 2 +- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 76cc9684fa..e746883259 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -23,7 +23,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, "waveforms": {"max_spikes_per_unit": 200, "overwrite": True, "sparse": True, "method": "ptp", "threshold": 0.5}, "filtering": {"freq_min": 150, "dtype": "float32"}, - "detection": {"peak_sign": "neg", "detect_threshold": 5}, + "detection": {"peak_sign": "neg", "detect_threshold": 4}, "selection": {"n_peaks_per_channel": 5000, "min_n_peaks": 20000}, "localization": {}, "clustering": {}, diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index c193b1f93d..ef733224bd 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -161,7 +161,7 @@ def main_function(cls, recording, peaks, params): node0 = PeakRetriever(recording, peaks) radius_um = params["radius_um"] - node3 = ExtractSparseWaveforms( + node1 = ExtractSparseWaveforms( recording, parents=[node0], return_output=False, @@ -170,18 +170,17 @@ def main_function(cls, recording, peaks, params): radius_um=radius_um, ) - node4 = TemporalPCAProjection( - recording, parents=[node0, node3], return_output=True, model_folder_path=model_folder + node2 = TemporalPCAProjection( + recording, parents=[node0, node1], return_output=True, model_folder_path=model_folder ) - # pipeline_nodes = [node0, node1, node2, node3, node4] - pipeline_nodes = [node0, node3, node4] + pipeline_nodes = [node0, node1, node2] all_pc_data = run_node_pipeline( recording, pipeline_nodes, params["job_kwargs"], - job_name="extracting PCs", + job_name="extracting features", ) peak_labels = -1 * np.ones(len(peaks), dtype=int) diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 66fe660918..6b6aba892e 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -600,7 +600,7 @@ def remove_duplicates_via_matching( "waveform_extractor": waveform_extractor, "noise_levels": noise_levels, "amplitudes": [0.975, 1.025], - "omp_min_sps": 0.05, + "omp_min_sps": 0.1, } ) From 292c96d53d920f007283450dd6f781a619b054e5 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 14 Nov 2023 20:53:39 +0100 Subject: [PATCH 13/48] Legacy mode --- src/spikeinterface/sorters/internal/spyking_circus2.py | 2 +- .../sortingcomponents/clustering/clustering_tools.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index e746883259..fad744d143 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -21,7 +21,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): _default_params = { "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, - "waveforms": {"max_spikes_per_unit": 200, "overwrite": True, "sparse": True, "method": "ptp", "threshold": 0.5}, + "waveforms": {"max_spikes_per_unit": 200, "overwrite": True, "sparse": True, "method": "ptp", "threshold": 1}, "filtering": {"freq_min": 150, "dtype": "float32"}, "detection": {"peak_sign": "neg", "detect_threshold": 4}, "selection": {"n_peaks_per_channel": 5000, "min_n_peaks": 20000}, diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 6b6aba892e..aaddc15b46 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -600,7 +600,7 @@ def remove_duplicates_via_matching( "waveform_extractor": waveform_extractor, "noise_levels": noise_levels, "amplitudes": [0.975, 1.025], - "omp_min_sps": 0.1, + "omp_min_sps": 0.025, } ) From 91edc001961480b51e32b2c8b8d9f68590eedc82 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 14 Nov 2023 20:53:39 +0100 Subject: [PATCH 14/48] Legacy mode --- src/spikeinterface/sorters/internal/spyking_circus2.py | 2 +- .../sortingcomponents/clustering/clustering_tools.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index e746883259..fad744d143 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -21,7 +21,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): _default_params = { "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, - "waveforms": {"max_spikes_per_unit": 200, "overwrite": True, "sparse": True, "method": "ptp", "threshold": 0.5}, + "waveforms": {"max_spikes_per_unit": 200, "overwrite": True, "sparse": True, "method": "ptp", "threshold": 1}, "filtering": {"freq_min": 150, "dtype": "float32"}, "detection": {"peak_sign": "neg", "detect_threshold": 4}, "selection": {"n_peaks_per_channel": 5000, "min_n_peaks": 20000}, diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 6b6aba892e..66fe660918 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -600,7 +600,7 @@ def remove_duplicates_via_matching( "waveform_extractor": waveform_extractor, "noise_levels": noise_levels, "amplitudes": [0.975, 1.025], - "omp_min_sps": 0.1, + "omp_min_sps": 0.05, } ) From 5ffa8911ccfe72d1d532462fde2aa2f7a571c855 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 14 Nov 2023 21:48:24 +0100 Subject: [PATCH 15/48] Adding a legacy mode for the clustering, similar as circus 1 --- .../sorters/internal/spyking_circus2.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index fad744d143..955f228ad5 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -21,12 +21,11 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): _default_params = { "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, - "waveforms": {"max_spikes_per_unit": 200, "overwrite": True, "sparse": True, "method": "ptp", "threshold": 1}, + "waveforms": {"max_spikes_per_unit": 200, "overwrite": True, "sparse": True, "method": "energy", "threshold": 0.25}, "filtering": {"freq_min": 150, "dtype": "float32"}, "detection": {"peak_sign": "neg", "detect_threshold": 4}, "selection": {"n_peaks_per_channel": 5000, "min_n_peaks": 20000}, - "localization": {}, - "clustering": {}, + "clustering": {"legacy" : False}, "matching": {}, "apply_preprocessing": True, "shared_memory": True, @@ -109,8 +108,18 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clustering_params["tmp_folder"] = sorter_output_folder / "clustering" clustering_params.update({"noise_levels": noise_levels}) + if "legacy" in clustering_params: + legacy = clustering_params["legacy"] + else: + legacy = False + + if legacy: + clustering_method = "circus" + else: + clustering_method = "random_projections" + labels, peak_labels = find_cluster_from_peaks( - recording_f, selected_peaks, method="circus", method_kwargs=clustering_params + recording_f, selected_peaks, method=clustering_method, method_kwargs=clustering_params ) ## We get the labels for our peaks From bd3e3b04f73828323e9ceb9fd6fefca7f6a103f9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 14 Nov 2023 21:09:38 +0000 Subject: [PATCH 16/48] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sorters/internal/spyking_circus2.py | 12 +++++++++--- .../sortingcomponents/clustering/circus.py | 12 +++++++----- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 955f228ad5..29de8a4b0d 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -21,11 +21,17 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): _default_params = { "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, - "waveforms": {"max_spikes_per_unit": 200, "overwrite": True, "sparse": True, "method": "energy", "threshold": 0.25}, + "waveforms": { + "max_spikes_per_unit": 200, + "overwrite": True, + "sparse": True, + "method": "energy", + "threshold": 0.25, + }, "filtering": {"freq_min": 150, "dtype": "float32"}, "detection": {"peak_sign": "neg", "detect_threshold": 4}, "selection": {"n_peaks_per_channel": 5000, "min_n_peaks": 20000}, - "clustering": {"legacy" : False}, + "clustering": {"legacy": False}, "matching": {}, "apply_preprocessing": True, "shared_memory": True, @@ -178,4 +184,4 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): sorting = sorting.save(folder=sorting_folder) - return sorting \ No newline at end of file + return sorting diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index ef733224bd..d7d94f73cc 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -67,6 +67,7 @@ def extract_waveform_at_max_channel(rec, peaks, ms_before=0.5, ms_after=1.5, **j return all_wfs + class CircusClustering: """ hdbscan clustering on peak_locations previously done by localize_peaks() @@ -135,7 +136,9 @@ def main_function(cls, recording, peaks, params): # SVD for time compression few_peaks = select_peaks(peaks, method="uniform", n_peaks=5000) - few_wfs = extract_waveform_at_max_channel(recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **params["job_kwargs"]) + few_wfs = extract_waveform_at_max_channel( + recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **params["job_kwargs"] + ) wfs = few_wfs[:, :, 0] tsvd = TruncatedSVD(params["n_svd"][0]) @@ -185,12 +188,12 @@ def main_function(cls, recording, peaks, params): peak_labels = -1 * np.ones(len(peaks), dtype=int) nb_clusters = 0 - for c in np.unique(peaks['channel_index']): - mask = peaks['channel_index'] == c + for c in np.unique(peaks["channel_index"]): + mask = peaks["channel_index"] == c tsvd = TruncatedSVD(params["n_svd"][1]) sub_data = all_pc_data[mask] hdbscan_data = tsvd.fit_transform(sub_data.reshape(len(sub_data), -1)) - clustering = hdbscan.hdbscan(hdbscan_data, **d['hdbscan_kwargs']) + clustering = hdbscan.hdbscan(hdbscan_data, **d["hdbscan_kwargs"]) local_labels = clustering[0] valid_clusters = local_labels > -1 if np.sum(valid_clusters) > 0: @@ -289,4 +292,3 @@ def main_function(cls, recording, peaks, params): print("We kept %d non-duplicated clusters..." % len(labels)) return labels, peak_labels - From d970f8a27bfb00ab1180d2ca06d885496f686f52 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 14 Nov 2023 22:14:21 +0100 Subject: [PATCH 17/48] Minor edits --- .../sorters/internal/spyking_circus2.py | 11 +++++++++-- .../sortingcomponents/clustering/circus.py | 5 +++-- .../clustering/random_projections.py | 5 +++-- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 955f228ad5..c94c49a4bf 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -109,7 +109,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clustering_params.update({"noise_levels": noise_levels}) if "legacy" in clustering_params: - legacy = clustering_params["legacy"] + legacy = clustering_params.pop("legacy") else: legacy = False @@ -147,7 +147,14 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): waveforms_folder = sorter_output_folder / "waveforms" we = extract_waveforms( - recording_f, sorting, waveforms_folder, mode=mode, **waveforms_params, return_scaled=False + recording_f, + sorting, + waveforms_folder, + return_scaled=False, + precompute_template=["median"], + mode=mode, + **waveforms_params + ) ## We launch a OMP matching pursuit by full convolution of the templates and the raw traces diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index ef733224bd..8a511b5fdc 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -254,10 +254,11 @@ def main_function(cls, recording, peaks, params): recording, sorting, waveform_folder, - **params["job_kwargs"], - **params["waveforms"], return_scaled=False, + precompute_template=["median"], mode=mode, + **params["job_kwargs"], + **params["waveforms"] ) cleaning_matching_params = params["job_kwargs"].copy() diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 72acd49f4f..3053bfbdd0 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -219,10 +219,11 @@ def sigmoid(x, L, x0, k, b): recording, sorting, waveform_folder, - **params["job_kwargs"], - **params["waveforms"], return_scaled=False, mode=mode, + precompute_template=["median"], + **params["job_kwargs"], + **params["waveforms"], ) cleaning_matching_params = params["job_kwargs"].copy() From 954cb450d7a84567e34de65a5d3eba0d8f02b5b6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 14 Nov 2023 21:14:52 +0000 Subject: [PATCH 18/48] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sorters/internal/spyking_circus2.py | 11 +++++------ .../sortingcomponents/clustering/circus.py | 2 +- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index f12d5c4fb9..28b9652a3a 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -153,14 +153,13 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): waveforms_folder = sorter_output_folder / "waveforms" we = extract_waveforms( - recording_f, - sorting, - waveforms_folder, + recording_f, + sorting, + waveforms_folder, return_scaled=False, precompute_template=["median"], - mode=mode, - **waveforms_params - + mode=mode, + **waveforms_params, ) ## We launch a OMP matching pursuit by full convolution of the templates and the raw traces diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index d20e33d244..24f4b29718 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -261,7 +261,7 @@ def main_function(cls, recording, peaks, params): precompute_template=["median"], mode=mode, **params["job_kwargs"], - **params["waveforms"] + **params["waveforms"], ) cleaning_matching_params = params["job_kwargs"].copy() From 0343ac2ccc3c26b0501d001505792067c765e2e7 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 14 Nov 2023 22:44:55 +0100 Subject: [PATCH 19/48] Patch for hdbscan --- src/spikeinterface/sortingcomponents/clustering/circus.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index d20e33d244..44cddc4f70 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -193,8 +193,11 @@ def main_function(cls, recording, peaks, params): tsvd = TruncatedSVD(params["n_svd"][1]) sub_data = all_pc_data[mask] hdbscan_data = tsvd.fit_transform(sub_data.reshape(len(sub_data), -1)) - clustering = hdbscan.hdbscan(hdbscan_data, **d["hdbscan_kwargs"]) - local_labels = clustering[0] + try: + clustering = hdbscan.hdbscan(hdbscan_data, **d["hdbscan_kwargs"]) + local_labels = clustering[0] + except Exception: + local_labels = -1 * np.ones(len(hdbscan_data)) valid_clusters = local_labels > -1 if np.sum(valid_clusters) > 0: local_labels[valid_clusters] += nb_clusters From d31d402f880c6dc2224332acd622319f6b98a620 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 15 Nov 2023 14:17:12 +0100 Subject: [PATCH 20/48] Still a gap --- src/spikeinterface/sorters/internal/spyking_circus2.py | 1 + src/spikeinterface/sortingcomponents/clustering/circus.py | 2 +- src/spikeinterface/sortingcomponents/matching/circus.py | 4 ++-- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 28b9652a3a..b344606c52 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -74,6 +74,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): recording_f = zscore(recording_f, dtype="float32") noise_levels = np.ones(num_channels, dtype=np.float32) + ## Then, we are detecting peaks with a locally_exclusive method detection_params = params["detection"].copy() detection_params.update(job_kwargs) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index dd36135b8d..59983cbe03 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -84,7 +84,7 @@ class CircusClustering: "waveforms": {"ms_before": 2, "ms_after": 2, "max_spikes_per_unit": 100}, "radius_um": 100, "selection_method": "closest_to_centroid", - "n_svd": [6, 6], + "n_svd": [5, 10], "ms_before": 1, "ms_after": 1, "random_seed": 42, diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index ea36b75847..6278067987 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -508,8 +508,8 @@ class CircusOMPSVDPeeler(BaseTemplateMatchingEngine): """ _default_params = { - "amplitudes": [0.6, 2], - "omp_min_sps": 0.1, + "amplitudes": [0.75, 1.25], + "omp_min_sps": 0.05, "waveform_extractor": None, "random_chunk_kwargs": {}, "noise_levels": None, From b4f760e0f89c5a1db80efff8edec36ec9c47c111 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 15 Nov 2023 13:17:34 +0000 Subject: [PATCH 21/48] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sorters/internal/spyking_circus2.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index b344606c52..28b9652a3a 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -74,7 +74,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): recording_f = zscore(recording_f, dtype="float32") noise_levels = np.ones(num_channels, dtype=np.float32) - ## Then, we are detecting peaks with a locally_exclusive method detection_params = params["detection"].copy() detection_params.update(job_kwargs) From baf0280da90e05a51852d132fbe815db7a769b63 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 15 Nov 2023 15:14:02 +0100 Subject: [PATCH 22/48] Cleaning --- .../clustering/random_projections.py | 20 ------------------- 1 file changed, 20 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 3053bfbdd0..7d5b58551b 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -113,32 +113,12 @@ def main_function(cls, recording, peaks, params): nafter = int(params["ms_after"] * fs / 1000) nsamples = nbefore + nafter - import scipy - - x = np.random.randn(100, nsamples, num_chans).astype(np.float32) - x = scipy.signal.savgol_filter(x, node2.window_length, node2.order, axis=1) - - ptps = np.ptp(x, axis=1) - a, b = np.histogram(ptps.flatten(), np.linspace(0, 100, 1000)) - ydata = np.cumsum(a) / a.sum() - xdata = b[1:] - - from scipy.optimize import curve_fit - - def sigmoid(x, L, x0, k, b): - y = L / (1 + np.exp(-k * (x - x0))) + b - return y - - p0 = [max(ydata), np.median(xdata), 1, min(ydata)] # this is an mandatory initial guess - popt, pcov = curve_fit(sigmoid, xdata, ydata, p0) - node3 = RandomProjectionsFeature( recording, parents=[node0, node2], return_output=True, projections=projections, radius_um=params["radius_um"], - sigmoid=None, sparse=True, ) From 50853f0a2d90dfa116d5ab547272edaaa7099198 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 16 Nov 2023 10:50:24 +0100 Subject: [PATCH 23/48] WIP on the peeler --- .../sorters/internal/spyking_circus2.py | 1 - .../sortingcomponents/clustering/circus.py | 9 +---- .../clustering/clustering_tools.py | 6 +--- .../clustering/random_projections.py | 8 +---- .../sortingcomponents/matching/circus.py | 34 ++++++++----------- 5 files changed, 18 insertions(+), 40 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 37478b1aa4..c690b4228d 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -167,7 +167,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## We launch a OMP matching pursuit by full convolution of the templates and the raw traces matching_params = params["matching"].copy() matching_params["waveform_extractor"] = we - matching_params.update({"noise_levels": noise_levels}) matching_job_params = job_kwargs.copy() for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 59983cbe03..0905e61169 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -88,7 +88,6 @@ class CircusClustering: "ms_before": 1, "ms_after": 1, "random_seed": 42, - "noise_levels": None, "shared_memory": True, "tmp_folder": None, "job_kwargs": {"n_jobs": os.cpu_count(), "chunk_memory": "100M", "verbose": True, "progress_bar": True}, @@ -118,12 +117,6 @@ def main_function(cls, recording, peaks, params): nafter = int(ms_after * fs / 1000.0) num_samples = nbefore + nafter num_chans = recording.get_num_channels() - - if d["noise_levels"] is None: - noise_levels = get_noise_levels(recording, return_scaled=False) - else: - noise_levels = d["noise_levels"] - np.random.seed(d["random_seed"]) if params["tmp_folder"] is None: @@ -280,7 +273,7 @@ def main_function(cls, recording, peaks, params): cleaning_params["tmp_folder"] = tmp_folder labels, peak_labels = remove_duplicates_via_matching( - we, noise_levels, peak_labels, job_kwargs=cleaning_matching_params, **cleaning_params + we, peak_labels, job_kwargs=cleaning_matching_params, **cleaning_params ) del we, sorting diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 66fe660918..72da52d7a0 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -534,7 +534,6 @@ def remove_duplicates( def remove_duplicates_via_matching( waveform_extractor, - noise_levels, peak_labels, method_kwargs={}, job_kwargs={}, @@ -542,7 +541,6 @@ def remove_duplicates_via_matching( method="circus-omp-svd", ): from spikeinterface.sortingcomponents.matching import find_spikes_from_templates - from spikeinterface import get_noise_levels from spikeinterface.core import BinaryRecordingExtractor from spikeinterface.core import NumpySorting from spikeinterface.core import extract_waveforms @@ -598,9 +596,7 @@ def remove_duplicates_via_matching( local_params.update( { "waveform_extractor": waveform_extractor, - "noise_levels": noise_levels, - "amplitudes": [0.975, 1.025], - "omp_min_sps": 0.05, + "amplitudes": [0.975, 1.025] } ) diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 7d5b58551b..fee35709d7 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -77,12 +77,6 @@ def main_function(cls, recording, peaks, params): nafter = int(params["ms_after"] * fs / 1000.0) num_samples = nbefore + nafter num_chans = recording.get_num_channels() - - if d["noise_levels"] is None: - noise_levels = get_noise_levels(recording, return_scaled=False) - else: - noise_levels = d["noise_levels"] - np.random.seed(d["random_seed"]) if params["tmp_folder"] is None: @@ -219,7 +213,7 @@ def main_function(cls, recording, peaks, params): cleaning_params["tmp_folder"] = tmp_folder labels, peak_labels = remove_duplicates_via_matching( - we, noise_levels, peak_labels, job_kwargs=cleaning_matching_params, **cleaning_params + we, peak_labels, job_kwargs=cleaning_matching_params, **cleaning_params ) del we, sorting diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 6278067987..8bc4b34806 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -496,9 +496,6 @@ class CircusOMPSVDPeeler(BaseTemplateMatchingEngine): (Minimal, Maximal) amplitudes allowed for every template omp_min_sps: float Stopping criteria of the OMP algorithm, in percentage of the norm - noise_levels: array - The noise levels, for every channels. If None, they will be automatically - computed random_chunk_kwargs: dict Parameters for computing noise levels, if not provided (sub optimal) sparse_kwargs: dict @@ -509,10 +506,9 @@ class CircusOMPSVDPeeler(BaseTemplateMatchingEngine): _default_params = { "amplitudes": [0.75, 1.25], - "omp_min_sps": 0.05, + "omp_min_sps": 1e-4, "waveform_extractor": None, "random_chunk_kwargs": {}, - "noise_levels": None, "rank": 5, "sparse_kwargs": {"method": "ptp", "threshold": 1}, "ignored_ids": [], @@ -612,10 +608,6 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d["sampling_frequency"] = d["waveform_extractor"].recording.get_sampling_frequency() d["vicinity"] *= d["num_samples"] - if d["noise_levels"] is None: - print("CircusOMPPeeler : noise should be computed outside") - d["noise_levels"] = get_noise_levels(recording, **d["random_chunk_kwargs"], return_scaled=False) - if "templates" not in d: d = cls._prepare_templates(d) else: @@ -638,10 +630,7 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d["unit_overlaps_tables"][i] = np.zeros(d["num_templates"], dtype=int) d["unit_overlaps_tables"][i][d["unit_overlaps_indices"][i]] = np.arange(len(d["unit_overlaps_indices"][i])) - omp_min_sps = d["omp_min_sps"] - # d["stop_criteria"] = omp_min_sps * np.sqrt(d["noise_levels"].sum() * d["num_samples"]) - d["stop_criteria"] = omp_min_sps * np.maximum(d["norms"], np.sqrt(d["noise_levels"].sum() * d["num_samples"])) - + d["stop_criteria"] = d["omp_min_sps"] return d @classmethod @@ -675,7 +664,7 @@ def main_function(cls, traces, d): neighbor_window = num_samples - 1 min_amplitude, max_amplitude = d["amplitudes"] ignored_ids = d["ignored_ids"] - stop_criteria = d["stop_criteria"][:, np.newaxis] + stop_criteria = d["stop_criteria"] vicinity = d["vicinity"] rank = d["rank"] @@ -717,13 +706,15 @@ def main_function(cls, traces, d): neighbors = {} cached_overlaps = {} - is_valid = scalar_products > stop_criteria all_amplitudes = np.zeros(0, dtype=np.float32) is_in_vicinity = np.zeros(0, dtype=np.int32) + new_error = np.linalg.norm(scalar_products) + delta_error = np.inf - while np.any(is_valid): - best_amplitude_ind = scalar_products[is_valid].argmax() - best_cluster_ind, peak_index = np.unravel_index(idx_lookup[is_valid][best_amplitude_ind], idx_lookup.shape) + while delta_error > stop_criteria: + + best_amplitude_ind = scalar_products.argmax() + best_cluster_ind, peak_index = np.unravel_index(best_amplitude_ind, scalar_products.shape) if num_selection > 0: delta_t = selection[1] - peak_index @@ -818,7 +809,12 @@ def main_function(cls, traces, d): to_add = diff_amp * local_overlaps[:, tdx[0] : tdx[1]] scalar_products[overlapping_templates, idx[0] : idx[1]] -= to_add - is_valid = scalar_products > stop_criteria + previous_error = new_error + new_error = np.linalg.norm(scalar_products) + if previous_error != 0: + delta_error = np.abs(new_error / previous_error - 1) + else: + delta_error = 0 is_valid = (final_amplitudes > min_amplitude) * (final_amplitudes < max_amplitude) valid_indices = np.where(is_valid) From df226b651db4685162115fc5f5c2a2cbb3f8e57f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 16 Nov 2023 09:51:53 +0000 Subject: [PATCH 24/48] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sortingcomponents/clustering/clustering_tools.py | 7 +------ src/spikeinterface/sortingcomponents/matching/circus.py | 1 - 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 72da52d7a0..052a596c63 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -593,12 +593,7 @@ def remove_duplicates_via_matching( local_params = method_kwargs.copy() - local_params.update( - { - "waveform_extractor": waveform_extractor, - "amplitudes": [0.975, 1.025] - } - ) + local_params.update({"waveform_extractor": waveform_extractor, "amplitudes": [0.975, 1.025]}) spikes_per_units, counts = np.unique(waveform_extractor.sorting.to_spike_vector()["unit_index"], return_counts=True) indices = np.argsort(counts) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 8bc4b34806..21de446162 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -712,7 +712,6 @@ def main_function(cls, traces, d): delta_error = np.inf while delta_error > stop_criteria: - best_amplitude_ind = scalar_products.argmax() best_cluster_ind, peak_index = np.unravel_index(best_amplitude_ind, scalar_products.shape) From 919a40494ab3635efbc3f8df31ffd3d7308f1b98 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 16 Nov 2023 11:54:04 +0100 Subject: [PATCH 25/48] Fix for cleaning via matching --- .../sortingcomponents/matching/circus.py | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 8bc4b34806..091b5d32a4 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -676,13 +676,13 @@ def main_function(cls, traces, d): # Filter using overlap-and-add convolution if len(ignored_ids) > 0: - mask = ~np.isin(np.arange(num_templates), ignored_ids) - spatially_filtered_data = np.matmul(d["spatial"][:, mask, :], traces.T[np.newaxis, :, :]) - scaled_filtered_data = spatially_filtered_data * d["singular"][:, mask, :] + not_ignored = ~np.isin(np.arange(num_templates), ignored_ids) + spatially_filtered_data = np.matmul(d["spatial"][:, not_ignored, :], traces.T[np.newaxis, :, :]) + scaled_filtered_data = spatially_filtered_data * d["singular"][:, not_ignored, :] objective_by_rank = scipy.signal.oaconvolve( - scaled_filtered_data, d["temporal"][:, mask, :], axes=2, mode="valid" + scaled_filtered_data, d["temporal"][:, not_ignored, :], axes=2, mode="valid" ) - scalar_products[mask] += np.sum(objective_by_rank, axis=0) + scalar_products[not_ignored] += np.sum(objective_by_rank, axis=0) scalar_products[ignored_ids] = -np.inf else: spatially_filtered_data = np.matmul(d["spatial"], traces.T[np.newaxis, :, :]) @@ -693,7 +693,6 @@ def main_function(cls, traces, d): num_spikes = 0 spikes = np.empty(scalar_products.size, dtype=spike_dtype) - idx_lookup = np.arange(scalar_products.size).reshape(num_templates, -1) M = np.zeros((num_templates, num_templates), dtype=np.float32) @@ -708,7 +707,10 @@ def main_function(cls, traces, d): all_amplitudes = np.zeros(0, dtype=np.float32) is_in_vicinity = np.zeros(0, dtype=np.int32) - new_error = np.linalg.norm(scalar_products) + if len(ignored_ids) > 0: + new_error = np.linalg.norm(scalar_products[not_ignored]) + else: + new_error = np.linalg.norm(scalar_products) delta_error = np.inf while delta_error > stop_criteria: @@ -810,11 +812,11 @@ def main_function(cls, traces, d): scalar_products[overlapping_templates, idx[0] : idx[1]] -= to_add previous_error = new_error - new_error = np.linalg.norm(scalar_products) - if previous_error != 0: - delta_error = np.abs(new_error / previous_error - 1) + if len(ignored_ids) > 0: + new_error = np.linalg.norm(scalar_products[not_ignored]) else: - delta_error = 0 + new_error = np.linalg.norm(scalar_products) + delta_error = np.abs(new_error / previous_error - 1) is_valid = (final_amplitudes > min_amplitude) * (final_amplitudes < max_amplitude) valid_indices = np.where(is_valid) From ee239e7a0ba940e74a6086c1cc0376d9f740ed0b Mon Sep 17 00:00:00 2001 From: Sebastien Date: Thu, 16 Nov 2023 14:34:02 +0100 Subject: [PATCH 26/48] Closing the gap --- src/spikeinterface/sorters/internal/spyking_circus2.py | 1 + src/spikeinterface/sortingcomponents/clustering/circus.py | 2 +- .../sortingcomponents/clustering/random_projections.py | 3 +-- src/spikeinterface/sortingcomponents/matching/circus.py | 4 ++-- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index c690b4228d..1d4f04a382 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -71,6 +71,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): recording_f = common_reference(recording_f) else: recording_f = recording + recording_f.annotate(is_filtered=True) # recording_f = whiten(recording_f, dtype="float32") recording_f = zscore(recording_f, dtype="float32") diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 0905e61169..6d29fe3b37 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -12,7 +12,7 @@ HAVE_HDBSCAN = False import random, string, os -from spikeinterface.core import get_global_tmp_folder, get_noise_levels, get_channel_distances +from spikeinterface.core import get_global_tmp_folder, get_channel_distances from sklearn.preprocessing import QuantileTransformer, MaxAbsScaler from spikeinterface.core.waveform_tools import extract_waveforms_to_buffers from .clustering_tools import remove_duplicates, remove_duplicates_via_matching, remove_duplicates_via_dip diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index fee35709d7..dcb84cb6ff 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -12,7 +12,7 @@ HAVE_HDBSCAN = False import random, string, os -from spikeinterface.core import get_global_tmp_folder, get_noise_levels, get_channel_distances, get_random_data_chunks +from spikeinterface.core import get_global_tmp_folder, get_channel_distances, get_random_data_chunks from sklearn.preprocessing import QuantileTransformer, MaxAbsScaler from spikeinterface.core.waveform_tools import extract_waveforms_to_buffers from .clustering_tools import remove_duplicates, remove_duplicates_via_matching, remove_duplicates_via_dip @@ -48,7 +48,6 @@ class RandomProjectionClustering: "ms_before": 1, "ms_after": 1, "random_seed": 42, - "noise_levels": None, "smoothing_kwargs": {"window_length_ms": 0.25}, "shared_memory": True, "tmp_folder": None, diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 39decc2380..839fe1dbd2 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -505,8 +505,8 @@ class CircusOMPSVDPeeler(BaseTemplateMatchingEngine): """ _default_params = { - "amplitudes": [0.75, 1.25], - "omp_min_sps": 1e-4, + "amplitudes": [0.6, 1.4], + "omp_min_sps": 1e-5, "waveform_extractor": None, "random_chunk_kwargs": {}, "rank": 5, From 9054f7b79ea7cd742bd383f9921ecafe58fd02f5 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 16 Nov 2023 14:46:01 +0100 Subject: [PATCH 27/48] Speeding up merging --- .../sortingcomponents/clustering/clustering_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 052a596c63..1167541ebf 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -593,7 +593,7 @@ def remove_duplicates_via_matching( local_params = method_kwargs.copy() - local_params.update({"waveform_extractor": waveform_extractor, "amplitudes": [0.975, 1.025]}) + local_params.update({"waveform_extractor": waveform_extractor, "amplitudes": [0.975, 1.025], "omp_min_sps" : 1e-3}) spikes_per_units, counts = np.unique(waveform_extractor.sorting.to_spike_vector()["unit_index"], return_counts=True) indices = np.argsort(counts) From b3852eabe99770b5ca12f8b3ba2276c544041ede Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 16 Nov 2023 13:46:23 +0000 Subject: [PATCH 28/48] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sortingcomponents/clustering/clustering_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 1167541ebf..629b0b13ac 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -593,7 +593,7 @@ def remove_duplicates_via_matching( local_params = method_kwargs.copy() - local_params.update({"waveform_extractor": waveform_extractor, "amplitudes": [0.975, 1.025], "omp_min_sps" : 1e-3}) + local_params.update({"waveform_extractor": waveform_extractor, "amplitudes": [0.975, 1.025], "omp_min_sps": 1e-3}) spikes_per_units, counts = np.unique(waveform_extractor.sorting.to_spike_vector()["unit_index"], return_counts=True) indices = np.argsort(counts) From ddb7eb964ca7d975d06ecb8f2e9a02b226ab924c Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 16 Nov 2023 15:31:23 +0100 Subject: [PATCH 29/48] WIP --- src/spikeinterface/sortingcomponents/matching/circus.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 839fe1dbd2..77bbf3a73b 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -506,7 +506,7 @@ class CircusOMPSVDPeeler(BaseTemplateMatchingEngine): _default_params = { "amplitudes": [0.6, 1.4], - "omp_min_sps": 1e-5, + "omp_min_sps": 5e-5, "waveform_extractor": None, "random_chunk_kwargs": {}, "rank": 5, From 3b514fe3d32d511556b65c50b965e90dbacf8f3a Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 17 Nov 2023 09:59:22 +0100 Subject: [PATCH 30/48] Documentation --- .../sortingcomponents/matching/circus.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 77bbf3a73b..b0311e10bd 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -495,12 +495,15 @@ class CircusOMPSVDPeeler(BaseTemplateMatchingEngine): amplitude: tuple (Minimal, Maximal) amplitudes allowed for every template omp_min_sps: float - Stopping criteria of the OMP algorithm, in percentage of the norm - random_chunk_kwargs: dict - Parameters for computing noise levels, if not provided (sub optimal) + Stopping criteria of the OMP algorithm, as relative error sparse_kwargs: dict Parameters to extract a sparsity mask from the waveform_extractor, if not already sparse. + rank: int + Number of components used internally by the SVD (default 5) + vicinity: int + Size of the area surrounding a spike to perform modification (expressed in terms + of template temporal width) ----- """ @@ -508,7 +511,6 @@ class CircusOMPSVDPeeler(BaseTemplateMatchingEngine): "amplitudes": [0.6, 1.4], "omp_min_sps": 5e-5, "waveform_extractor": None, - "random_chunk_kwargs": {}, "rank": 5, "sparse_kwargs": {"method": "ptp", "threshold": 1}, "ignored_ids": [], From 69f6f8ccebe2085703c918a05e4693778a391067 Mon Sep 17 00:00:00 2001 From: Zach McKenzie <92116279+zm711@users.noreply.github.com> Date: Fri, 17 Nov 2023 05:35:57 -0500 Subject: [PATCH 31/48] Sam & Heberto feedback --- doc/import.rst | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/doc/import.rst b/doc/import.rst index 430e53ad4e..fbf5c6f985 100644 --- a/doc/import.rst +++ b/doc/import.rst @@ -48,9 +48,7 @@ more aliases to keep track of. Flat Import ----------- -A second option is to import the SpikeInterface package in :code:`full` mode. This would be similar to -what is seen with packages like NumPy (:code:`np`) or Pandas (:code:`pd`), which offer the majority of -their functionality with a single alias and the option to import additional functionality separately. +A second option is to import the SpikeInterface package in :code:`full` mode. To accomplish this one does: @@ -60,8 +58,6 @@ To accomplish this one does: This import statement will import all of the SpikeInterface modules as one flattened module. -Note that importing :code:`spikeinterface.full` will take a few extra seconds, because some modules use -just-in-time :code:`numba` compilation performed at the time of import. We recommend this approach for advanced (or lazy) users, since it requires a deeper knowledge of the API. The advantage being that users can access all functions using one alias without the need of memorizing all aliases. From 7595d9ce3ba57f11ab33e8260f26654c1fca204a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 17 Nov 2023 10:36:53 +0000 Subject: [PATCH 32/48] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- doc/import.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/import.rst b/doc/import.rst index fbf5c6f985..ad5e1da0ee 100644 --- a/doc/import.rst +++ b/doc/import.rst @@ -48,7 +48,7 @@ more aliases to keep track of. Flat Import ----------- -A second option is to import the SpikeInterface package in :code:`full` mode. +A second option is to import the SpikeInterface package in :code:`full` mode. To accomplish this one does: From 836d77b256b17aec7dae1009f341dc4e2b279f44 Mon Sep 17 00:00:00 2001 From: Zach McKenzie <92116279+zm711@users.noreply.github.com> Date: Fri, 17 Nov 2023 06:53:00 -0500 Subject: [PATCH 33/48] import spikeinterface.core as si --- doc/import.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/import.rst b/doc/import.rst index ad5e1da0ee..be3f7d5afb 100644 --- a/doc/import.rst +++ b/doc/import.rst @@ -17,7 +17,7 @@ be accomplished by: .. code-block:: python - import spikeinterface as si + import spikeinterface.core as si to import the :code:`core` module followed by: From 47643f7555a9c950d63c62085ae1c5f830e7898d Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 17 Nov 2023 13:00:09 +0100 Subject: [PATCH 34/48] Update src/spikeinterface/sortingcomponents/clustering/circus.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/sortingcomponents/clustering/circus.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 6d29fe3b37..401ed58871 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -95,7 +95,7 @@ class CircusClustering: @classmethod def main_function(cls, recording, peaks, params): - assert HAVE_HDBSCAN, "random projections clustering need hdbscan to be installed" + assert HAVE_HDBSCAN, "random projections clustering needs hdbscan to be installed" if "n_jobs" in params["job_kwargs"]: if params["job_kwargs"]["n_jobs"] == -1: From b78433068d9a464485e4397fb7f0272d0d4b9093 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 17 Nov 2023 13:00:19 +0100 Subject: [PATCH 35/48] Update src/spikeinterface/sortingcomponents/matching/circus.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/sortingcomponents/matching/circus.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index b0311e10bd..d23095d838 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -500,7 +500,7 @@ class CircusOMPSVDPeeler(BaseTemplateMatchingEngine): Parameters to extract a sparsity mask from the waveform_extractor, if not already sparse. rank: int - Number of components used internally by the SVD (default 5) + Number of components used internally by the SVD vicinity: int Size of the area surrounding a spike to perform modification (expressed in terms of template temporal width) From 0c65b2ca8895895b1a89ddc9a8eb9688b4b42c9a Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 17 Nov 2023 13:00:25 +0100 Subject: [PATCH 36/48] Update src/spikeinterface/sortingcomponents/matching/circus.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/sortingcomponents/matching/circus.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index d23095d838..cfdca6f612 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -499,7 +499,7 @@ class CircusOMPSVDPeeler(BaseTemplateMatchingEngine): sparse_kwargs: dict Parameters to extract a sparsity mask from the waveform_extractor, if not already sparse. - rank: int + rank: int, default: 5 Number of components used internally by the SVD vicinity: int Size of the area surrounding a spike to perform modification (expressed in terms From 3def285b8af6600337abde9374ec67dcea685048 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 17 Nov 2023 13:00:34 +0100 Subject: [PATCH 37/48] Update src/spikeinterface/sortingcomponents/clustering/circus.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/sortingcomponents/clustering/circus.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 401ed58871..47c5a1e58f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -34,7 +34,7 @@ def extract_waveform_at_max_channel(rec, peaks, ms_before=0.5, ms_after=1.5, **job_kwargs): """ - Helper function to extractor waveforms at max channel from a peak list + Helper function to extract waveforms at the max channel from a peak list """ From 850c6dc4bd50818d5b45b857badcce2a695ccb40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Fri, 17 Nov 2023 15:31:24 +0100 Subject: [PATCH 38/48] Fixed a bug when caching recording noise levels Caching needs to depend on the method, otherwise the result might be erroneous. --- src/spikeinterface/core/recording_tools.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 34313cd7ae..030cce7faa 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -164,9 +164,9 @@ def get_noise_levels( """ if return_scaled: - key = "noise_level_scaled" + key = f"noise_level_{method}_scaled" else: - key = "noise_level_raw" + key = f"noise_level_{method}_raw" if key in recording.get_property_keys() and not force_recompute: noise_levels = recording.get_property(key=key) From 675667f6e7e60af5c37cc862056fd64a6b093032 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Fri, 17 Nov 2023 15:42:12 +0100 Subject: [PATCH 39/48] Fixed noise levels propagation --- src/spikeinterface/core/baserecording.py | 2 +- .../tests/test_noise_levels_propagation.py | 18 +++++++++--------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 6dfe038558..8bd31abfce 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -29,7 +29,7 @@ class BaseRecording(BaseRecordingSnippets): _main_properties = ["group", "location", "gain_to_uV", "offset_to_uV"] _main_features = [] # recording do not handle features - _skip_properties = ["noise_level_raw", "noise_level_scaled"] + _skip_properties = ["noise_level_std_raw", "noise_level_std_scaled", "noise_level_mad_raw", "noise_level_mad_scaled"] def __init__(self, sampling_frequency: float, channel_ids: List, dtype): BaseRecordingSnippets.__init__( diff --git a/src/spikeinterface/core/tests/test_noise_levels_propagation.py b/src/spikeinterface/core/tests/test_noise_levels_propagation.py index d6dbd08abe..6f1b46bd33 100644 --- a/src/spikeinterface/core/tests/test_noise_levels_propagation.py +++ b/src/spikeinterface/core/tests/test_noise_levels_propagation.py @@ -19,26 +19,26 @@ def test_skip_noise_levels_propagation(): rec = generate_recording(durations=[5.0], num_channels=4) rec.set_property("test", ["1", "2", "3", "4"]) rec = rec.save() - noise_level_raw = get_noise_levels(rec, return_scaled=False) - assert "noise_level_raw" in rec.get_property_keys() + noise_level_raw = get_noise_levels(rec, return_scaled=False, method="mad") + assert "noise_level_mad_raw" in rec.get_property_keys() rec_frame_slice = rec.frame_slice(start_frame=0, end_frame=1000) - assert "noise_level_raw" not in rec_frame_slice.get_property_keys() + assert "noise_level_mad_raw" not in rec_frame_slice.get_property_keys() assert "test" in rec_frame_slice.get_property_keys() # make scaled rec.set_channel_gains([100] * 4) rec.set_channel_offsets([0] * 4) - noise_level_scaled = get_noise_levels(rec, return_scaled=True) - assert "noise_level_scaled" in rec.get_property_keys() + noise_level_scaled = get_noise_levels(rec, return_scaled=True, method="std") + assert "noise_level_std_scaled" in rec.get_property_keys() rec_frame_slice = rec.frame_slice(start_frame=0, end_frame=1000) rec_concat = concatenate_recordings([rec] * 5) - assert "noise_level_raw" not in rec_concat.get_property_keys() - assert "noise_level_scaled" not in rec_concat.get_property_keys() - assert "noise_level_raw" not in rec_frame_slice.get_property_keys() - assert "noise_level_scaled" not in rec_frame_slice.get_property_keys() + assert "noise_level_mad_raw" not in rec_concat.get_property_keys() + assert "noise_level_std_scaled" not in rec_concat.get_property_keys() + assert "noise_level_mad_raw" not in rec_frame_slice.get_property_keys() + assert "noise_level_std_scaled" not in rec_frame_slice.get_property_keys() assert "test" in rec_frame_slice.get_property_keys() assert "test" in rec_concat.get_property_keys() From e5001ea6a97b2840e2c016b139b4bde494e91fbc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 17 Nov 2023 14:43:33 +0000 Subject: [PATCH 40/48] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/baserecording.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 8bd31abfce..0aff78499b 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -29,7 +29,12 @@ class BaseRecording(BaseRecordingSnippets): _main_properties = ["group", "location", "gain_to_uV", "offset_to_uV"] _main_features = [] # recording do not handle features - _skip_properties = ["noise_level_std_raw", "noise_level_std_scaled", "noise_level_mad_raw", "noise_level_mad_scaled"] + _skip_properties = [ + "noise_level_std_raw", + "noise_level_std_scaled", + "noise_level_mad_raw", + "noise_level_mad_scaled", + ] def __init__(self, sampling_frequency: float, channel_ids: List, dtype): BaseRecordingSnippets.__init__( From 200426353480bd41a9375076aca66f2e5485c617 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Fri, 17 Nov 2023 19:10:50 -0500 Subject: [PATCH 41/48] various updates to docs --- doc/modules/core.rst | 17 ++++++++-------- doc/modules/curation.rst | 2 +- doc/modules/exporters.rst | 4 ++-- doc/modules/widgets.rst | 41 +++++++++++++++++++-------------------- 4 files changed, 32 insertions(+), 32 deletions(-) diff --git a/doc/modules/core.rst b/doc/modules/core.rst index ef7b266a86..697681bcb1 100644 --- a/doc/modules/core.rst +++ b/doc/modules/core.rst @@ -143,7 +143,7 @@ with 10 units: .. code-block:: python - unit_ids = sorting.channel_ids + unit_ids = sorting.unit_ids num_channels = sorting.get_num_units() sampling_frequency = sorting.sampling_frequency @@ -419,11 +419,16 @@ probes, such as Neuropixels, because the waveforms of a unit will only appear on Sparsity is defined as the subset of channels on which waveforms (and related information) are defined. Of course, sparsity is not global, but it is unit-specific. +**NOTE** As of version :code:`0.99.0` the default for a :code:`extract_waveforms()` has `sparse=True`, ie every :code:`waveform_extractor` +will be sparse by default. Thus for users that wish to have dense waveforms they must set `sparse=False`. Keyword arguments +can still be input into the :code:`extract_wavforms()` to generate the desired sparsity as explained below. + Sparsity can be computed from a :py:class:`~spikeinterface.core.WaveformExtractor` object with the :py:func:`~spikeinterface.core.compute_sparsity` function: .. code-block:: python + # in this case 'we' should be a dense waveform_extractor sparsity = compute_sparsity(we, method="radius", radius_um=40) The returned :code:`sparsity` is a :py:class:`~spikeinterface.core.ChannelSparsity` object, which has convenient @@ -437,11 +442,11 @@ methods to access the sparsity information in several ways: There are several methods to compute sparsity, including: * | :code:`method="radius"`: selects the channels based on the channel locations. For example, using a - | :code:`radius_um=40`, will select, for each unit, the channels which are whithin 40um of the channel with the + | :code:`radius_um=40`, will select, for each unit, the channels which are within 40um of the channel with the | largest amplitude (*the extremum channel*). **This is the recommended method for high-density probes** * | :code:`method="best_channels"`: selects the best :code:`num_channels` channels based on their amplitudes. Note that | in this case the selected channels might not be close to each other. -* | :code:`method="threshold"`: selects channels based on an SNR threshold (:code:`threshold` argument) +* | :code:`method="threshold"`: selects channels based on an SNR threshold (given by the :code:`threshold` argument) * | :code:`method="by_property"`: selects channels based on a property, such as :code:`group`. This method is recommended | when working with tetrodes. @@ -460,10 +465,6 @@ waveforms folder. .. _save_load: -**NOTE:** As of SpikeInterface 0.99.0, :code:`extract_waveforms` now defaults to :code:`sparse=True`, so that default -behavior is to always have sparse waveforms. To have dense waveforms (the previous default behavior), remember to set -:code:`sparsity=False`. - Saving, loading, and compression -------------------------------- @@ -635,7 +636,7 @@ the new objects will be a *view* of the original ones. # here we load a very long recording and sorting recording = read_spikeglx('np_folder') - sorting =read_kilosrt('ks_folder') + sorting =read_kilosort('ks_folder') # keep one channel of every tenth channel keep_ids = recording.channel_ids[::10] diff --git a/doc/modules/curation.rst b/doc/modules/curation.rst index d533cdcac8..032988818b 100644 --- a/doc/modules/curation.rst +++ b/doc/modules/curation.rst @@ -76,7 +76,7 @@ merges. Therefore, it has many parameters and options. clean_sorting = MergeUnitsSorting(parent_sorting=sorting, units_to_merge=merges) -Manual curation with sorting view +Manual curation with sortingview --------------------------------- Within the :code:`sortingview` widgets backend (see :ref:`sorting_view`), the diff --git a/doc/modules/exporters.rst b/doc/modules/exporters.rst index d9c4be963f..b322139c2b 100644 --- a/doc/modules/exporters.rst +++ b/doc/modules/exporters.rst @@ -28,7 +28,7 @@ The input of the :py:func:`~spikeinterface.exporters.export_to_phy` is a :code:` from spikeinterface.exporters import export_to_phy # the waveforms are sparse so it is faster to export to phy - we = extract_waveforms(recording=recording, sorting=sorting, folder='waveforms', sparse=True) + we = extract_waveforms(recording=recording, sorting=sorting, folder='waveforms') # some computations are done before to control all options _ = compute_spike_amplitudes(waveform_extractor=we) @@ -71,7 +71,7 @@ with many units! # the waveforms are sparse for more interpretable figures - we = extract_waveforms(recording=recording, sorting=sorting, folder='path/to/wf', sparse=True) + we = extract_waveforms(recording=recording, sorting=sorting, folder='path/to/wf',) # some computations are done before to control all options _ = compute_spike_amplitudes(waveform_extractor=we) diff --git a/doc/modules/widgets.rst b/doc/modules/widgets.rst index d5a2ee87c6..68106d13a9 100644 --- a/doc/modules/widgets.rst +++ b/doc/modules/widgets.rst @@ -14,7 +14,7 @@ Since version 0.95.0, the :py:mod:`spikeinterface.widgets` module supports multi * | :code:`sortingview`: web-based and interactive rendering using the `sortingview `_ | and `FIGURL `_ packages. -Version 0.100.0, also come with this new backend: +Version 0.100.0, also comes with this new backend: * | :code:`ephyviewer`: interactive Qt based using the `ephyviewer `_ package @@ -197,7 +197,7 @@ The functions have the following additional arguments: # sortingview backend w_ts = sw.plot_traces(recording=recording, backend="sortingview") - w_ss = sw.plot_sorting_summary(waveform_extractor = we, curation=True, backend="sortingview") + w_ss = sw.plot_sorting_summary(waveform_extractor=we, curation=True, backend="sortingview") **Output:** @@ -259,11 +259,22 @@ The :code:`ephyviewer` backend is currently only available for the :py:func:`~sp Available plotting functions ---------------------------- +* :py:func:`~spikeinterface.widgets.plot_agreement_matrix` (backends: :code:`matplotlib`) * :py:func:`~spikeinterface.widgets.plot_all_amplitudes_distributions` (backends: :code:`matplotlib`) * :py:func:`~spikeinterface.widgets.plot_amplitudes` (backends: :code:`matplotlib`, :code:`ipywidgets`, :code:`sortingview`) * :py:func:`~spikeinterface.widgets.plot_autocorrelograms` (backends: :code:`matplotlib`, :code:`sortingview`) +* :py:func:`~spikeinterface.widgets.plot_confusion_matrix` (backends: :code:`matplotlib`) +* :py:func:`~spikeinterface.widgets.plot_comparison_collision_by_similarity` * :py:func:`~spikeinterface.widgets.plot_crosscorrelograms` (backends: :code:`matplotlib`, :code:`sortingview`) +* :py:func:`~spikeinterface.widgets.plot_isi_distribution` (backends: :code:`matplotlib`) +* :py:func:`~spikeinterface.widgets.plot_motion` (backends: :code:`matplotlib`) +* :py:func:`~spikeinterface.widgets.plot_multicomparison_agreement` (backends: :code:`matplotlib`) +* :py:func:`~spikeinterface.widgets.plot_multicomparison_agreement_by_sorter` (backends: :code:`matplotlib`) +* :py:func:`~spikeinterface.widgets.plot_multicomparison_graph` (backends: :code:`matplotlib`) +* :py:func:`~spikeinterface.widgets.plot_peak_activity` (backends: :code:`matplotlib`) +* :py:func:`~spikeinterface.widgets.plot_probe_map` (backends: :code:`matplotlib`) * :py:func:`~spikeinterface.widgets.plot_quality_metrics` (backends: :code:`matplotlib`, :code:`ipywidgets`, :code:`sortingview`) +* :py:func:`~spikeinterface.widgets.plot_rasters` (backends: :code:`matplotlib`) * :py:func:`~spikeinterface.widgets.plot_sorting_summary` (backends: :code:`sortingview`) * :py:func:`~spikeinterface.widgets.plot_spike_locations` (backends: :code:`matplotlib`, :code:`ipywidgets`) * :py:func:`~spikeinterface.widgets.plot_spikes_on_traces` (backends: :code:`matplotlib`, :code:`ipywidgets`) @@ -272,26 +283,14 @@ Available plotting functions * :py:func:`~spikeinterface.widgets.plot_traces` (backends: :code:`matplotlib`, :code:`ipywidgets`, :code:`sortingview`, :code:`ephyviewer`) * :py:func:`~spikeinterface.widgets.plot_unit_depths` (backends: :code:`matplotlib`) * :py:func:`~spikeinterface.widgets.plot_unit_locations` (backends: :code:`matplotlib`, :code:`ipywidgets`, :code:`sortingview`) +* :py:func:`~spikeinterface.widgets.plot_unit_presence` (backends: :code:`matplotlib`) +* :py:func:`~spikeinterface.widgets.plot_unit_probe_map` (backends: :code:`matplotlib`) * :py:func:`~spikeinterface.widgets.plot_unit_summary` (backends: :code:`matplotlib`) * :py:func:`~spikeinterface.widgets.plot_unit_templates` (backends: :code:`matplotlib`, :code:`ipywidgets`, :code:`sortingview`) * :py:func:`~spikeinterface.widgets.plot_unit_waveforms_density_map` (backends: :code:`matplotlib`) * :py:func:`~spikeinterface.widgets.plot_unit_waveforms` (backends: :code:`matplotlib`, :code:`ipywidgets`) - - -# Which have been moved over? - -* :py:func:`~spikeinterface.widgets.plot_rasters` -* :py:func:`~spikeinterface.widgets.plot_probe_map` -* :py:func:`~spikeinterface.widgets.plot_isi_distribution` -* :py:func:`~spikeinterface.widgets.plot_drift_over_time` -* :py:func:`~spikeinterface.widgets.plot_peak_activity_map` -* :py:func:`~spikeinterface.widgets.plot_principal_component` -* :py:func:`~spikeinterface.widgets.plot_unit_probe_map` -* :py:func:`~spikeinterface.widgets.plot_confusion_matrix` -* :py:func:`~spikeinterface.widgets.plot_agreement_matrix` -* :py:func:`~spikeinterface.widgets.plot_multicomp_graph` -* :py:func:`~spikeinterface.widgets.plot_multicomp_agreement` -* :py:func:`~spikeinterface.widgets.plot_multicomp_agreement_by_sorter` -* :py:func:`~spikeinterface.widgets.plot_comparison_collision_pair_by_pair` -* :py:func:`~spikeinterface.widgets.plot_comparison_collision_by_similarity` -* :py:func:`~spikeinterface.widgets.plot_sorting_performance` +* :py:func:`~spikeinterface.widgets.plot_study_run_times` (backends: :code:`matplotlib`) +* :py:func:`~spikeinterface.widgets.plot_study_unit_counts` (backends: :code:`matplotlib`) +* :py:func:`~spikeinterface.widgets.plot_study_agreement_matrix` (backends: :code:`matplotlib`) +* :py:func:`~spikeinterface.widgets.plot_study_summary` (backends: :code:`matplotlib`) +* :py:func:`~spikeinterface.widgets.plot_study_comparison_collision_by_similarity` (backends: :code:`matplotlib`) From d86aba38b65d8461d688d240deaaeceec8097556 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Sat, 18 Nov 2023 12:00:32 -0500 Subject: [PATCH 42/48] update examples for modules gallery --- .../core/plot_1_recording_extractor.py | 40 ++++++++++--------- .../core/plot_3_handle_probe_info.py | 14 +++---- .../core/plot_6_handle_times.py | 6 +-- .../extractors/plot_1_read_various_formats.py | 16 ++++---- .../plot_2_working_with_unscaled_traces.py | 8 ++-- .../qualitymetrics/plot_3_quality_mertics.py | 19 +++++---- .../qualitymetrics/plot_4_curation.py | 21 ++++++---- 7 files changed, 69 insertions(+), 55 deletions(-) diff --git a/examples/modules_gallery/core/plot_1_recording_extractor.py b/examples/modules_gallery/core/plot_1_recording_extractor.py index f20bf6497d..f5d3ee1db2 100644 --- a/examples/modules_gallery/core/plot_1_recording_extractor.py +++ b/examples/modules_gallery/core/plot_1_recording_extractor.py @@ -26,7 +26,7 @@ num_channels = 7 sampling_frequency = 30000. # in Hz -durations = [10., 15.] #  in s for 2 segments +durations = [10., 15.] # in s for 2 segments num_segments = 2 num_timepoints = [int(sampling_frequency * d) for d in durations] @@ -38,7 +38,7 @@ traces1 = np.random.normal(0, 10, (num_timepoints[1], num_channels)) ############################################################################## -# And instantiate a :py:class:`~spikeinterface.core.NumpyRecording`. Each object has a pretty print to +# And instantiate a :py:class:`~spikeinterface.core.NumpyRecording`. Each object has a pretty print to # summarize its content: recording = se.NumpyRecording(traces_list=[traces0, traces1], sampling_frequency=sampling_frequency) @@ -47,24 +47,28 @@ ############################################################################## # We can now print properties that the :code:`RecordingExtractor` retrieves from the underlying recording. -print('Num. channels = {}'.format(len(recording.get_channel_ids()))) -print('Sampling frequency = {} Hz'.format(recording.get_sampling_frequency())) -print('Num. timepoints seg0= {}'.format(recording.get_num_segments())) -print('Num. timepoints seg0= {}'.format(recording.get_num_frames(segment_index=0))) -print('Num. timepoints seg1= {}'.format(recording.get_num_frames(segment_index=1))) +print(f'Number of channels = {recording.get_channel_ids()}') +print(f'Sampling frequency = {recording.get_sampling_frequency()} Hz') +print(f'Number of segments= {recording.get_num_segments()}') +print(f'Number of timepoints in seg0= {recording.get_num_frames(segment_index=0)}') +print(f'Number of timepoints in seg1= {recording.get_num_frames(segment_index=1)}') ############################################################################## -# The geometry of the Probe is handle with the :probeinterface:`ProbeInterface <>`. -# Let's generate a linear probe: +# The geometry of the Probe is handled with the :probeinterface:`ProbeInterface <>` library. +# Let's generate a linear probe by specifying our number of electrodes/contacts (num_elec) +# the distance between the contacts (ypitch), their shape (contact_shapes) and their size +# (contact_shape_params): from probeinterface import generate_linear_probe from probeinterface.plotting import plot_probe probe = generate_linear_probe(num_elec=7, ypitch=20, contact_shapes='circle', contact_shape_params={'radius': 6}) -# the probe has to be wired to the recording +# the probe has to be wired to the recording device (i.e., which electrode corresponds to an entry in the data +# matrix) probe.set_device_channel_indices(np.arange(7)) +# then we need to actually set the probe to the recording object recording = recording.set_probe(probe) plot_probe(probe) @@ -76,14 +80,14 @@ ############################################################################## # We can read the written recording back with the proper extractor. -# Note that this new recording is now "on disk" and not "in memory" as the Numpy recording. -# This means that the loading is "lazy" and the data are not loaded in memory. +# Note that this new recording is now "on disk" and not "in memory" as the Numpy recording was. +# This means that the loading is "lazy" and the data are not loaded into memory. recording2 = se.BinaryRecordingExtractor(file_paths=file_paths, sampling_frequency=sampling_frequency, num_channels=num_channels, dtype=traces0.dtype) print(recording2) ############################################################################## -#  Loading traces in memory is done on demand: +# Loading traces in memory is done on demand: # entire segment 0 traces0 = recording2.get_traces(segment_index=0) @@ -93,8 +97,8 @@ print(traces1_short.shape) ############################################################################## -# A recording internally has :code:`channel_ids`: these are a vector that can have -# dtype int or str: +# Internally, a recording has :code:`channel_ids`: that are a vector that can have a +# dtype of :code:`int` or :code:`str`: print('chan_ids (dtype=int):', recording.get_channel_ids()) @@ -111,7 +115,7 @@ print(traces.shape) ############################################################################## -# You can also get a a recording with a subset of channel (a channel slice): +# You can also get a recording with a subset of channels (i.e. a channel slice): recording4 = recording3.channel_slice(channel_ids=['a', 'c', 'e']) print(recording4) @@ -136,7 +140,7 @@ ############################################################################### # A recording can be "dumped" (exported) to: # * a dict -#  * a json file +# * a json file # * a pickle file # # The "dump" operation is lazy, i.e., the traces are not exported. @@ -164,7 +168,7 @@ # # If you wish to also store the traces in a compact way you need to use the # :code:`save()` function. This operation is very useful to save traces obtained -# after long computation (e.g. filtering): +# after long computations (e.g. filtering or referencing): recording2.save(folder='./my_recording') diff --git a/examples/modules_gallery/core/plot_3_handle_probe_info.py b/examples/modules_gallery/core/plot_3_handle_probe_info.py index 1900b59433..d134b29ec5 100644 --- a/examples/modules_gallery/core/plot_3_handle_probe_info.py +++ b/examples/modules_gallery/core/plot_3_handle_probe_info.py @@ -4,9 +4,9 @@ In order to properly spike sort, you may need to load information related to the probe you are using. -SpikeInterface internally uses :probeinterface:`ProbeInterface <>` to handle probe or probe groups for recordings. +SpikeInterface internally uses :probeinterface:`ProbeInterface <>` to handle probes or probe groups for recordings. -Depending on the dataset, the :py:class:`~probeinterface.Probe` object can be already included or needs to be set +Depending on the dataset, the :py:class:`~probeinterface.Probe` object may already be included or might need to be set manually. Here's how! @@ -22,7 +22,7 @@ ############################################################################### # This generator already contain a probe object that you can retrieve -# directly an plot: +# directly and plot: probe = recording.get_probe() print(probe) @@ -32,13 +32,13 @@ plot_probe(probe) ############################################################################### -# You can also overwrite the probe. In that case you need to manually make +# You can also overwrite the probe. In this case you need to manually make # the wiring (e.g. virtually connect each electrode to the recording device). # Let's use a probe from Cambridge Neurotech with 32 channels: from probeinterface import get_probe -other_probe = get_probe('cambridgeneurotech', 'ASSY-37-E-1') +other_probe = get_probe(manufacturer='cambridgeneurotech', probe_name='ASSY-37-E-1') print(other_probe) other_probe.set_device_channel_indices(np.arange(32)) @@ -47,8 +47,8 @@ ############################################################################### # Now let's check what we have loaded. The `group_mode='by_shank'` automatically -# set the 'group' property depending on the shank id. -# We can use this information to split the recording in two sub recordings: +# sets the 'group' property depending on the shank id. +# We can use this information to split the recording into two sub-recordings: print(recording_2_shanks) print(recording_2_shanks.get_property('group')) diff --git a/examples/modules_gallery/core/plot_6_handle_times.py b/examples/modules_gallery/core/plot_6_handle_times.py index 81c67fc31d..4ca116e3c6 100644 --- a/examples/modules_gallery/core/plot_6_handle_times.py +++ b/examples/modules_gallery/core/plot_6_handle_times.py @@ -10,16 +10,16 @@ from spikeinterface.extractors import toy_example ############################################################################## -# First let's generate toy example with a single segment: +# First let's generate a toy example with a single segment: rec, sort = toy_example(num_segments=1) ############################################################################## -# Generally, the time information would be automaticall loaded when reading a +# Generally, the time information would be automatically loaded when reading a # recording. # However, sometimes we might need to add a time vector externally. -# For example, now let's create a time vector by getting the default times and +# For example, let's create a time vector by getting the default times and # adding 5 s: default_times = rec.get_times() diff --git a/examples/modules_gallery/extractors/plot_1_read_various_formats.py b/examples/modules_gallery/extractors/plot_1_read_various_formats.py index ed0ba34396..df85946530 100644 --- a/examples/modules_gallery/extractors/plot_1_read_various_formats.py +++ b/examples/modules_gallery/extractors/plot_1_read_various_formats.py @@ -2,10 +2,10 @@ Read various format into SpikeInterface ======================================= -SpikeInterface can read various format of "recording" (traces) and "sorting" (spike train) data. +SpikeInterface can read various formats of "recording" (traces) and "sorting" (spike train) data. Internally, to read different formats, SpikeInterface either uses: - * a wrapper to the `neo `_ rawio classes + * a wrapper to `neo `_ rawio classes * or a direct implementation Note that: @@ -18,14 +18,14 @@ import matplotlib.pyplot as plt -import spikeinterface as si +import spikeinterface.core as si import spikeinterface.extractors as se ############################################################################## # Let's download some datasets in different formats from the # `ephy_testing_data `_ repo: # -# * MEArec: an simulator format which is hdf5-based. It contains both a "recording" and a "sorting" in the same file. +# * MEArec: a simulator format which is hdf5-based. It contains both a "recording" and a "sorting" in the same file. # * Spike2: file from spike2 devices. It contains "recording" information only. @@ -36,14 +36,14 @@ print(mearec_folder_path) ############################################################################## -# Now that we have downloaded the files let's load them into SI. +# Now that we have downloaded the files, let's load them into SI. # # The :py:func:`~spikeinterface.extractors.read_spike2` function returns one object, # a :py:class:`~spikeinterface.core.BaseRecording`. # # Note that internally this file contains 2 data streams ('0' and '1'), so we need to specify which one we # want to retrieve ('0' in our case). -# the stream information can be retrieve using :py:func:`~spikeinterface.extractors.get_neo_streams` function +# the stream information can be retrieved by using the :py:func:`~spikeinterface.extractors.get_neo_streams` function. stream_names, stream_ids = se.get_neo_streams('spike2', spike2_file_path) print(stream_names) @@ -76,13 +76,13 @@ print(type(sorting)) ############################################################################## -#  The :py:func:`~spikeinterface.extractors.read_mearec` function is equivalent to: +# The :py:func:`~spikeinterface.extractors.read_mearec` function is equivalent to: recording = se.MEArecRecordingExtractor(mearec_folder_path) sorting = se.MEArecSortingExtractor(mearec_folder_path) ############################################################################## -# SI objects (:py:class:`~spikeinterface.core.BaseRecording` and :py:class:`~spikeinterface.core.BaseSorting`) object +# SI objects (:py:class:`~spikeinterface.core.BaseRecording` and :py:class:`~spikeinterface.core.BaseSorting`) # can be plotted quickly with the :py:mod:`spikeinterface.widgets` submodule: import spikeinterface.widgets as sw diff --git a/examples/modules_gallery/extractors/plot_2_working_with_unscaled_traces.py b/examples/modules_gallery/extractors/plot_2_working_with_unscaled_traces.py index 5dd8a39582..69a7e889e4 100644 --- a/examples/modules_gallery/extractors/plot_2_working_with_unscaled_traces.py +++ b/examples/modules_gallery/extractors/plot_2_working_with_unscaled_traces.py @@ -3,7 +3,7 @@ ============================ Some file formats store data in convenient types that require offsetting and scaling in order to convert the -traces to uV. This example shows how to work with unscaled and scaled traces int :py:mod:`spikeinterface.extractors` +traces to uV. This example shows how to work with unscaled and scaled traces in the :py:mod:`spikeinterface.extractors` module. ''' @@ -39,21 +39,21 @@ offset = -2 ** (10 - 1) * gain ############################################################################### -# We are now ready to set gains and offsets to our extractor. We also have to set the :code:`has_unscaled` field to +# We are now ready to set gains and offsets for our extractor. We also have to set the :code:`has_unscaled` field to # :code:`True`: recording.set_channel_gains(gain) recording.set_channel_offsets(offset) ############################################################################### -#  Internally this gains and offsets are handle with properties +# Internally the gain and offset are handled with properties # So the gain could be "by channel". print(recording.get_property('gain_to_uV')) print(recording.get_property('offset_to_uV')) ############################################################################### -# With gains and offset information, we can retrieve traces both in their unscaled (raw) type, and in their scaled +# With gain and offset information, we can retrieve traces both in their unscaled (raw) type, and in their scaled # type: traces_unscaled = recording.get_traces(return_scaled=False) diff --git a/examples/modules_gallery/qualitymetrics/plot_3_quality_mertics.py b/examples/modules_gallery/qualitymetrics/plot_3_quality_mertics.py index 7b6aae3e30..7b2fa565b5 100644 --- a/examples/modules_gallery/qualitymetrics/plot_3_quality_mertics.py +++ b/examples/modules_gallery/qualitymetrics/plot_3_quality_mertics.py @@ -2,12 +2,12 @@ Quality Metrics Tutorial ======================== -After spike sorting, you might want to validate the goodness of the sorted units. This can be done using the +After spike sorting, you might want to validate the 'goodness' of the sorted units. This can be done using the :code:`qualitymetrics` submodule, which computes several quality metrics of the sorted units. """ -import spikeinterface as si +import spikeinterface.core as si import spikeinterface.extractors as se from spikeinterface.postprocessing import compute_principal_components from spikeinterface.qualitymetrics import (compute_snrs, compute_firing_rates, @@ -29,10 +29,15 @@ # For convenience, metrics are computed on the :code:`WaveformExtractor` object, # because it contains a reference to the "Recording" and the "Sorting" objects: -folder = 'waveforms_mearec' -we = si.extract_waveforms(recording, sorting, folder, sparse=False, - ms_before=1, ms_after=2., max_spikes_per_unit=500, - n_jobs=1, chunk_durations='1s') +we = si.extract_waveforms(recording=recording, + sorting=sorting, + folder='waveforms_mearec', + sparse=False, + ms_before=1, + ms_after=2., + max_spikes_per_unit=500, + n_jobs=1, + chunk_durations='1s') print(we) ############################################################################## @@ -51,7 +56,7 @@ # Some metrics are based on the principal component scores, so they require a # :code:`WaveformsPrincipalComponent` object as input: -pc = compute_principal_components(we, load_if_exists=True, +pc = compute_principal_components(waveform_extractor=we, load_if_exists=True, n_components=3, mode='by_channel_local') print(pc) diff --git a/examples/modules_gallery/qualitymetrics/plot_4_curation.py b/examples/modules_gallery/qualitymetrics/plot_4_curation.py index edd7a85ce5..2568452de3 100644 --- a/examples/modules_gallery/qualitymetrics/plot_4_curation.py +++ b/examples/modules_gallery/qualitymetrics/plot_4_curation.py @@ -3,13 +3,13 @@ ================== After spike sorting and computing quality metrics, you can automatically curate the spike sorting output using the -quality metrics. +quality metrics that you have calculated. """ ############################################################################# # Import the modules and/or functions necessary from spikeinterface -import spikeinterface as si +import spikeinterface.core as si import spikeinterface.extractors as se from spikeinterface.postprocessing import compute_principal_components @@ -29,11 +29,16 @@ ############################################################################## # First, we extract waveforms (to be saved in the folder 'wfs_mearec') and -# compute their PC scores: - -we = si.extract_waveforms(recording, sorting, folder='wfs_mearec', - ms_before=1, ms_after=2., max_spikes_per_unit=500, - n_jobs=1, chunk_size=30000) +# compute their PC (principal component) scores: + +we = si.extract_waveforms(recording=recording, + sorting=sorting, + folder='wfs_mearec', + ms_before=1, + ms_after=2., + max_spikes_per_unit=500, + n_jobs=1, + chunk_size=30000) print(we) pc = compute_principal_components(we, load_if_exists=True, n_components=3, mode='by_channel_local') @@ -42,7 +47,7 @@ ############################################################################## # Then we compute some quality metrics: -metrics = compute_quality_metrics(we, metric_names=['snr', 'isi_violation', 'nearest_neighbor']) +metrics = compute_quality_metrics(waveform_extractor=we, metric_names=['snr', 'isi_violation', 'nearest_neighbor']) print(metrics) ############################################################################## From eb00e8bdba72087d540d0c1fd8ff01d5eb0a1e97 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Mon, 20 Nov 2023 12:16:36 +0100 Subject: [PATCH 43/48] Update src/spikeinterface/comparison/groundtruthstudy.py Co-authored-by: Garcia Samuel --- src/spikeinterface/comparison/groundtruthstudy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index adc2898071..23d13c0afe 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -21,7 +21,7 @@ # This is to separate names when the key are tuples when saving folders -_key_separator = "--" +_key_separator = "_##_" class GroundTruthStudy: From 1d20cfd061de6e624da6aedac76a2c9edf921d8a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 20 Nov 2023 15:40:56 +0100 Subject: [PATCH 44/48] Few more fixes to API and viewers --- doc/api.rst | 43 ++++++++----------- doc/modules/widgets.rst | 5 ++- doc/viewers.rst | 2 +- .../comparison/groundtruthstudy.py | 2 +- .../postprocessing/spike_locations.py | 17 ++++---- src/spikeinterface/widgets/gtstudy.py | 11 ++--- src/spikeinterface/widgets/multicomparison.py | 4 -- 7 files changed, 38 insertions(+), 46 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index 97c956c2f6..ab81b1596a 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -239,7 +239,6 @@ spikeinterface.comparison .. autofunction:: compare_sorter_to_ground_truth .. autofunction:: compare_templates .. autofunction:: compare_multiple_templates - .. autofunction:: aggregate_performances_table .. autofunction:: create_hybrid_units_recording .. autofunction:: create_hybrid_spikes_recording @@ -272,12 +271,22 @@ spikeinterface.widgets .. autofunction:: set_default_plotter_backend .. autofunction:: get_default_plotter_backend + .. autofunction:: plot_agreement_matrix .. autofunction:: plot_all_amplitudes_distributions .. autofunction:: plot_amplitudes .. autofunction:: plot_autocorrelograms + .. autofunction:: plot_confusion_matrix + .. autofunction:: plot_comparison_collision_by_similarity .. autofunction:: plot_crosscorrelograms + .. autofunction:: plot_isi_distribution .. autofunction:: plot_motion + .. autofunction:: plot_multicomparison_agreement + .. autofunction:: plot_multicomparison_agreement_by_sorter + .. autofunction:: plot_multicomparison_graph + .. autofunction:: plot_peak_activity + .. autofunction:: plot_probe_map .. autofunction:: plot_quality_metrics + .. autofunction:: plot_rasters .. autofunction:: plot_sorting_summary .. autofunction:: plot_spike_locations .. autofunction:: plot_spikes_on_traces @@ -286,34 +295,18 @@ spikeinterface.widgets .. autofunction:: plot_traces .. autofunction:: plot_unit_depths .. autofunction:: plot_unit_locations + .. autofunction:: plot_unit_presence + .. autofunction:: plot_unit_probe_map .. autofunction:: plot_unit_summary .. autofunction:: plot_unit_templates .. autofunction:: plot_unit_waveforms_density_map .. autofunction:: plot_unit_waveforms - - -Legacy widgets -~~~~~~~~~~~~~~ - -These widgets are only available with the "matplotlib" backend - -.. automodule:: spikeinterface.widgets - :noindex: - - .. autofunction:: plot_rasters - .. autofunction:: plot_probe_map - .. autofunction:: plot_isi_distribution - .. autofunction:: plot_peak_activity_map - .. autofunction:: plot_principal_component - .. autofunction:: plot_unit_probe_map - .. autofunction:: plot_confusion_matrix - .. autofunction:: plot_agreement_matrix - .. autofunction:: plot_multicomp_graph - .. autofunction:: plot_multicomp_agreement - .. autofunction:: plot_multicomp_agreement_by_sorter - .. autofunction:: plot_comparison_collision_pair_by_pair - .. autofunction:: plot_comparison_collision_by_similarity - .. autofunction:: plot_sorting_performance + .. autofunction:: plot_study_run_times + .. autofunction:: plot_study_unit_counts + .. autofunction:: plot_study_performances + .. autofunction:: plot_study_agreement_matrix + .. autofunction:: plot_study_summary + .. autofunction:: plot_study_comparison_collision_by_similarity spikeinterface.exporters diff --git a/doc/modules/widgets.rst b/doc/modules/widgets.rst index 68106d13a9..ca097b4729 100644 --- a/doc/modules/widgets.rst +++ b/doc/modules/widgets.rst @@ -14,8 +14,9 @@ Since version 0.95.0, the :py:mod:`spikeinterface.widgets` module supports multi * | :code:`sortingview`: web-based and interactive rendering using the `sortingview `_ | and `FIGURL `_ packages. -Version 0.100.0, also comes with this new backend: -* | :code:`ephyviewer`: interactive Qt based using the `ephyviewer `_ package +Version 0.99.0 also comes with this new backend: + +* :code:`ephyviewer`: interactive Qt based using the `ephyviewer `_ package Installing backends diff --git a/doc/viewers.rst b/doc/viewers.rst index a906ee29db..c3ada31b55 100644 --- a/doc/viewers.rst +++ b/doc/viewers.rst @@ -16,7 +16,7 @@ spikeinterface.widgets The easiest way to visualize :code:`spikeinterface` objects is to use the :code:`widgets` module for plotting. You can find an extensive description in the module documentation :ref:`modulewidgets` -and many examples in this tutorial :ref:`sphx_glr_modules_gallery_widgets`. +and many examples in the :code:`Widgets tutorials` section of the :code:`Modules example gallery`. spikeinterface-gui ------------------ diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index 0d08922543..7541c394b3 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -38,7 +38,7 @@ class GroundTruthStudy: In this case, the result dataframes will have `MultiIndex` to handle the different levels. A ground-truth dataset is made of a `Recording` and a `Sorting` object. For example, it can be a simulated dataset with MEArec or internally generated (see - :py:fun:`~spikeinterface.core.generate.generate_ground_truth_recording()`). + :py:func:`~spikeinterface.core.generate.generate_ground_truth_recording()`). This GroundTruthStudy have been refactor in version 0.100 to be more flexible than previous versions. Note that the underlying folder structure is not backward compatible! diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index 72d44bf348..dfa940b979 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -150,14 +150,15 @@ def compute_spike_locations( spike_retriver_kwargs: dict A dictionary to control the behavior for getting the maximum channel for each spike This dictionary contains: - * channel_from_template: bool, default: True - For each spike is the maximum channel computed from template or re estimated at every spikes - channel_from_template = True is old behavior but less acurate - channel_from_template = False is slower but more accurate - * radius_um: float, default: 50 - In case channel_from_template=False, this is the radius to get the true peak - * peak_sign, default: "neg" - In case channel_from_template=False, this is the peak sign. + + * channel_from_template: bool, default: True + For each spike is the maximum channel computed from template or re estimated at every spikes + channel_from_template = True is old behavior but less acurate + channel_from_template = False is slower but more accurate + * radius_um: float, default: 50 + In case channel_from_template=False, this is the radius to get the true peak + * peak_sign, default: "neg" + In case channel_from_template=False, this is the peak sign. method : "center_of_mass" | "monopolar_triangulation" | "grid_convolution", default: "center_of_mass" The localization method to use method_kwargs : dict, default: dict() diff --git a/src/spikeinterface/widgets/gtstudy.py b/src/spikeinterface/widgets/gtstudy.py index 5e934f9702..91e2c382b4 100644 --- a/src/spikeinterface/widgets/gtstudy.py +++ b/src/spikeinterface/widgets/gtstudy.py @@ -297,11 +297,12 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): class StudySummary(BaseWidget): """ Plot a summary of a ground truth study. - Internally does: - plot_study_run_times - plot_study_unit_counts - plot_study_performances - plot_study_agreement_matrix + Internally this plotting function runs: + + * plot_study_run_times + * plot_study_unit_counts + * plot_study_performances + * plot_study_agreement_matrix Parameters ---------- diff --git a/src/spikeinterface/widgets/multicomparison.py b/src/spikeinterface/widgets/multicomparison.py index fb34156fef..0917869f8c 100644 --- a/src/spikeinterface/widgets/multicomparison.py +++ b/src/spikeinterface/widgets/multicomparison.py @@ -206,10 +206,6 @@ class MultiCompAgreementBySorterWidget(BaseWidget): show_legend: bool Show the legend in the last axes - Returns - ------- - W: MultiCompGraphWidget - The output widget """ def __init__( From 4cad6bbfcb76b7c2db4dcdb8e05eb9e38a090960 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 20 Nov 2023 16:27:55 +0100 Subject: [PATCH 45/48] Update doc/modules/sorters.rst --- doc/modules/sorters.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/modules/sorters.rst b/doc/modules/sorters.rst index e0005e51fe..98a5ea4fcf 100644 --- a/doc/modules/sorters.rst +++ b/doc/modules/sorters.rst @@ -14,7 +14,7 @@ module. **Note that internal sorters are currently experimental and under develo A drawback of using external sorters is the separate installation of these tools. Sometimes they need MATLAB, specific versions of CUDA, specific gcc versions or outdated versions of -Python/NumPy. In this case, SpikeInterface offer the mechanism of running external sorters inside a +Python/NumPy. In this case, SpikeInterface offers the mechanism of running external sorters inside a container (Docker/Singularity) with the sorter pre-installed. See :ref:`containerizedsorters`. From 35b702612783c9c4675d10d66d4b7941986d2126 Mon Sep 17 00:00:00 2001 From: Zach McKenzie <92116279+zm711@users.noreply.github.com> Date: Mon, 20 Nov 2023 16:01:49 -0500 Subject: [PATCH 46/48] additional typo fixes in documentation --- doc/modules/core.rst | 6 +++--- doc/modules/widgets.rst | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/doc/modules/core.rst b/doc/modules/core.rst index 697681bcb1..656176f27a 100644 --- a/doc/modules/core.rst +++ b/doc/modules/core.rst @@ -258,7 +258,7 @@ Finally, an existing :py:class:`~spikeinterface.core.WaveformExtractor` can be s we_zarr = we.save(folder="waveforms_zarr", format="zarr") # extract sparse waveforms (see Sparsity section) - # this will use 50 spike per unit to estimate the sparsity of 40um radius for each unit + # this will use 50 spikes per unit to estimate the sparsity within a 40um radius from that unit we_sparse = extract_waveforms(recording=recording, sorting=sorting, folder="waveforms_sparse", @@ -575,7 +575,7 @@ In order to do this, one can use the :code:`Numpy*` classes, :py:class:`~spikein but they are not bound to a file. Also note the class :py:class:`~spikeinterface.core.SharedMemorySorting` which is very similar to -Similar to :py:class:`~spikeinterface.core.NumpySorting` but with an unerlying SharedMemory which is useful for +Similar to :py:class:`~spikeinterface.core.NumpySorting` but with an underlying SharedMemory which is useful for parallel computing. In this example, we create a recording and a sorting object from numpy objects: @@ -617,7 +617,7 @@ Any sorting object can be transformed into a :py:class:`~spikeinterface.core.Num # turn any sortinto into NumpySorting sorting_np = sorting.to_numpy_sorting() - # or to SharedMemorySorting for parrallel computing + # or to SharedMemorySorting for parallel computing sorting_shm = sorting.to_shared_memory_sorting() diff --git a/doc/modules/widgets.rst b/doc/modules/widgets.rst index ca097b4729..4d69867d83 100644 --- a/doc/modules/widgets.rst +++ b/doc/modules/widgets.rst @@ -265,7 +265,7 @@ Available plotting functions * :py:func:`~spikeinterface.widgets.plot_amplitudes` (backends: :code:`matplotlib`, :code:`ipywidgets`, :code:`sortingview`) * :py:func:`~spikeinterface.widgets.plot_autocorrelograms` (backends: :code:`matplotlib`, :code:`sortingview`) * :py:func:`~spikeinterface.widgets.plot_confusion_matrix` (backends: :code:`matplotlib`) -* :py:func:`~spikeinterface.widgets.plot_comparison_collision_by_similarity` +* :py:func:`~spikeinterface.widgets.plot_comparison_collision_by_similarity` (backends: :code:`matplotlib`) * :py:func:`~spikeinterface.widgets.plot_crosscorrelograms` (backends: :code:`matplotlib`, :code:`sortingview`) * :py:func:`~spikeinterface.widgets.plot_isi_distribution` (backends: :code:`matplotlib`) * :py:func:`~spikeinterface.widgets.plot_motion` (backends: :code:`matplotlib`) From 231f352bda0212f412c44325a8b001bb7c1789d2 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 21 Nov 2023 12:14:49 +0100 Subject: [PATCH 47/48] Move extract_waveforms_to_single_buffer to tools.py --- .../sorters/internal/tridesclous2.py | 46 +++----------- .../sortingcomponents/clustering/circus.py | 38 +---------- src/spikeinterface/sortingcomponents/tools.py | 63 ++++++++++++++----- 3 files changed, 55 insertions(+), 92 deletions(-) diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index eb2ddc922d..9e67bbf4f4 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -9,12 +9,14 @@ NumpySorting, get_channel_distances, ) -from spikeinterface.core.waveform_tools import extract_waveforms_to_single_buffer + from spikeinterface.core.job_tools import fix_job_kwargs from spikeinterface.preprocessing import bandpass_filter, common_reference, zscore from spikeinterface.core.basesorting import minimum_spike_dtype +from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel + import numpy as np import pickle @@ -115,9 +117,12 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if verbose: print("We kept %d peaks for clustering" % len(peaks)) + ms_before = params["waveforms"]["ms_before"] + ms_after = params["waveforms"]["ms_after"] + # SVD for time compression few_peaks = select_peaks(peaks, method="uniform", n_peaks=5000) - few_wfs = extract_waveform_at_max_channel(recording, few_peaks, **job_kwargs) + few_wfs = extract_waveform_at_max_channel(recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs) wfs = few_wfs[:, :, 0] tsvd = TruncatedSVD(params["svd"]["n_components"]) @@ -129,8 +134,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): with open(model_folder / "pca_model.pkl", "wb") as f: pickle.dump(tsvd, f) - ms_before = params["waveforms"]["ms_before"] - ms_after = params["waveforms"]["ms_after"] model_params = { "ms_before": ms_before, "ms_after": ms_after, @@ -321,37 +324,4 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): return sorting -def extract_waveform_at_max_channel(rec, peaks, ms_before=0.5, ms_after=1.5, **job_kwargs): - """ - Helper function to extractor waveforms at max channel from a peak list - - - """ - n = rec.get_num_channels() - unit_ids = np.arange(n, dtype="int64") - sparsity_mask = np.eye(n, dtype="bool") - - spikes = np.zeros( - peaks.size, dtype=[("sample_index", "int64"), ("unit_index", "int64"), ("segment_index", "int64")] - ) - spikes["sample_index"] = peaks["sample_index"] - spikes["unit_index"] = peaks["channel_index"] - spikes["segment_index"] = peaks["segment_index"] - - nbefore = int(ms_before * rec.sampling_frequency / 1000.0) - nafter = int(ms_after * rec.sampling_frequency / 1000.0) - - all_wfs = extract_waveforms_to_single_buffer( - rec, - spikes, - unit_ids, - nbefore, - nafter, - mode="shared_memory", - return_scaled=False, - sparsity_mask=sparsity_mask, - copy=True, - **job_kwargs, - ) - - return all_wfs + diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 47c5a1e58f..4dbd88c411 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -18,8 +18,6 @@ from .clustering_tools import remove_duplicates, remove_duplicates_via_matching, remove_duplicates_via_dip from spikeinterface.core import NumpySorting from spikeinterface.core import extract_waveforms -from spikeinterface.core.recording_tools import get_channel_distances, get_random_data_chunks -from spikeinterface.core.waveform_tools import extract_waveforms_to_single_buffer from spikeinterface.sortingcomponents.peak_selection import select_peaks from spikeinterface.sortingcomponents.waveforms.temporal_pca import TemporalPCAProjection from sklearn.decomposition import TruncatedSVD @@ -30,43 +28,9 @@ ExtractSparseWaveforms, PeakRetriever, ) +from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel -def extract_waveform_at_max_channel(rec, peaks, ms_before=0.5, ms_after=1.5, **job_kwargs): - """ - Helper function to extract waveforms at the max channel from a peak list - - - """ - n = rec.get_num_channels() - unit_ids = np.arange(n, dtype="int64") - sparsity_mask = np.eye(n, dtype="bool") - - spikes = np.zeros( - peaks.size, dtype=[("sample_index", "int64"), ("unit_index", "int64"), ("segment_index", "int64")] - ) - spikes["sample_index"] = peaks["sample_index"] - spikes["unit_index"] = peaks["channel_index"] - spikes["segment_index"] = peaks["segment_index"] - - nbefore = int(ms_before * rec.sampling_frequency / 1000.0) - nafter = int(ms_after * rec.sampling_frequency / 1000.0) - - all_wfs = extract_waveforms_to_single_buffer( - rec, - spikes, - unit_ids, - nbefore, - nafter, - mode="shared_memory", - return_scaled=False, - sparsity_mask=sparsity_mask, - copy=True, - **job_kwargs, - ) - - return all_wfs - class CircusClustering: """ diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index cd9226d5e8..1e8a933990 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -1,7 +1,7 @@ import numpy as np from spikeinterface.core.node_pipeline import run_node_pipeline, ExtractSparseWaveforms, PeakRetriever - +from spikeinterface.core.waveform_tools import extract_waveforms_to_single_buffer def make_multi_method_doc(methods, ident=" "): doc = "" @@ -18,23 +18,52 @@ def make_multi_method_doc(methods, ident=" "): return doc -def get_prototype_spike(recording, peaks, job_kwargs, nb_peaks=1000, ms_before=0.5, ms_after=0.5): - # TODO for Pierre: this function is really inefficient because it runs a full pipeline only for a few - # spikes, which means that all traces need to be accesses! Please find a better way - nb_peaks = min(len(peaks), nb_peaks) - idx = np.sort(np.random.choice(len(peaks), nb_peaks, replace=False)) - peak_retriever = PeakRetriever(recording, peaks[idx]) - - sparse_waveforms = ExtractSparseWaveforms( - recording, - parents=[peak_retriever], - ms_before=ms_before, - ms_after=ms_after, - return_output=True, - radius_um=5, +def extract_waveform_at_max_channel(rec, peaks, ms_before=0.5, ms_after=1.5, **job_kwargs): + """ + Helper function to extract waveforms at the max channel from a peak list + + + """ + n = rec.get_num_channels() + unit_ids = np.arange(n, dtype="int64") + sparsity_mask = np.eye(n, dtype="bool") + + spikes = np.zeros( + peaks.size, dtype=[("sample_index", "int64"), ("unit_index", "int64"), ("segment_index", "int64")] + ) + spikes["sample_index"] = peaks["sample_index"] + spikes["unit_index"] = peaks["channel_index"] + spikes["segment_index"] = peaks["segment_index"] + + nbefore = int(ms_before * rec.sampling_frequency / 1000.0) + nafter = int(ms_after * rec.sampling_frequency / 1000.0) + + all_wfs = extract_waveforms_to_single_buffer( + rec, + spikes, + unit_ids, + nbefore, + nafter, + mode="shared_memory", + return_scaled=False, + sparsity_mask=sparsity_mask, + copy=True, + **job_kwargs, ) - nbefore = sparse_waveforms.nbefore - waveforms = run_node_pipeline(recording, [peak_retriever, sparse_waveforms], job_kwargs=job_kwargs) + return all_wfs + + +def get_prototype_spike(recording, peaks, job_kwargs, nb_peaks=1000, ms_before=0.5, ms_after=0.5): + + if peaks.size > nb_peaks: + idx = np.sort(np.random.choice(len(peaks), nb_peaks, replace=False)) + some_peaks = peaks[idx] + else: + some_peaks = peaks + + nbefore = int(ms_before * recording.sampling_frequency / 1000.0) + + waveforms = extract_waveform_at_max_channel(recording, some_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs) prototype = np.median(waveforms[:, :, 0] / (waveforms[:, nbefore, 0][:, np.newaxis]), axis=0) return prototype From 20cc70ef9b6e4fb6a18e6942fb2449000105f82e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 21 Nov 2023 11:16:39 +0000 Subject: [PATCH 48/48] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sorters/internal/tridesclous2.py | 7 +++---- src/spikeinterface/sortingcomponents/clustering/circus.py | 1 - src/spikeinterface/sortingcomponents/tools.py | 6 ++++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index 9e67bbf4f4..6d53414c9f 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -122,7 +122,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # SVD for time compression few_peaks = select_peaks(peaks, method="uniform", n_peaks=5000) - few_wfs = extract_waveform_at_max_channel(recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs) + few_wfs = extract_waveform_at_max_channel( + recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs + ) wfs = few_wfs[:, :, 0] tsvd = TruncatedSVD(params["svd"]["n_components"]) @@ -322,6 +324,3 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): sorting = sorting.save(folder=sorter_output_folder / "sorting") return sorting - - - diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 4dbd88c411..238b16260c 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -31,7 +31,6 @@ from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel - class CircusClustering: """ hdbscan clustering on peak_locations previously done by localize_peaks() diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 1e8a933990..328e3b715d 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -3,6 +3,7 @@ from spikeinterface.core.node_pipeline import run_node_pipeline, ExtractSparseWaveforms, PeakRetriever from spikeinterface.core.waveform_tools import extract_waveforms_to_single_buffer + def make_multi_method_doc(methods, ident=" "): doc = "" @@ -55,7 +56,6 @@ def extract_waveform_at_max_channel(rec, peaks, ms_before=0.5, ms_after=1.5, **j def get_prototype_spike(recording, peaks, job_kwargs, nb_peaks=1000, ms_before=0.5, ms_after=0.5): - if peaks.size > nb_peaks: idx = np.sort(np.random.choice(len(peaks), nb_peaks, replace=False)) some_peaks = peaks[idx] @@ -64,6 +64,8 @@ def get_prototype_spike(recording, peaks, job_kwargs, nb_peaks=1000, ms_before=0 nbefore = int(ms_before * recording.sampling_frequency / 1000.0) - waveforms = extract_waveform_at_max_channel(recording, some_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs) + waveforms = extract_waveform_at_max_channel( + recording, some_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs + ) prototype = np.median(waveforms[:, :, 0] / (waveforms[:, nbefore, 0][:, np.newaxis]), axis=0) return prototype