diff --git a/doc/api.rst b/doc/api.rst index 6bb9b39091..eb9a61eb9c 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -346,6 +346,9 @@ spikeinterface.curation .. autofunction:: remove_redundant_units .. autofunction:: remove_duplicated_spikes .. autofunction:: remove_excess_spikes + .. autofunction:: load_model + .. autofunction:: auto_label_units + .. autofunction:: train_model Deprecated ~~~~~~~~~~ diff --git a/doc/conf.py b/doc/conf.py index e3d58ca8f2..d229dc18ee 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -119,12 +119,15 @@ # for sphinx gallery plugin sphinx_gallery_conf = { - 'only_warn_on_example_error': True, + # This is the default but including here explicitly. Should build all docs and fail on gallery failures only. + # other option would be abort_on_example_error, but this fails on first failure. So we decided against this. + 'only_warn_on_example_error': False, 'examples_dirs': ['../examples/tutorials'], 'gallery_dirs': ['tutorials' ], # path where to save gallery generated examples 'subsection_order': ExplicitOrder([ '../examples/tutorials/core', '../examples/tutorials/extractors', + '../examples/tutorials/curation', '../examples/tutorials/qualitymetrics', '../examples/tutorials/comparison', '../examples/tutorials/widgets', diff --git a/doc/how_to/auto_curation_prediction.rst b/doc/how_to/auto_curation_prediction.rst new file mode 100644 index 0000000000..9b1612ec12 --- /dev/null +++ b/doc/how_to/auto_curation_prediction.rst @@ -0,0 +1,43 @@ +How to use a trained model to predict the curation labels +========================================================= + +For a more detailed guide to using trained models, `read our tutorial here +`_). + +There is a Collection of models for automated curation available on the +`SpikeInterface HuggingFace page `_. + +We'll apply the model ``toy_tetrode_model`` from ``SpikeInterface`` on a SortingAnalyzer +called ``sorting_analyzer``. We assume that the quality and template metrics have +already been computed. + +We need to pass the ``sorting_analyzer``, the ``repo_id`` (which is just the part of the +repo's URL after huggingface.co/) and that we trust the model. + +.. code:: + + from spikeinterface.curation import auto_label_units + + labels_and_probabilities = auto_label_units( + sorting_analyzer = sorting_analyzer, + repo_id = "SpikeInterface/toy_tetrode_model", + trust_model = True + ) + +If you have a local directory containing the model in a ``skops`` file you can use this to +create the labels: + +.. code:: + + labels_and_probabilities = si.auto_label_units( + sorting_analyzer = sorting_analyzer, + model_folder = "my_folder_with_a_model_in_it", + ) + +The returned labels are a dictionary of model's predictions and it's confidence. These +are also saved as a property of your ``sorting_analyzer`` and can be accessed like so: + +.. code:: + + labels = sorting_analyzer.sorting.get_property("classifier_label") + probabilities = sorting_analyzer.sorting.get_property("classifier_probability") diff --git a/doc/how_to/auto_curation_training.rst b/doc/how_to/auto_curation_training.rst new file mode 100644 index 0000000000..20ab57d284 --- /dev/null +++ b/doc/how_to/auto_curation_training.rst @@ -0,0 +1,58 @@ +How to train a model to predict curation labels +=============================================== + +A full tutorial for model-based curation can be found `here `_. + +Here, we assume that you have: + +* Two SortingAnalyzers called ``analyzer_1`` and + ``analyzer_2``, and have calculated some template and quality metrics for both +* Manually curated labels for the units in each analyzer, in lists called + ``analyzer_1_labels`` and ``analyzer_2_labels``. If you have used phy, the lists can + be accessed using ``curated_labels = analyzer.sorting.get_property("quality")``. + +With these objects calculated, you can train a model as follows + +.. code:: + + from spikeinterface.curation import train_model + + analyzer_list = [analyzer_1, analyzer_2] + labels_list = [analyzer_1_labels, analyzer_2_labels] + output_folder = "/path/to/output_folder" + + trainer = train_model( + mode="analyzers", + labels=labels_list, + analyzers=analyzer_list, + output_folder=output_folder, + metric_names=None, # Set if you want to use a subset of metrics, defaults to all calculated quality and template metrics + imputation_strategies=None, # Default is all available imputation strategies + scaling_techniques=None, # Default is all available scaling techniques + classifiers=None, # Defaults to Random Forest classifier only - we usually find this gives the best results, but a range of classifiers is available + seed=None, # Set a seed for reproducibility + ) + + +The trainer tries several models and chooses the most accurate one. This model and +some metadata are stored in the ``output_folder``, which can later be loaded using the +``load_model`` function (`more details `_). +We can also access the model, which is an sklearn ``Pipeline``, from the trainer object + +.. code:: + + best_model = trainer.best_pipeline + + +The training function can also be run in “csv” mode, if you prefer to +store metrics in as .csv files. If the target labels are stored as a column in +the file, you can point to these with the ``target_label`` parameter + +.. code:: + + trainer = train_model( + mode="csv", + metrics_paths = ["/path/to/csv_file_1", "/path/to/csv_file_2"], + target_label = "my_label", + output_folder=output_folder, + ) diff --git a/doc/how_to/index.rst b/doc/how_to/index.rst index 5d7eae9003..7f79156a3b 100644 --- a/doc/how_to/index.rst +++ b/doc/how_to/index.rst @@ -15,3 +15,5 @@ Guides on how to solve specific, short problems in SpikeInterface. Learn how to. load_your_data_into_sorting benchmark_with_hybrid_recordings drift_with_lfp + auto_curation_training + auto_curation_prediction diff --git a/doc/images/files_screen.png b/doc/images/files_screen.png new file mode 100644 index 0000000000..ef2b5b0873 Binary files /dev/null and b/doc/images/files_screen.png differ diff --git a/doc/images/hf-logo.svg b/doc/images/hf-logo.svg new file mode 100644 index 0000000000..ab959d165f --- /dev/null +++ b/doc/images/hf-logo.svg @@ -0,0 +1,8 @@ + + + + + + + + diff --git a/doc/images/initial_model_screen.png b/doc/images/initial_model_screen.png new file mode 100644 index 0000000000..b01c4248a6 Binary files /dev/null and b/doc/images/initial_model_screen.png differ diff --git a/doc/tutorials_custom_index.rst b/doc/tutorials_custom_index.rst index 4c7625d811..82f2c06eed 100644 --- a/doc/tutorials_custom_index.rst +++ b/doc/tutorials_custom_index.rst @@ -119,8 +119,8 @@ The :code:`spikeinterface.qualitymetrics` module allows users to compute various .. grid-item-card:: Quality Metrics :link-type: ref - :link: sphx_glr_tutorials_qualitymetrics_plot_3_quality_mertics.py - :img-top: /tutorials/qualitymetrics/images/thumb/sphx_glr_plot_3_quality_mertics_thumb.png + :link: sphx_glr_tutorials_qualitymetrics_plot_3_quality_metrics.py + :img-top: /tutorials/qualitymetrics/images/thumb/sphx_glr_plot_3_quality_metrics_thumb.png :img-alt: Quality Metrics :class-card: gallery-card :text-align: center @@ -133,6 +133,39 @@ The :code:`spikeinterface.qualitymetrics` module allows users to compute various :class-card: gallery-card :text-align: center +Automated curation tutorials +---------------------------- + +Learn how to curate your units using a trained machine learning model. Or how to create +and share your own model. + +.. grid:: 1 2 2 3 + :gutter: 2 + + .. grid-item-card:: Model-based curation + :link-type: ref + :link: sphx_glr_tutorials_curation_plot_1_automated_curation.py + :img-top: /tutorials/curation/images/sphx_glr_plot_1_automated_curation_002.png + :img-alt: Model-based curation + :class-card: gallery-card + :text-align: center + + .. grid-item-card:: Train your own model + :link-type: ref + :link: sphx_glr_tutorials_curation_plot_2_train_a_model.py + :img-top: /tutorials/curation/images/thumb/sphx_glr_plot_2_train_a_model_thumb.png + :img-alt: Train your own model + :class-card: gallery-card + :text-align: center + + .. grid-item-card:: Upload your model to HuggingFaceHub + :link-type: ref + :link: sphx_glr_tutorials_curation_plot_3_upload_a_model.py + :img-top: /images/hf-logo.svg + :img-alt: Upload your model + :class-card: gallery-card + :text-align: center + Comparison tutorial ------------------- diff --git a/examples/tutorials/curation/README.rst b/examples/tutorials/curation/README.rst new file mode 100644 index 0000000000..0f64179e65 --- /dev/null +++ b/examples/tutorials/curation/README.rst @@ -0,0 +1,5 @@ +Curation tutorials +------------------ + +Learn how to use models to automatically curated your sorted data, or generate models +based on your own curation. diff --git a/examples/tutorials/curation/plot_1_automated_curation.py b/examples/tutorials/curation/plot_1_automated_curation.py new file mode 100644 index 0000000000..e88b0973df --- /dev/null +++ b/examples/tutorials/curation/plot_1_automated_curation.py @@ -0,0 +1,287 @@ +""" +Model-based curation tutorial +============================= + +Sorters are not perfect. They output excellent units, as well as noisy ones, and ones that +should be split or merged. Hence one should curate the generated units. Historically, this +has been done using laborious manual curation. An alternative is to use automated methods +based on metrics which quantify features of the units. In spikeinterface these are the +quality metrics and the template metrics. A simple approach is to use thresholding: +only accept units whose metrics pass a certain quality threshold. Another approach is to +take one (or more) manually labelled sortings, whose metrics have been computed, and train +a machine learning model to predict labels. + +This notebook provides a step-by-step guide on how to take a machine learning model that +someone else has trained and use it to curate your own spike sorted output. SpikeInterface +also provides the tools to train your own model, +`which you can learn about here `_. + +We'll download a toy model and use it to label our sorted data. We start by importing some packages +""" + +import warnings +warnings.filterwarnings("ignore") +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt + +import spikeinterface.core as si +import spikeinterface.curation as sc +import spikeinterface.widgets as sw + +# note: you can use more cores using e.g. +# si.set_global_jobs_kwargs(n_jobs = 8) + +############################################################################## +# Download a pretrained model +# --------------------------- +# +# Let's download a pretrained model from `Hugging Face `_ (HF), +# a model sharing platform focused on AI and ML models and datasets. The +# ``load_model`` function allows us to download directly from HF, or use a model in a local +# folder. The function downloads the model and saves it in a temporary folder and returns a +# model and some metadata about the model. + +model, model_info = sc.load_model( + repo_id = "SpikeInterface/toy_tetrode_model", + trusted = ['numpy.dtype'] +) + + +############################################################################## +# This model was trained on artifically generated tetrode data. There are also models trained +# on real data, like the one discussed `below <#A-model-trained-on-real-Neuropixels-data>`_. +# Each model object has a nice html representation, which will appear if you're using a Jupyter notebook. + +model + +############################################################################## +# This tells us more information about the model. The one we've just downloaded was trained used +# a ``RandomForestClassifier```. You can also discover this information by running +# ``model.get_params()``. The model object (an `sklearn Pipeline `_) also contains information +# about which metrics were used to compute the model. We can access it from the model (or from the model_info) + +print(model.feature_names_in_) + +############################################################################## +# Hence, to use this model we need to create a ``sorting_analyzer`` with all these metrics computed. +# We'll do this by generating a recording and sorting, creating a sorting analyzer and computing a +# bunch of extensions. Follow these links for more info on `recordings `_, `sortings `_, `sorting analyzers `_ +# and `extensions `_. + +recording, sorting = si.generate_ground_truth_recording(num_channels=4, seed=4, num_units=10) +sorting_analyzer = si.create_sorting_analyzer(sorting=sorting, recording=recording) +sorting_analyzer.compute(['noise_levels','random_spikes','waveforms','templates','spike_locations','spike_amplitudes','correlograms','principal_components','quality_metrics','template_metrics']) +sorting_analyzer.compute('template_metrics', include_multi_channel_metrics=True) + +############################################################################## +# This sorting_analyzer now contains the required quality metrics and template metrics. +# We can check that this is true by accessing the extension data. + +all_metric_names = list(sorting_analyzer.get_extension('quality_metrics').get_data().keys()) + list(sorting_analyzer.get_extension('template_metrics').get_data().keys()) +print(set(all_metric_names) == set(model.feature_names_in_)) + +############################################################################## +# Great! We can now use the model to predict labels. Here, we pass the HF repo id directly +# to the ``auto_label_units`` function. This returns a dictionary containing a label and +# a confidence for each unit contained in the ``sorting_analyzer``. + +labels = sc.auto_label_units( + sorting_analyzer = sorting_analyzer, + repo_id = "SpikeInterface/toy_tetrode_model", + trusted = ['numpy.dtype'] +) + +print(labels) + + +############################################################################## +# The model has labelled one unit as bad. Let's look at that one, and also the 'good' unit +# with the highest confidence of being 'good'. + +sw.plot_unit_templates(sorting_analyzer, unit_ids=['7','9']) + +############################################################################## +# Nice! Unit 9 looks more like an expected action potential waveform while unit 7 doesn't, +# and it seems reasonable that unit 7 is labelled as `bad`. However, for certain experiments +# or brain areas, unit 7 might be a great small-amplitude unit. This example highlights that +# you should be careful applying models trained on one dataset to your own dataset. You can +# explore the currently available models on the `spikeinterface hugging face hub `_ +# page, or `train your own one `_. +# +# Assess the model performance +# ---------------------------- +# +# To assess the performance of the model relative to labels assigned by a human creator, we can load or generate some +# "human labels", and plot a confusion matrix of predicted vs human labels for all clusters. Here +# we'll be a conservative human, who has labelled several units with small amplitudes as 'bad'. + +human_labels = ['bad', 'good', 'good', 'bad', 'good', 'bad', 'good', 'bad', 'good', 'good'] + +# Note: if you labelled using phy, you can load the labels using: +# human_labels = sorting_analyzer.sorting.get_property('quality') +# We need to load in the `label_conversion` dictionary, which converts integers such +# as '0' and '1' to readable labels such as 'good' and 'bad'. This is stored as +# in `model_info`, which we loaded earlier. + +from sklearn.metrics import confusion_matrix, balanced_accuracy_score + +label_conversion = model_info['label_conversion'] +predictions = labels['prediction'] + +conf_matrix = confusion_matrix(human_labels, predictions) + +# Calculate balanced accuracy for the confusion matrix +balanced_accuracy = balanced_accuracy_score(human_labels, predictions) + +plt.imshow(conf_matrix) +for (index, value) in np.ndenumerate(conf_matrix): + plt.annotate( str(value), xy=index, color="white", fontsize="15") +plt.xlabel('Predicted Label') +plt.ylabel('Human Label') +plt.xticks(ticks = [0, 1], labels = list(label_conversion.values())) +plt.yticks(ticks = [0, 1], labels = list(label_conversion.values())) +plt.title('Predicted vs Human Label') +plt.suptitle(f"Balanced Accuracy: {balanced_accuracy}") +plt.show() + + +############################################################################## +# Here, there are several false positives (if we consider the human labels to be "the truth"). +# +# Next, we can also see how the model's confidence relates to the probability that the model +# label matches the human label. +# +# This could be used to help decide which units should be auto-curated and which need further +# manual creation. For example, we might accept any unit as 'good' that the model predicts +# as 'good' with confidence over a threshold, say 80%. If the confidence is lower we might decide to take a +# look at this unit manually. Below, we will create a plot that shows how the agreement +# between human and model labels changes as we increase the confidence threshold. We see that +# the agreement increases as the confidence does. So the model gets more accurate with a +# higher confidence threshold, as expceted. + + +def calculate_moving_avg(label_df, confidence_label, window_size): + + label_df[f'{confidence_label}_decile'] = pd.cut(label_df[confidence_label], 10, labels=False, duplicates='drop') + # Group by decile and calculate the proportion of correct labels (agreement) + p_label_grouped = label_df.groupby(f'{confidence_label}_decile')['model_x_human_agreement'].mean() + # Convert decile to range 0-1 + p_label_grouped.index = p_label_grouped.index / 10 + # Sort the DataFrame by confidence scores + label_df_sorted = label_df.sort_values(by=confidence_label) + + p_label_moving_avg = label_df_sorted['model_x_human_agreement'].rolling(window=window_size).mean() + + return label_df_sorted[confidence_label], p_label_moving_avg + +confidences = labels['probability'] + +# Make dataframe of human label, model label, and confidence +label_df = pd.DataFrame(data = { + 'human_label': human_labels, + 'decoder_label': predictions, + 'confidence': confidences}, + index = sorting_analyzer.sorting.get_unit_ids()) + +# Calculate the proportion of agreed labels by confidence decile +label_df['model_x_human_agreement'] = label_df['human_label'] == label_df['decoder_label'] + +p_agreement_sorted, p_agreement_moving_avg = calculate_moving_avg(label_df, 'confidence', 3) + +# Plot the moving average of agreement +plt.figure(figsize=(6, 6)) +plt.plot(p_agreement_sorted, p_agreement_moving_avg, label = 'Moving Average') +plt.axhline(y=1/len(np.unique(predictions)), color='black', linestyle='--', label='Chance') +plt.xlabel('Confidence'); #plt.xlim(0.5, 1) +plt.ylabel('Proportion Agreement with Human Label'); plt.ylim(0, 1) +plt.title('Agreement vs Confidence (Moving Average)') +plt.legend(); plt.grid(True); plt.show() + +############################################################################## +# In this case, you might decide to only trust labels which had confidence over above 0.88, +# and manually labels the ones the model isn't so confident about. +# +# A model trained on real Neuropixels data +# ---------------------------------------- +# +# Above, we used a toy model trained on generated data. There are also models on HuggingFace +# trained on real data. +# +# For example, the following classifiers are trained on Neuropixels data from 11 mice recorded in +# V1,SC and ALM: https://huggingface.co/AnoushkaJain3/noise_neural_classifier/ and +# https://huggingface.co/AnoushkaJain3/sua_mua_classifier/ . One will classify units into +# `noise` or `not-noise` and the other will classify the `not-noise` units into single +# unit activity (sua) units and multi-unit activity (mua) units. +# +# There is more information about the model on the model's HuggingFace page. Take a look! +# The idea here is to first apply the noise/not-noise classifier, then the sua/mua one. +# We can do so as follows: +# + +# Apply the noise/not-noise model +noise_neuron_labels = sc.auto_label_units( + sorting_analyzer = sorting_analyzer, + repo_id = "AnoushkaJain3/noise_neural_classifier", + trust_model=True, +) + +noise_units = noise_neuron_labels[noise_neuron_labels['prediction']=='noise'] +analyzer_neural = sorting_analyzer.remove_units(noise_units.index) + +# Apply the sua/mua model +sua_mua_labels = sc.auto_label_units( + sorting_analyzer = analyzer_neural, + repo_id = "AnoushkaJain3/sua_mua_classifier", + trust_model=True, +) + +all_labels = pd.concat([sua_mua_labels, noise_units]).sort_index() +print(all_labels) + +############################################################################## +# If you run this without the ``trust_model=True`` parameter, you will receive an error: +# +# .. code-block:: +# +# UntrustedTypesFoundException: Untrusted types found in the file: ['sklearn.metrics._classification.balanced_accuracy_score', 'sklearn.metrics._scorer._Scorer', 'sklearn.model_selection._search_successive_halving.HalvingGridSearchCV', 'sklearn.model_selection._split.StratifiedKFold'] +# +# This is a security warning, which can be overcome by passing the trusted types list +# ``trusted = ['sklearn.metrics._classification.balanced_accuracy_score', 'sklearn.metrics._scorer._Scorer', 'sklearn.model_selection._search_successive_halving.HalvingGridSearchCV', 'sklearn.model_selection._split.StratifiedKFold']`` +# or by passing the ``trust_model=True``` keyword. +# +# .. dropdown:: More about security +# +# Sharing models, with are Python objects, is complicated. +# We have chosen to use the `skops format `_, instead +# of the common but insecure ``.pkl`` format (read about ``pickle`` security issues +# `here `_). While unpacking the ``.skops`` file, each function +# is checked. Ideally, skops should recognise all `sklearn`, `numpy` and `scipy` functions and +# allow the object to be loaded if it only contains these (and no unkown malicious code). But +# when ``skops`` it's not sure, it raises an error. Here, it doesn't recognise +# ``['sklearn.metrics._classification.balanced_accuracy_score', 'sklearn.metrics._scorer._Scorer', +# 'sklearn.model_selection._search_successive_halving.HalvingGridSearchCV', +# 'sklearn.model_selection._split.StratifiedKFold']``. Taking a look, these are all functions +# from `sklearn`, and we can happily add them to the ``trusted`` functions to load. +# +# In general, you should be cautious when downloading ``.skops`` files and ``.pkl`` files from repos, +# especially from unknown sources. +# +# Directly applying a sklearn Pipeline +# ------------------------------------ +# +# Instead of using ``HuggingFace`` and ``skops``, someone might have given you a model +# in differet way: perhaps by e-mail or a download. If you have the model in a +# folder, you can apply it in a very similar way: +# +# .. code-block:: +# +# labels = sc.auto_label_units( +# sorting_analyzer = sorting_analyzer, +# model_folder = "path/to/model/folder", +# ) + +############################################################################## +# Using this, you lose the advantages of the model metadata: the quality metric parameters +# are not checked and the labels are not converted their original human readable names (like +# 'good' and 'bad'). Hence we advise using the methods discussed above, when possible. diff --git a/examples/tutorials/curation/plot_2_train_a_model.py b/examples/tutorials/curation/plot_2_train_a_model.py new file mode 100644 index 0000000000..1a38836527 --- /dev/null +++ b/examples/tutorials/curation/plot_2_train_a_model.py @@ -0,0 +1,168 @@ +""" +Training a model for automated curation +============================= + +If the pretrained models do not give satisfactory performance on your data, it is easy to train your own classifier using SpikeInterface. +""" + + +############################################################################## +# Step 1: Generate and label data +# ------------------------------- +# +# First we will import our dependencies +import warnings +warnings.filterwarnings("ignore") +from pathlib import Path +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt + +import spikeinterface.core as si +import spikeinterface.curation as sc +import spikeinterface.widgets as sw + +# Note, you can set the number of cores you use using e.g. +# si.set_global_job_kwargs(n_jobs = 8) + +############################################################################## +# For this tutorial, we will use simulated data to create ``recording`` and ``sorting`` objects. We'll +# create two sorting objects: :code:`sorting_1` is coupled to the real recording, so the spike times of the sorter will +# perfectly match the spikes in the recording. Hence this will contain good units. However, we've +# uncoupled :code:`sorting_2` to the recording and the spike times will not be matched with the spikes in the recording. +# Hence these units will mostly be random noise. We'll combine the "good" and "noise" sortings into one sorting +# object using :code:`si.aggregate_units`. +# +# (When making your own model, you should +# `load your own recording `_ +# and `do a sorting `_ on your data.) + +recording, sorting_1 = si.generate_ground_truth_recording(num_channels=4, seed=1, num_units=5) +_, sorting_2 =si.generate_ground_truth_recording(num_channels=4, seed=2, num_units=5) + +both_sortings = si.aggregate_units([sorting_1, sorting_2]) + +############################################################################## +# To do some visualisation and postprocessing, we need to create a sorting analyzer, and +# compute some extensions: + +analyzer = si.create_sorting_analyzer(sorting = both_sortings, recording=recording) +analyzer.compute(['noise_levels','random_spikes','waveforms','templates']) + +############################################################################## +# Now we can plot the templates for the first and fifth units. The first (unit id 0) belongs to +# :code:`sorting_1` so should look like a real unit; the sixth (unit id 5) belongs to :code:`sorting_2` +# so should look like noise. + +sw.plot_unit_templates(analyzer, unit_ids=["0", "5"]) + +############################################################################## +# This is as expected: great! (Find out more about plotting using widgets `here `_.) +# We've set up our system so that the first five units are 'good' and the next five are 'bad'. +# So we can make a list of labels which contain this information. For real data, you could +# use a manual curation tool to make your own list. + +labels = ['good', 'good', 'good', 'good', 'good', 'bad', 'bad', 'bad', 'bad', 'bad'] + +############################################################################## +# Step 2: Train our model +# ----------------------- +# +# We'll now train a model, based on our labelled data. The model will be trained using properties +# of the units, and then be applied to units from other sortings. The properties we use are the +# `quality metrics `_ +# and `template metrics `_. +# Hence we need to compute these, using some ``sorting_analyzer``` extensions. + +analyzer.compute(['spike_locations','spike_amplitudes','correlograms','principal_components','quality_metrics','template_metrics']) + +############################################################################## +# Now that we have metrics and labels, we're ready to train the model using the +# ``train_model``` function. The trainer will try several classifiers, imputation strategies and +# scaling techniques then save the most accurate. To save time in this tutorial, +# we'll only try one classifier (Random Forest), imputation strategy (median) and scaling +# technique (standard scaler). +# +# We will use a list of one analyzer here, so the model is trained on a single +# session. In reality, we would usually train a model using multiple analyzers from an +# experiment, which should make the model more robust. To do this, you can simply pass +# a list of analyzers and a list of manually curated labels for each +# of these analyzers. Then the model would use all of these data as input. + +trainer = sc.train_model( + mode = "analyzers", # You can supply a labelled csv file instead of an analyzer + labels = [labels], + analyzers = [analyzer], + folder = "my_folder", # Where to save the model and model_info.json file + metric_names = None, # Specify which metrics to use for training: by default uses those already calculted + imputation_strategies = ["median"], # Defaults to all + scaling_techniques = ["standard_scaler"], # Defaults to all + classifiers = None, # Default to Random Forest only. Other classifiers you can try [ "AdaBoostClassifier","GradientBoostingClassifier","LogisticRegression","MLPClassifier"] + overwrite = True, # Whether or not to overwrite `folder` if it already exists. Default is False. + search_kwargs = {'cv': 3} # Parameters used during the model hyperparameter search +) + +best_model = trainer.best_pipeline + +############################################################################## +# +# You can pass many sklearn `classifiers `_ +# `imputation strategies `_ and +# `scalers `_, although the +# documentation is quite overwhelming. You can find the classifiers we've tried out +# using the ``sc.get_default_classifier_search_spaces`` function. +# +# The above code saves the model in ``model.skops``, some metadata in +# ``model_info.json`` and the model accuracies in ``model_accuracies.csv`` +# in the specified ``folder`` (in this case ``'my_folder'``). +# +# (``skops`` is a file format: you can think of it as a more-secure pkl file. `Read more `_.) +# +# The ``model_accuracies.csv`` file contains the accuracy, precision and recall of the +# tested models. Let's take a look: + +accuracies = pd.read_csv(Path("my_folder") / "model_accuracies.csv", index_col = 0) +accuracies.head() + +############################################################################## +# Our model is perfect!! This is because the task was *very* easy. We had 10 units; where +# half were pure noise and half were not. +# +# The model also contains some more information, such as which features are "important", +# as defined by sklearn (learn about feature importance of a Random Forest Classifier +# `here `_.) +# We can plot these: + +# Plot feature importances +importances = best_model.named_steps['classifier'].feature_importances_ +indices = np.argsort(importances)[::-1] + +# The sklearn importances are not computed for inputs whose values are all `nan`. +# Hence, we need to pick out the non-`nan` columns of our metrics +features = best_model.feature_names_in_ +n_features = best_model.n_features_in_ + +metrics = pd.concat([analyzer.get_extension('quality_metrics').get_data(), analyzer.get_extension('template_metrics').get_data()], axis=1) +non_null_metrics = ~(metrics.isnull().all()).values + +features = features[non_null_metrics] +n_features = len(features) + +plt.figure(figsize=(12, 7)) +plt.title("Feature Importances") +plt.bar(range(n_features), importances[indices], align="center") +plt.xticks(range(n_features), features[indices], rotation=90) +plt.xlim([-1, n_features]) +plt.subplots_adjust(bottom=0.3) +plt.show() + +############################################################################## +# Roughly, this means the model is using metrics such as "nn_hit_rate" and "l_ratio" +# but is not using "sync_spike_4" and "rp_contanimation". This is a toy model, so don't +# take these results seriously. But using this information, you could retrain another, +# simpler model using a subset of the metrics, by passing, e.g., +# ``metric_names = ['nn_hit_rate', 'l_ratio',...]`` to the ``train_model`` function. +# +# Now that you have a model, you can `apply it to another sorting +# `_ +# or `upload it to HuggingFaceHub `_. diff --git a/examples/tutorials/curation/plot_3_upload_a_model.py b/examples/tutorials/curation/plot_3_upload_a_model.py new file mode 100644 index 0000000000..0a9ea402db --- /dev/null +++ b/examples/tutorials/curation/plot_3_upload_a_model.py @@ -0,0 +1,139 @@ +""" +Upload a pipeline to Hugging Face Hub +===================================== +""" +############################################################################## +# In this tutorial we will upload a pipeline, trained in SpikeInterface, to the +# `Hugging Face Hub `_ (HFH). +# +# To do this, you first need to train a model. `Learn how here! `_ +# +# Hugging Face Hub? +# ----------------- +# Hugging Face Hub (HFH) is a model sharing platform focused on AI and ML models and datasets. +# To upload your own model to HFH, you need to make an account with them. +# If you do not want to make an account, you can simply share the model folder with colleagues. +# There are also several ways to interaction with HFH: the way we propose here doesn't use +# many of the tools ``skops`` and hugging face have developed such as the ``Card`` and +# ``hub_utils``. Feel free to check those out `here `_. +# +# Prepare your model +# ------------------ +# +# The plan is to make a folder with the following file structure +# +# .. code-block:: +# +# my_model_folder/ +# my_model_name.skops +# model_info.json +# training_data.csv +# labels.csv +# metadata.json +# +# SpikeInterface and HFH don't require you to keep this folder structure, we just advise it as +# best practice. +# +# If you've used SpikeInterface to train your model, the ``train_model`` function auto-generates +# most of this data. The only thing missing is the the ``metadata.json`` file. The purpose of this +# file is to detail how the model was trained, which can help prospective users decide if it +# is relevant for them. For example, taking +# a model trained on mouse data and applying it to a primate is likely a bad idea (or a +# great research paper!). And a model trained using tetrode data might have limited application +# on a silcone high-density probes. Hence we suggest saving at least the species, brain areas +# and probe information, as is done in the dictionary below. Note that we format the metadata +# so that the information +# in common with the NWB data format is consistent with it. Since the models can be trained +# on several curations, all the metadata fields are lists: +# +# .. code-block:: +# +# import json +# +# model_metadata = { +# "subject_species": ["Mus musculus"], +# "brain_areas": ["CA1"], +# "probes": +# [{ +# "manufacturer": "IMEc", +# "name": "Neuropixels 2.0" +# }] +# } +# with open("my_model_folder/metadata.json", "w") as file: +# json.dump(model_metadata, file) +# +# Upload to HuggingFaceHub +# ------------------------ +# +# We'll now upload this folder to HFH using the web interface. +# +# First, go to https://huggingface.co/ and make an account. Once you've logged in, press +# ``+`` then ``New model`` or find ``+ New Model`` in the user menu. You will be asked +# to enter a model name, to choose a license for the model and whether the model should +# be public or private. After you have made these choices, press ``Create Model``. +# +# You should be on your model's landing page, whose header looks something like +# +# .. image:: ../../images/initial_model_screen.png +# :width: 550 +# :align: center +# :alt: The page shown on HuggingFaceHub when a user first initialises a model +# +# Click Files, then ``+ Add file`` then ``Upload file(s)``. You can then add your files to the repository. Upload these by pressing ``Commit changes to main``. +# +# You are returned to the Files page, which should look similar to +# +# .. image:: ../../images/files_screen.png +# :width: 700 +# :align: center +# :alt: The file list for a model HuggingFaceHub. +# +# Let's add some information about the model for users to see when they go on your model's +# page. Click on ``Model card`` then ``Edit model card``. Here is a sample model card for +# For a model based on synthetically generated tetrode data, +# +# .. code-block:: +# +# --- +# license: mit +# --- +# +# ## Model description +# +# A toy model, trained on toy data generated from spikeinterface. +# +# # Intended use +# +# Used to try out automated curation in SpikeInterface. +# +# # How to Get Started with the Model +# +# This can be used to automatically label a sorting in spikeinterface. Provided you have a `sorting_analyzer`, it is used as follows +# +# ` ` ` python (NOTE: you should remove the spaces between each backtick. This is just formatting for the notebook you are reading) +# +# from spikeinterface.curation import auto_label_units +# labels = auto_label_units( +# sorting_analyzer = sorting_analyzer, +# repo_id = "SpikeInterface/toy_tetrode_model", +# trust_model=True +# ) +# ` ` ` +# +# or you can download the entire repositry to `a_folder_for_a_model`, and use +# +# ` ` ` python +# from spikeinterface.curation import auto_label_units +# +# labels = auto_label_units( +# sorting_analyzer = sorting_analyzer, +# model_folder = "path/to/a_folder_for_a_model", +# trusted = ['numpy.dtype'] +# ) +# ` ` ` +# +# # Authors +# +# Chris Halcrow +# +# You can see the repo with this Model card `here `_. diff --git a/examples/tutorials/qualitymetrics/plot_3_quality_mertics.py b/examples/tutorials/qualitymetrics/plot_3_quality_metrics.py similarity index 100% rename from examples/tutorials/qualitymetrics/plot_3_quality_mertics.py rename to examples/tutorials/qualitymetrics/plot_3_quality_metrics.py diff --git a/examples/tutorials/widgets/plot_2_sort_gallery.py b/examples/tutorials/widgets/plot_2_sort_gallery.py index da5c611ce4..056b5e3a8d 100644 --- a/examples/tutorials/widgets/plot_2_sort_gallery.py +++ b/examples/tutorials/widgets/plot_2_sort_gallery.py @@ -31,14 +31,14 @@ # plot_autocorrelograms() # ~~~~~~~~~~~~~~~~~~~~~~~~ -w_ach = sw.plot_autocorrelograms(sorting, window_ms=150.0, bin_ms=5.0, unit_ids=[1, 2, 5]) +w_ach = sw.plot_autocorrelograms(sorting, window_ms=150.0, bin_ms=5.0, unit_ids=['1', '2', '5']) ############################################################################## # plot_crosscorrelograms() # ~~~~~~~~~~~~~~~~~~~~~~~~ -w_cch = sw.plot_crosscorrelograms(sorting, window_ms=150.0, bin_ms=5.0, unit_ids=[1, 2, 5]) +w_cch = sw.plot_crosscorrelograms(sorting, window_ms=150.0, bin_ms=5.0, unit_ids=['1', '2', '5']) plt.show() diff --git a/pyproject.toml b/pyproject.toml index 22fbdc7f22..0b2f06049f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -101,6 +101,8 @@ full = [ "matplotlib>=3.6", # matplotlib.colormaps "cuda-python; platform_system != 'Darwin'", "numba", + "skops", + "huggingface_hub" ] widgets = [ @@ -171,6 +173,10 @@ test = [ "torch", "pynndescent", + # curation + "skops", + "huggingface_hub", + # for github test : probeinterface and neo from master # for release we need pypi, so this need to be commented "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", @@ -192,6 +198,8 @@ docs = [ "hdbscan>=0.8.33", # For sorters spykingcircus2 + tridesclous "numba", # For many postprocessing functions "networkx", + "skops", # For auotmated curation + "scikit-learn", # For auotmated curation # Download data "pooch>=1.8.2", "datalad>=1.0.2", diff --git a/readthedocs.yml b/readthedocs.yml index 512fcbc709..c6c44d83a0 100644 --- a/readthedocs.yml +++ b/readthedocs.yml @@ -1,5 +1,9 @@ version: 2 +sphinx: + # Path to your Sphinx configuration file. + configuration: doc/conf.py + build: os: ubuntu-22.04 tools: diff --git a/src/spikeinterface/benchmark/benchmark_base.py b/src/spikeinterface/benchmark/benchmark_base.py index b9cbf269c8..fc1b136d2d 100644 --- a/src/spikeinterface/benchmark/benchmark_base.py +++ b/src/spikeinterface/benchmark/benchmark_base.py @@ -208,10 +208,11 @@ def run(self, case_keys=None, keep=True, verbose=False, **job_kwargs): for key in case_keys: result_folder = self.folder / "results" / self.key_to_str(key) + sorter_folder = self.folder / "sorters" / self.key_to_str(key) if keep and result_folder.exists(): continue - elif not keep and result_folder.exists(): + elif not keep and (result_folder.exists() or sorter_folder.exists()): self.remove_benchmark(key) job_keys.append(key) diff --git a/src/spikeinterface/benchmark/benchmark_motion_estimation.py b/src/spikeinterface/benchmark/benchmark_motion_estimation.py index abb2a51bae..5a3c490d38 100644 --- a/src/spikeinterface/benchmark/benchmark_motion_estimation.py +++ b/src/spikeinterface/benchmark/benchmark_motion_estimation.py @@ -109,6 +109,8 @@ def run(self, **job_kwargs): estimate_motion=t4 - t3, ) + self.result["peaks"] = peaks + self.result["peak_locations"] = peak_locations self.result["step_run_times"] = step_run_times self.result["raw_motion"] = motion @@ -131,6 +133,8 @@ def compute_result(self, **result_params): self.result["motion"] = motion _run_key_saved = [ + ("peaks", "npy"), + ("peak_locations", "npy"), ("raw_motion", "Motion"), ("step_run_times", "pickle"), ] @@ -161,7 +165,9 @@ def create_benchmark(self, key): def plot_true_drift(self, case_keys=None, scaling_probe=1.5, figsize=(8, 6)): self.plot_drift(case_keys=case_keys, tested_drift=False, scaling_probe=scaling_probe, figsize=figsize) - def plot_drift(self, case_keys=None, gt_drift=True, tested_drift=True, scaling_probe=1.0, figsize=(8, 6)): + def plot_drift( + self, case_keys=None, gt_drift=True, tested_drift=True, raster=False, scaling_probe=1.0, figsize=(8, 6) + ): import matplotlib.pyplot as plt if case_keys is None: @@ -195,6 +201,13 @@ def plot_drift(self, case_keys=None, gt_drift=True, tested_drift=True, scaling_p # for i in range(self.gt_unit_positions.shape[1]): # ax.plot(temporal_bins_s, self.gt_unit_positions[:, i], alpha=0.5, ls="--", c="0.5") + if raster: + peaks = bench.result["peaks"] + peak_locations = bench.result["peak_locations"] + rec = bench.recording + x = peaks["sample_index"] / rec.sampling_frequency + y = peak_locations[bench.direction] + ax.scatter(x, y, alpha=0.2, s=2, c=np.abs(peaks["amplitude"]), cmap="inferno") for i in range(gt_motion.displacement[0].shape[1]): depth = motion.spatial_bins_um[i] diff --git a/src/spikeinterface/benchmark/benchmark_sorter.py b/src/spikeinterface/benchmark/benchmark_sorter.py index f9267c785a..3cf6dca04f 100644 --- a/src/spikeinterface/benchmark/benchmark_sorter.py +++ b/src/spikeinterface/benchmark/benchmark_sorter.py @@ -56,6 +56,15 @@ def create_benchmark(self, key): benchmark = SorterBenchmark(recording, gt_sorting, params, sorter_folder) return benchmark + def remove_benchmark(self, key): + BenchmarkStudy.remove_benchmark(self, key) + + sorter_folder = self.folder / "sorters" / self.key_to_str(key) + import shutil + + if sorter_folder.exists(): + shutil.rmtree(sorter_folder) + def get_performance_by_unit(self, case_keys=None): import pandas as pd diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 2af48407a3..9a0e242d62 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -135,7 +135,7 @@ def get_total_duration(self) -> float: def get_unit_spike_train( self, - unit_id, + unit_id: str | int, segment_index: Union[int, None] = None, start_frame: Union[int, None] = None, end_frame: Union[int, None] = None, diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 0316b3bab1..aa69fe585b 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -2,7 +2,7 @@ import math import warnings import numpy as np -from typing import Literal +from typing import Literal, Optional from math import ceil from .basesorting import SpikeVectorSortingSegment @@ -134,7 +134,7 @@ def generate_sorting( seed = _ensure_seed(seed) rng = np.random.default_rng(seed) num_segments = len(durations) - unit_ids = np.arange(num_units) + unit_ids = [str(idx) for idx in np.arange(num_units)] spikes = [] for segment_index in range(num_segments): @@ -1111,7 +1111,7 @@ def __init__( """ - unit_ids = np.arange(num_units) + unit_ids = [str(idx) for idx in np.arange(num_units)] super().__init__(sampling_frequency, unit_ids) self.num_units = num_units @@ -1138,6 +1138,7 @@ def __init__( firing_rates=firing_rates, refractory_period_seconds=self.refractory_period_seconds, seed=segment_seed, + unit_ids=unit_ids, t_start=None, ) self.add_sorting_segment(segment) @@ -1161,6 +1162,7 @@ def __init__( firing_rates: float | np.ndarray, refractory_period_seconds: float | np.ndarray, seed: int, + unit_ids: list[str], t_start: Optional[float] = None, ): self.num_units = num_units @@ -1177,7 +1179,8 @@ def __init__( self.refractory_period_seconds = np.full(num_units, self.refractory_period_seconds, dtype="float64") self.segment_seed = seed - self.units_seed = {unit_id: self.segment_seed + hash(unit_id) for unit_id in range(num_units)} + self.units_seed = {unit_id: abs(self.segment_seed + hash(unit_id)) for unit_id in unit_ids} + self.num_samples = math.ceil(sampling_frequency * duration) super().__init__(t_start) @@ -1280,7 +1283,7 @@ def __init__( noise_block_size: int = 30000, ): - channel_ids = np.arange(num_channels) + channel_ids = [str(idx) for idx in np.arange(num_channels)] dtype = np.dtype(dtype).name # Cast to string for serialization if dtype not in ("float32", "float64"): raise ValueError(f"'dtype' must be 'float32' or 'float64' but is {dtype}") diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 55cbe6070a..fdad87287e 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -2092,6 +2092,13 @@ def load_data(self): import pandas as pd ext_data = pd.read_csv(ext_data_file, index_col=0) + # we need to cast the index to the unit id dtype (int or str) + unit_ids = self.sorting_analyzer.unit_ids + if ext_data.shape[0] == unit_ids.size: + # we force dtype to be the same as unit_ids + if ext_data.index.dtype != unit_ids.dtype: + ext_data.index = ext_data.index.astype(unit_ids.dtype) + elif ext_data_file.suffix == ".pkl": with ext_data_file.open("rb") as f: ext_data = pickle.load(f) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index b64f0610ea..3e3fcc7384 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -205,6 +205,7 @@ def to_sparse(self, sparsity): unit_ids=self.unit_ids, probe=self.probe, check_for_consistent_sparsity=self.check_for_consistent_sparsity, + is_scaled=self.is_scaled, ) def get_one_template_dense(self, unit_index): diff --git a/src/spikeinterface/core/tests/test_basesnippets.py b/src/spikeinterface/core/tests/test_basesnippets.py index 64f7f76819..f243dd9d9f 100644 --- a/src/spikeinterface/core/tests/test_basesnippets.py +++ b/src/spikeinterface/core/tests/test_basesnippets.py @@ -41,8 +41,8 @@ def test_BaseSnippets(create_cache_folder): assert snippets.get_num_segments() == len(duration) assert snippets.get_num_channels() == num_channels - assert np.all(snippets.ids_to_indices([0, 1, 2]) == [0, 1, 2]) - assert np.all(snippets.ids_to_indices([0, 1, 2], prefer_slice=True) == slice(0, 3, None)) + assert np.all(snippets.ids_to_indices(["0", "1", "2"]) == [0, 1, 2]) + assert np.all(snippets.ids_to_indices(["0", "1", "2"], prefer_slice=True) == slice(0, 3, None)) # annotations / properties snippets.annotate(gre="ta") @@ -60,7 +60,7 @@ def test_BaseSnippets(create_cache_folder): ) # missing property - snippets.set_property("string_property", ["ciao", "bello"], ids=[0, 1]) + snippets.set_property("string_property", ["ciao", "bello"], ids=["0", "1"]) values = snippets.get_property("string_property") assert values[2] == "" @@ -70,14 +70,14 @@ def test_BaseSnippets(create_cache_folder): snippets.set_property, key="string_property_nan", values=["hola", "chabon"], - ids=[0, 1], + ids=["0", "1"], missing_value=np.nan, ) # int properties without missing values raise an error assert_raises(Exception, snippets.set_property, key="int_property", values=[5, 6], ids=[1, 2]) - snippets.set_property("int_property", [5, 6], ids=[1, 2], missing_value=200) + snippets.set_property("int_property", [5, 6], ids=["1", "2"], missing_value=200) values = snippets.get_property("int_property") assert values.dtype.kind == "i" diff --git a/src/spikeinterface/core/tests/test_channelsaggregationrecording.py b/src/spikeinterface/core/tests/test_channelsaggregationrecording.py index 118b6092a9..99d6890dfd 100644 --- a/src/spikeinterface/core/tests/test_channelsaggregationrecording.py +++ b/src/spikeinterface/core/tests/test_channelsaggregationrecording.py @@ -38,10 +38,12 @@ def test_channelsaggregationrecording(): assert np.allclose(traces1_1, recording_agg.get_traces(channel_ids=[str(channel_ids[1])], segment_index=seg)) assert np.allclose( - traces2_0, recording_agg.get_traces(channel_ids=[str(num_channels + channel_ids[0])], segment_index=seg) + traces2_0, + recording_agg.get_traces(channel_ids=[str(num_channels + int(channel_ids[0]))], segment_index=seg), ) assert np.allclose( - traces3_2, recording_agg.get_traces(channel_ids=[str(2 * num_channels + channel_ids[2])], segment_index=seg) + traces3_2, + recording_agg.get_traces(channel_ids=[str(2 * num_channels + int(channel_ids[2]))], segment_index=seg), ) # all traces traces1 = recording1.get_traces(segment_index=seg) diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index 35ab18b5f2..15f089f784 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -31,6 +31,14 @@ def get_dataset(): noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"), seed=2205, ) + + # TODO: the tests or the sorting analyzer make assumptions about the ids being integers + # So keeping this the way it was + integer_channel_ids = [int(id) for id in recording.get_channel_ids()] + integer_unit_ids = [int(id) for id in sorting.get_unit_ids()] + + recording = recording.rename_channels(new_channel_ids=integer_channel_ids) + sorting = sorting.rename_units(new_unit_ids=integer_unit_ids) return recording, sorting diff --git a/src/spikeinterface/core/tests/test_unitsselectionsorting.py b/src/spikeinterface/core/tests/test_unitsselectionsorting.py index 1e72b0ab28..3ecb702aa2 100644 --- a/src/spikeinterface/core/tests/test_unitsselectionsorting.py +++ b/src/spikeinterface/core/tests/test_unitsselectionsorting.py @@ -10,25 +10,29 @@ def test_basic_functions(): sorting = generate_sorting(num_units=3, durations=[0.100, 0.100], sampling_frequency=30000.0) - sorting2 = UnitsSelectionSorting(sorting, unit_ids=[0, 2]) - assert np.array_equal(sorting2.unit_ids, [0, 2]) + sorting2 = UnitsSelectionSorting(sorting, unit_ids=["0", "2"]) + assert np.array_equal(sorting2.unit_ids, ["0", "2"]) assert sorting2.get_parent() == sorting - sorting3 = UnitsSelectionSorting(sorting, unit_ids=[0, 2], renamed_unit_ids=["a", "b"]) + sorting3 = UnitsSelectionSorting(sorting, unit_ids=["0", "2"], renamed_unit_ids=["a", "b"]) assert np.array_equal(sorting3.unit_ids, ["a", "b"]) assert np.array_equal( - sorting.get_unit_spike_train(0, segment_index=0), sorting2.get_unit_spike_train(0, segment_index=0) + sorting.get_unit_spike_train(unit_id="0", segment_index=0), + sorting2.get_unit_spike_train(unit_id="0", segment_index=0), ) assert np.array_equal( - sorting.get_unit_spike_train(0, segment_index=0), sorting3.get_unit_spike_train("a", segment_index=0) + sorting.get_unit_spike_train(unit_id="0", segment_index=0), + sorting3.get_unit_spike_train(unit_id="a", segment_index=0), ) assert np.array_equal( - sorting.get_unit_spike_train(2, segment_index=0), sorting2.get_unit_spike_train(2, segment_index=0) + sorting.get_unit_spike_train(unit_id="2", segment_index=0), + sorting2.get_unit_spike_train(unit_id="2", segment_index=0), ) assert np.array_equal( - sorting.get_unit_spike_train(2, segment_index=0), sorting3.get_unit_spike_train("b", segment_index=0) + sorting.get_unit_spike_train(unit_id="2", segment_index=0), + sorting3.get_unit_spike_train(unit_id="b", segment_index=0), ) @@ -36,13 +40,13 @@ def test_failure_with_non_unique_unit_ids(): seed = 10 sorting = generate_sorting(num_units=3, durations=[0.100], sampling_frequency=30000.0, seed=seed) with pytest.raises(AssertionError): - sorting2 = UnitsSelectionSorting(sorting, unit_ids=[0, 2], renamed_unit_ids=["a", "a"]) + sorting2 = UnitsSelectionSorting(sorting, unit_ids=["0", "2"], renamed_unit_ids=["a", "a"]) def test_custom_cache_spike_vector(): sorting = generate_sorting(num_units=3, durations=[0.100, 0.100], sampling_frequency=30000.0) - sub_sorting = UnitsSelectionSorting(sorting, unit_ids=[2, 0], renamed_unit_ids=["b", "a"]) + sub_sorting = UnitsSelectionSorting(sorting, unit_ids=["2", "0"], renamed_unit_ids=["b", "a"]) cached_spike_vector = sub_sorting.to_spike_vector(use_cache=True) computed_spike_vector = sub_sorting.to_spike_vector(use_cache=False) assert np.all(cached_spike_vector == computed_spike_vector) diff --git a/src/spikeinterface/curation/__init__.py b/src/spikeinterface/curation/__init__.py index 0302ffe5b7..975f2fe22f 100644 --- a/src/spikeinterface/curation/__init__.py +++ b/src/spikeinterface/curation/__init__.py @@ -15,3 +15,7 @@ from .curation_format import validate_curation_dict, curation_label_to_dataframe, apply_curation from .sortingview_curation import apply_sortingview_curation + +# automated curation +from .model_based_curation import auto_label_units, load_model +from .train_manual_curation import train_model, get_default_classifier_search_spaces diff --git a/src/spikeinterface/curation/curation_format.py b/src/spikeinterface/curation/curation_format.py index 5f85538b08..80f251ca43 100644 --- a/src/spikeinterface/curation/curation_format.py +++ b/src/spikeinterface/curation/curation_format.py @@ -45,12 +45,16 @@ def validate_curation_dict(curation_dict): if not removed_units_set.issubset(unit_set): raise ValueError("Curation format: some removed units are not in the unit list") + for group in curation_dict["merge_unit_groups"]: + if len(group) < 2: + raise ValueError("Curation format: 'merge_unit_groups' must be list of list with at least 2 elements") + all_merging_groups = [set(group) for group in curation_dict["merge_unit_groups"]] for gp_1, gp_2 in combinations(all_merging_groups, 2): if len(gp_1.intersection(gp_2)) != 0: - raise ValueError("Some units belong to multiple merge groups") + raise ValueError("Curation format: some units belong to multiple merge groups") if len(removed_units_set.intersection(merged_units_set)) != 0: - raise ValueError("Some units were merged and deleted") + raise ValueError("Curation format: some units were merged and deleted") # Check the labels exclusivity for lbl in curation_dict["manual_labels"]: @@ -238,7 +242,7 @@ def apply_curation_labels(sorting, new_unit_ids, curation_dict): all_values = np.zeros(sorting.unit_ids.size, dtype=values.dtype) for unit_ind, unit_id in enumerate(sorting.unit_ids): if unit_id not in new_unit_ids: - ind = curation_dict["unit_ids"].index(unit_id) + ind = list(curation_dict["unit_ids"]).index(unit_id) all_values[unit_ind] = values[ind] sorting.set_property(key, all_values) @@ -253,7 +257,7 @@ def apply_curation_labels(sorting, new_unit_ids, curation_dict): group_values.append(value) if len(set(group_values)) == 1: # all group has the same label or empty - sorting.set_property(key, values=group_values, ids=[new_unit_id]) + sorting.set_property(key, values=group_values[:1], ids=[new_unit_id]) else: for key in label_def["label_options"]: @@ -339,18 +343,22 @@ def apply_curation( elif isinstance(sorting_or_analyzer, SortingAnalyzer): analyzer = sorting_or_analyzer - analyzer = analyzer.remove_units(curation_dict["removed_units"]) - analyzer, new_unit_ids = analyzer.merge_units( - curation_dict["merge_unit_groups"], - censor_ms=censor_ms, - merging_mode=merging_mode, - sparsity_overlap=sparsity_overlap, - new_id_strategy=new_id_strategy, - return_new_unit_ids=True, - format="memory", - verbose=verbose, - **job_kwargs, - ) + if len(curation_dict["removed_units"]) > 0: + analyzer = analyzer.remove_units(curation_dict["removed_units"]) + if len(curation_dict["merge_unit_groups"]) > 0: + analyzer, new_unit_ids = analyzer.merge_units( + curation_dict["merge_unit_groups"], + censor_ms=censor_ms, + merging_mode=merging_mode, + sparsity_overlap=sparsity_overlap, + new_id_strategy=new_id_strategy, + return_new_unit_ids=True, + format="memory", + verbose=verbose, + **job_kwargs, + ) + else: + new_unit_ids = [] apply_curation_labels(analyzer.sorting, new_unit_ids, curation_dict) return analyzer else: diff --git a/src/spikeinterface/curation/model_based_curation.py b/src/spikeinterface/curation/model_based_curation.py new file mode 100644 index 0000000000..93ad03734c --- /dev/null +++ b/src/spikeinterface/curation/model_based_curation.py @@ -0,0 +1,435 @@ +import numpy as np +from pathlib import Path +import json +import warnings +import re + +from spikeinterface.core import SortingAnalyzer +from spikeinterface.curation.train_manual_curation import ( + try_to_get_metrics_from_analyzer, + _get_computed_metrics, + _format_metric_dataframe, +) +from copy import deepcopy + + +class ModelBasedClassification: + """ + Class for performing model-based classification on spike sorting data. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The sorting analyzer object containing the spike sorting data. + pipeline : Pipeline + The pipeline object representing the trained classification model. + + Attributes + ---------- + sorting_analyzer : SortingAnalyzer + The sorting analyzer object containing the spike sorting data. + pipeline : Pipeline + The pipeline object representing the trained classification model. + required_metrics : Sequence[str] + The list of required metrics for classification, extracted from the pipeline. + + Methods + ------- + predict_labels() + Predicts the labels for the spike sorting data using the trained model. + """ + + def __init__(self, sorting_analyzer: SortingAnalyzer, pipeline): + from sklearn.pipeline import Pipeline + + if not isinstance(pipeline, Pipeline): + raise ValueError("The `pipeline` must be an instance of sklearn.pipeline.Pipeline") + + self.sorting_analyzer = sorting_analyzer + self.pipeline = pipeline + self.required_metrics = pipeline.feature_names_in_ + + def predict_labels( + self, label_conversion=None, input_data=None, export_to_phy=False, model_info=None, enforce_metric_params=False + ): + """ + Predicts the labels for the spike sorting data using the trained model. + Populates the sorting object with the predicted labels and probabilities as unit properties + + Parameters + ---------- + model_info : dict or None, default: None + Model info, generated with model, used to check metric parameters used to train it. + label_conversion : dict or None, default: None + A dictionary for converting the predicted labels (which are integers) to custom labels. If None, + tries to find in `model_info` file. The dictionary should have the format {old_label: new_label}. + input_data : pandas.DataFrame or None, default: None + The input data for classification. If not provided, the method will extract metrics stored in the sorting analyzer. + export_to_phy : bool, default: False. + Whether to export the classified units to Phy format. Default is False. + enforce_metric_params : bool, default: False + If True and the parameters used to compute the metrics in `sorting_analyzer` are different than the parmeters + used to compute the metrics used to train the model, this function will raise an error. Otherwise, a warning is raised. + + Returns + ------- + pd.DataFrame + A dataframe containing the classified units and their corresponding predictions and probabilities, + indexed by their `unit_ids`. + """ + import pandas as pd + + # Get metrics DataFrame for classification + if input_data is None: + input_data = _get_computed_metrics(self.sorting_analyzer) + else: + if not isinstance(input_data, pd.DataFrame): + raise ValueError("Input data must be a pandas DataFrame") + + input_data = self._check_required_metrics_are_present(input_data) + + if model_info is not None: + self._check_params_for_classification(enforce_metric_params, model_info=model_info) + + if model_info is not None and label_conversion is None: + try: + string_label_conversion = model_info["label_conversion"] + # json keys are strings; we convert these to ints + label_conversion = {} + for key, value in string_label_conversion.items(): + label_conversion[int(key)] = value + except: + warnings.warn("Could not find `label_conversion` key in `model_info.json` file") + + input_data = _format_metric_dataframe(input_data) + + # Apply classifier + predictions = self.pipeline.predict(input_data) + probabilities = self.pipeline.predict_proba(input_data) + probabilities = np.max(probabilities, axis=1) + + if isinstance(label_conversion, dict): + + if set(predictions).issubset(set(label_conversion.keys())) is False: + raise ValueError("Labels in predictions do not match those in label_conversion") + predictions = [label_conversion[label] for label in predictions] + + classified_units = pd.DataFrame( + zip(predictions, probabilities), columns=["prediction", "probability"], index=self.sorting_analyzer.unit_ids + ) + + # Set predictions and probability as sorting properties + self.sorting_analyzer.sorting.set_property("classifier_label", predictions) + self.sorting_analyzer.sorting.set_property("classifier_probability", probabilities) + + if export_to_phy: + self._export_to_phy(classified_units) + + return classified_units + + def _check_required_metrics_are_present(self, calculated_metrics): + + # Check all the required metrics have been calculated + required_metrics = set(self.required_metrics) + if required_metrics.issubset(set(calculated_metrics)): + input_data = calculated_metrics[self.required_metrics] + else: + raise ValueError( + "Input data does not contain all required metrics for classification", + f"Missing metrics: {required_metrics.difference(calculated_metrics)}", + ) + + return input_data + + def _check_params_for_classification(self, enforce_metric_params=False, model_info=None): + """ + Check that quality and template metrics parameters match those used to train the model + + Parameters + ---------- + enforce_metric_params : bool, default: False + If True and the parameters used to compute the metrics in `sorting_analyzer` are different than the parmeters + used to compute the metrics used to train the model, this function will raise an error. Otherwise, a warning is raised. + model_info : dict, default: None + Dictionary of model info containing provenance of the model. + """ + + extension_names = ["quality_metrics", "template_metrics"] + + metric_extensions = [self.sorting_analyzer.get_extension(extension_name) for extension_name in extension_names] + + for metric_extension, extension_name in zip(metric_extensions, extension_names): + + # remove the 's' at the end of the extension name + extension_name = extension_name[:-1] + model_extension_params = model_info["metric_params"].get(extension_name + "_params") + + if metric_extension is not None and model_extension_params is not None: + + metric_params = metric_extension.params["metric_params"] + + inconsistent_metrics = [] + for metric in model_extension_params["metric_names"]: + model_metric_params = model_extension_params.get("metric_params") + if model_metric_params is None or metric not in model_metric_params: + inconsistent_metrics.append(metric) + else: + if metric_params[metric] != model_metric_params[metric]: + warning_message = f"{extension_name} params for {metric} do not match those used to train the model. Parameters can be found in the 'model_info.json' file." + if enforce_metric_params is True: + raise Exception(warning_message) + else: + warnings.warn(warning_message) + + if len(inconsistent_metrics) > 0: + warning_message = f"Parameters used to compute metrics {inconsistent_metrics}, used to train this model, are unknown." + if enforce_metric_params is True: + raise Exception(warning_message) + else: + warnings.warn(warning_message) + + def _export_to_phy(self, classified_units): + """Export the classified units to Phy as cluster_prediction.tsv file""" + + import pandas as pd + + # Create a new DataFrame with unit_id, prediction, and probability columns from dict {unit_id: (prediction, probability)} + classified_df = pd.DataFrame.from_dict(classified_units, orient="index", columns=["prediction", "probability"]) + + # Export to Phy format + try: + sorting_path = self.sorting_analyzer.sorting.get_annotation("phy_folder") + assert sorting_path is not None + assert Path(sorting_path).is_dir() + except AssertionError: + raise ValueError("Phy folder not found in sorting annotations, or is not a directory") + + classified_df.to_csv(f"{sorting_path}/cluster_prediction.tsv", sep="\t", index_label="cluster_id") + + +def auto_label_units( + sorting_analyzer: SortingAnalyzer, + model_folder=None, + model_name=None, + repo_id=None, + label_conversion=None, + trust_model=False, + trusted=None, + export_to_phy=False, + enforce_metric_params=False, +): + """ + Automatically labels units based on a model-based classification, either from a model + hosted on HuggingFaceHub or one available in a local folder. + + This function returns the predicted labels and the prediction probabilities, and populates + the sorting object with the predicted labels and probabilities in the 'classifier_label' and + 'classifier_probability' properties. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The sorting analyzer object containing the spike sorting results. + model_folder : str or Path, defualt: None + The path to the folder containing the model + repo_id : str | Path, default: None + Hugging face repo id which contains the model e.g. 'username/model' + model_name: str | Path, default: None + Filename of model e.g. 'my_model.skops'. If None, uses first model found. + label_conversion : dic | None, default: None + A dictionary for converting the predicted labels (which are integers) to custom labels. If None, + tries to extract from `model_info.json` file. The dictionary should have the format {old_label: new_label}. + export_to_phy : bool, default: False + Whether to export the results to Phy format. Default is False. + trust_model : bool, default: False + Whether to trust the model. If True, the `trusted` parameter that is passed to `skops.load` to load the model will be + automatically inferred. If False, the `trusted` parameter must be provided to indicate the trusted objects. + trusted : list of str, default: None + Passed to skops.load. The object will be loaded only if there are only trusted objects and objects of types listed in trusted in the dumped file. + enforce_metric_params : bool, default: False + If True and the parameters used to compute the metrics in `sorting_analyzer` are different than the parmeters + used to compute the metrics used to train the model, this function will raise an error. Otherwise, a warning is raised. + + + Returns + ------- + classified_units : pd.DataFrame + A dataframe containing the classified units, indexed by the `unit_ids`, containing the predicted label + and confidence probability of each labelled unit. + + Raises + ------ + ValueError + If the pipeline is not an instance of sklearn.pipeline.Pipeline. + + """ + from sklearn.pipeline import Pipeline + + model, model_info = load_model( + model_folder=model_folder, repo_id=repo_id, model_name=model_name, trust_model=trust_model, trusted=trusted + ) + + if not isinstance(model, Pipeline): + raise ValueError("The model must be an instance of sklearn.pipeline.Pipeline") + + model_based_classification = ModelBasedClassification(sorting_analyzer, model) + + classified_units = model_based_classification.predict_labels( + label_conversion=label_conversion, + export_to_phy=export_to_phy, + model_info=model_info, + enforce_metric_params=enforce_metric_params, + ) + + return classified_units + + +def load_model(model_folder=None, repo_id=None, model_name=None, trust_model=False, trusted=None): + """ + Loads a model and model_info from a HuggingFaceHub repo or a local folder. + + Parameters + ---------- + model_folder : str or Path, defualt: None + The path to the folder containing the model + repo_id : str | Path, default: None + Hugging face repo id which contains the model e.g. 'username/model' + model_name: str | Path, default: None + Filename of model e.g. 'my_model.skops'. If None, uses first model found. + trust_model : bool, default: False + Whether to trust the model. If True, the `trusted` parameter that is passed to `skops.load` to load the model will be + automatically inferred. If False, the `trusted` parameter must be provided to indicate the trusted objects. + trusted : list of str, default: None + Passed to skops.load. The object will be loaded only if there are only trusted objects and objects of types listed in trusted in the dumped file. + + + Returns + ------- + model, model_info + A model and metadata about the model + """ + + if model_folder is None and repo_id is None: + raise ValueError("Please provide a 'model_folder' or a 'repo_id'.") + elif model_folder is not None and repo_id is not None: + raise ValueError("Please only provide one of 'model_folder' or 'repo_id'.") + elif model_folder is not None: + model, model_info = _load_model_from_folder( + model_folder=model_folder, model_name=model_name, trust_model=trust_model, trusted=trusted + ) + else: + model, model_info = _load_model_from_huggingface( + repo_id=repo_id, model_name=model_name, trust_model=trust_model, trusted=trusted + ) + + return model, model_info + + +def _load_model_from_huggingface(repo_id=None, model_name=None, trust_model=False, trusted=None): + """ + Loads a model from a huggingface repo + + Returns + ------- + model, model_info + A model and metadata about the model + """ + + from huggingface_hub import list_repo_files + from huggingface_hub import hf_hub_download + + # get repo filenames + repo_filenames = list_repo_files(repo_id=repo_id) + + # download all skops and json files to temp directory + for filename in repo_filenames: + if Path(filename).suffix in [".skops", ".json"]: + full_path = hf_hub_download(repo_id=repo_id, filename=filename) + model_folder = Path(full_path).parent + + model, model_info = _load_model_from_folder( + model_folder=model_folder, model_name=model_name, trust_model=trust_model, trusted=trusted + ) + + return model, model_info + + +def _load_model_from_folder(model_folder=None, model_name=None, trust_model=False, trusted=None): + """ + Loads a model and model_info from a folder + + Returns + ------- + model, model_info + A model and metadata about the model + """ + + import skops.io as skio + from skops.io.exceptions import UntrustedTypesFoundException + + folder = Path(model_folder) + assert folder.is_dir(), f"The folder {folder}, does not exist." + + # look for any .skops files + skops_files = list(folder.glob("*.skops")) + assert len(skops_files) > 0, f"There are no '.skops' files in the folder {folder}" + + if len(skops_files) > 1: + if model_name is None: + model_names = [f.name for f in skops_files] + raise ValueError( + f"There are more than 1 '.skops' file in folder {folder}. You have to specify " + f"the file using the 'model_name' argument. Available files:\n{model_names}" + ) + else: + skops_file = folder / Path(model_name) + assert skops_file.is_file(), f"Model file {skops_file} not found." + elif len(skops_files) == 1: + skops_file = skops_files[0] + + if trust_model and trusted is None: + try: + model = skio.load(skops_file) + except UntrustedTypesFoundException as e: + exception_msg = str(e) + # the exception message contains the list of untrusted objects. The following + # search assumes it is the only list in the message. + string_list = re.search(r"\[(.*?)\]", exception_msg).group() + trusted = [list_item for list_item in string_list.split("'") if len(list_item) > 2] + + model = skio.load(skops_file, trusted=trusted) + + model_info_path = folder / "model_info.json" + if not model_info_path.is_file(): + warnings.warn("No 'model_info.json' file found in folder. No metadata can be checked.") + model_info = None + else: + model_info = json.load(open(model_info_path)) + + model_info = handle_backwards_compatibility_metric_params(model_info) + + return model, model_info + + +def handle_backwards_compatibility_metric_params(model_info): + + if ( + model_info.get("metric_params") is not None + and model_info.get("metric_params").get("quality_metric_params") is not None + ): + if (qm_params := model_info["metric_params"]["quality_metric_params"].get("qm_params")) is not None: + model_info["metric_params"]["quality_metric_params"]["metric_params"] = qm_params + del model_info["metric_params"]["quality_metric_params"]["qm_params"] + + if ( + model_info.get("metric_params") is not None + and model_info.get("metric_params").get("template_metric_params") is not None + ): + if (tm_params := model_info["metric_params"]["template_metric_params"].get("metrics_kwargs")) is not None: + metric_params = {} + for metric_name in model_info["metric_params"]["template_metric_params"].get("metric_names"): + metric_params[metric_name] = deepcopy(tm_params) + model_info["metric_params"]["template_metric_params"]["metric_params"] = metric_params + del model_info["metric_params"]["template_metric_params"]["metrics_kwargs"] + + return model_info diff --git a/src/spikeinterface/curation/tests/common.py b/src/spikeinterface/curation/tests/common.py index 9cd20f4bfc..e9c4c4a463 100644 --- a/src/spikeinterface/curation/tests/common.py +++ b/src/spikeinterface/curation/tests/common.py @@ -19,6 +19,11 @@ def make_sorting_analyzer(sparse=True): seed=2205, ) + channel_ids_as_integers = [id for id in range(recording.get_num_channels())] + unit_ids_as_integers = [id for id in range(sorting.get_num_units())] + recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers) + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers) + sorting_analyzer = create_sorting_analyzer(sorting=sorting, recording=recording, format="memory", sparse=sparse) sorting_analyzer.compute("random_spikes") sorting_analyzer.compute("waveforms", **job_kwargs) diff --git a/src/spikeinterface/curation/tests/test_model_based_curation.py b/src/spikeinterface/curation/tests/test_model_based_curation.py new file mode 100644 index 0000000000..3683b417df --- /dev/null +++ b/src/spikeinterface/curation/tests/test_model_based_curation.py @@ -0,0 +1,167 @@ +import pytest +from pathlib import Path +from spikeinterface.curation.tests.common import make_sorting_analyzer, sorting_analyzer_for_curation +from spikeinterface.curation.model_based_curation import ModelBasedClassification +from spikeinterface.curation import auto_label_units, load_model +from spikeinterface.curation.train_manual_curation import _get_computed_metrics + +import numpy as np + +if hasattr(pytest, "global_test_folder"): + cache_folder = pytest.global_test_folder / "curation" +else: + cache_folder = Path("cache_folder") / "curation" + + +@pytest.fixture +def model(): + """A toy model, created using the `sorting_analyzer_for_curation` from `spikeinterface.curation.tests.common`. + It has been trained locally and, when applied to `sorting_analyzer_for_curation` will label its 5 units with + the following labels: [1,0,1,0,1].""" + + model = load_model(Path(__file__).parent / "trained_pipeline/", trusted=["numpy.dtype"]) + return model + + +@pytest.fixture +def required_metrics(): + """These are the metrics which `model` are trained on.""" + return ["num_spikes", "snr", "half_width"] + + +def test_model_based_classification_init(sorting_analyzer_for_curation, model): + """Test that the ModelBasedClassification attributes are correctly initialised""" + + model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model[0]) + assert model_based_classification.sorting_analyzer == sorting_analyzer_for_curation + assert model_based_classification.pipeline == model[0] + assert np.all(model_based_classification.required_metrics == model_based_classification.pipeline.feature_names_in_) + + +def test_metric_ordering_independence(sorting_analyzer_for_curation, model): + """The function `auto_label_units` needs the correct metrics to have been computed. However, + it should be independent of the order of computation. We test this here.""" + + sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"]) + sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["num_spikes", "snr"]) + + model_folder = Path(__file__).parent / Path("trained_pipeline") + + prediction_prob_dataframe_1 = auto_label_units( + sorting_analyzer=sorting_analyzer_for_curation, + model_folder=model_folder, + trusted=["numpy.dtype"], + ) + + sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["snr", "num_spikes"]) + + prediction_prob_dataframe_2 = auto_label_units( + sorting_analyzer=sorting_analyzer_for_curation, + model_folder=model_folder, + trusted=["numpy.dtype"], + ) + + assert prediction_prob_dataframe_1.equals(prediction_prob_dataframe_2) + + +def test_model_based_classification_get_metrics_for_classification( + sorting_analyzer_for_curation, model, required_metrics +): + """If the user has not computed the required metrics, an error should be returned. + This test checks that an error occurs when the required metrics have not been computed, + and that no error is returned when the required metrics have been computed. + """ + + sorting_analyzer_for_curation.delete_extension("quality_metrics") + sorting_analyzer_for_curation.delete_extension("template_metrics") + + model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model[0]) + + # Check that ValueError is returned when no metrics are present in sorting_analyzer + with pytest.raises(ValueError): + computed_metrics = _get_computed_metrics(sorting_analyzer_for_curation) + + # Compute some (but not all) of the required metrics in sorting_analyzer, should still error + sorting_analyzer_for_curation.compute("quality_metrics", metric_names=[required_metrics[0]]) + computed_metrics = _get_computed_metrics(sorting_analyzer_for_curation) + with pytest.raises(ValueError): + model_based_classification._check_required_metrics_are_present(computed_metrics) + + # Compute all of the required metrics in sorting_analyzer, no more error + sorting_analyzer_for_curation.compute("quality_metrics", metric_names=required_metrics[0:2]) + sorting_analyzer_for_curation.compute("template_metrics", metric_names=[required_metrics[2]]) + + metrics_data = _get_computed_metrics(sorting_analyzer_for_curation) + assert metrics_data.shape[0] == len(sorting_analyzer_for_curation.sorting.get_unit_ids()) + assert set(metrics_data.columns.to_list()) == set(required_metrics) + + +def test_model_based_classification_export_to_phy(sorting_analyzer_for_curation, model): + # Test the _export_to_phy() method of ModelBasedClassification + model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model[0]) + classified_units = {0: (1, 0.5), 1: (0, 0.5), 2: (1, 0.5), 3: (0, 0.5), 4: (1, 0.5)} + # Function should fail here + with pytest.raises(ValueError): + model_based_classification._export_to_phy(classified_units) + # Make temp output folder and set as phy_folder + phy_folder = cache_folder / "phy_folder" + phy_folder.mkdir(parents=True, exist_ok=True) + + model_based_classification.sorting_analyzer.sorting.annotate(phy_folder=phy_folder) + model_based_classification._export_to_phy(classified_units) + assert (phy_folder / "cluster_prediction.tsv").exists() + + +def test_model_based_classification_predict_labels(sorting_analyzer_for_curation, model): + """The model `model` has been trained on the `sorting_analyzer` used in this test with + the labels `[1, 0, 1, 0, 1]`. Hence if we apply the model to this `sorting_analyzer` + we expect these labels to be outputted. The test checks this, and also checks + that label conversion works as expected.""" + + sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"]) + sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["num_spikes", "snr"]) + + # Test the predict_labels() method of ModelBasedClassification + model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model[0]) + classified_units = model_based_classification.predict_labels() + predictions = classified_units["prediction"].values + + assert np.all(predictions == np.array([1, 0, 1, 0, 1])) + + conversion = {0: "noise", 1: "good"} + classified_units_labelled = model_based_classification.predict_labels(label_conversion=conversion) + predictions_labelled = classified_units_labelled["prediction"] + assert np.all(predictions_labelled == ["good", "noise", "good", "noise", "good"]) + + +def test_exception_raised_when_metricparams_not_equal(sorting_analyzer_for_curation): + """We track whether the metric parameters used to compute the metrics used to train + a model are the same as the parameters used to compute the metrics in the sorting + analyzer which is being curated. If they are different, an error or warning will + be raised depending on the `enforce_metric_params` kwarg. This behaviour is tested here.""" + + sorting_analyzer_for_curation.compute( + "quality_metrics", metric_names=["num_spikes", "snr"], metric_params={"snr": {"peak_mode": "peak_to_peak"}} + ) + sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"]) + + model_folder = Path(__file__).parent / Path("trained_pipeline") + + model, model_info = load_model(model_folder=model_folder, trusted=["numpy.dtype"]) + model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model) + + # an error should be raised if `enforce_metric_params` is True + with pytest.raises(Exception): + model_based_classification._check_params_for_classification(enforce_metric_params=True, model_info=model_info) + + # but only a warning if `enforce_metric_params` is False + with pytest.warns(UserWarning): + model_based_classification._check_params_for_classification(enforce_metric_params=False, model_info=model_info) + + # Now test the positive case. Recompute using the default parameters + sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["num_spikes", "snr"], metric_params={}) + sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"]) + + model, model_info = load_model(model_folder=model_folder, trusted=["numpy.dtype"]) + model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model) + model_based_classification._check_params_for_classification(enforce_metric_params=True, model_info=model_info) diff --git a/src/spikeinterface/curation/tests/test_sortingview_curation.py b/src/spikeinterface/curation/tests/test_sortingview_curation.py index 945aca7937..ff80be365d 100644 --- a/src/spikeinterface/curation/tests/test_sortingview_curation.py +++ b/src/spikeinterface/curation/tests/test_sortingview_curation.py @@ -49,6 +49,9 @@ def test_gh_curation(): Test curation using GitHub URI. """ sorting = generate_sorting(num_units=10) + unit_ids_as_int = [id for id in range(sorting.get_num_units())] + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_int) + # curated link: # https://figurl.org/f?v=npm://@fi-sci/figurl-sortingview@12/dist&d=sha1://058ab901610aa9d29df565595a3cc2a81a1b08e5 gh_uri = "gh://SpikeInterface/spikeinterface/main/src/spikeinterface/curation/tests/sv-sorting-curation.json" @@ -76,6 +79,8 @@ def test_sha1_curation(): Test curation using SHA1 URI. """ sorting = generate_sorting(num_units=10) + unit_ids_as_int = [id for id in range(sorting.get_num_units())] + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_int) # from SHA1 # curated link: @@ -105,6 +110,8 @@ def test_json_curation(): Test curation using a JSON file. """ sorting = generate_sorting(num_units=10) + unit_ids_as_int = [id for id in range(sorting.get_num_units())] + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_int) # from curation.json json_file = parent_folder / "sv-sorting-curation.json" @@ -248,6 +255,8 @@ def test_json_no_merge_curation(): Test curation with no merges using a JSON file. """ sorting = generate_sorting(num_units=10) + unit_ids_as_int = [id for id in range(sorting.get_num_units())] + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_int) json_file = parent_folder / "sv-sorting-curation-no-merge.json" sorting_curated = apply_sortingview_curation(sorting, uri_or_json=json_file) diff --git a/src/spikeinterface/curation/tests/test_train_manual_curation.py b/src/spikeinterface/curation/tests/test_train_manual_curation.py new file mode 100644 index 0000000000..f455fbdb9c --- /dev/null +++ b/src/spikeinterface/curation/tests/test_train_manual_curation.py @@ -0,0 +1,285 @@ +import pytest +import numpy as np +import tempfile, csv +from pathlib import Path + +from spikeinterface.curation.tests.common import make_sorting_analyzer +from spikeinterface.curation.train_manual_curation import CurationModelTrainer, train_model + + +@pytest.fixture +def trainer(): + """A simple CurationModelTrainer object is created, which can later by used to + train models using data from `sorting_analyzer`s.""" + + folder = tempfile.mkdtemp() # Create a temporary output folder + imputation_strategies = ["median"] + scaling_techniques = ["standard_scaler"] + classifiers = ["LogisticRegression"] + metric_names = ["metric1", "metric2", "metric3"] + search_kwargs = {"cv": 3} + return CurationModelTrainer( + labels=[[0, 1, 0, 1, 0, 1, 0, 1, 0, 1]], + folder=folder, + metric_names=metric_names, + imputation_strategies=imputation_strategies, + scaling_techniques=scaling_techniques, + classifiers=classifiers, + search_kwargs=search_kwargs, + ) + + +def make_temp_training_csv(): + """Create a temporary CSV file with artificially generated quality metrics. + The data is designed to be easy to dicern between units. Even units metric + values are all `0`, while odd units metric values are all `1`. + """ + with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_file: + writer = csv.writer(temp_file) + writer.writerow(["unit_id", "metric1", "metric2", "metric3"]) + for i in range(5): + writer.writerow([i * 2, 0, 0, 0]) + writer.writerow([i * 2 + 1, 1, 1, 1]) + return temp_file.name + + +def test_load_and_preprocess_full(trainer): + """Check that we load and preprocess the csv file from `make_temp_training_csv` + correctly.""" + temp_file_path = make_temp_training_csv() + + # Load and preprocess the data from the temporary CSV file + trainer.load_and_preprocess_csv([temp_file_path]) + + # Assert that the data is loaded and preprocessed correctly + for a, row in trainer.X.iterrows(): + assert np.all(row.values == [float(a % 2)] * 3) + for a, label in enumerate(trainer.y.values): + assert label == a % 2 + for a, row in trainer.testing_metrics.iterrows(): + assert np.all(row.values == [a % 2] * 3) + assert row.name == a + + +def test_apply_scaling_imputation(trainer): + """Take a simple training and test set and check that they are corrected scaled, + using a standard scaler which rescales the training distribution to have mean 0 + and variance 1. Length between each row is 3, so if x0 is the first value in the + column, all other values are scaled as x -> 2/3(x - x0) - 1. The y (labled) values + do not get scaled.""" + + from sklearn.impute._knn import KNNImputer + from sklearn.preprocessing._data import StandardScaler + + imputation_strategy = "knn" + scaling_technique = "standard_scaler" + X_train = np.array([[1, 2, 3], [4, 5, 6]]) + X_test = np.array([[7, 8, 9], [10, 11, 12]]) + y_train = np.array([0, 1]) + y_test = np.array([2, 3]) + + X_train_scaled, X_test_scaled, y_train_scaled, y_test_scaled, imputer, scaler = trainer.apply_scaling_imputation( + imputation_strategy, scaling_technique, X_train, X_test, y_train, y_test + ) + + first_row_elements = X_train[0] + for a, row in enumerate(X_train): + assert np.all(2 / 3 * (row - first_row_elements) - 1.0 == X_train_scaled[a]) + for a, row in enumerate(X_test): + assert np.all(2 / 3 * (row - first_row_elements) - 1.0 == X_test_scaled[a]) + + assert np.all(y_train == y_train_scaled) + assert np.all(y_test == y_test_scaled) + + assert isinstance(imputer, KNNImputer) + assert isinstance(scaler, StandardScaler) + + +def test_get_classifier_search_space(trainer): + """For each classifier, there is a hyperparameter space we search over to find its + most accurate incarnation. Here, we check that we do indeed load the approprirate + dict of hyperparameter possibilities""" + + from sklearn.linear_model._logistic import LogisticRegression + + classifier = "LogisticRegression" + model, param_space = trainer.get_classifier_search_space(classifier) + + assert isinstance(model, LogisticRegression) + assert len(param_space) > 0 + assert isinstance(param_space, dict) + + +def test_get_custom_classifier_search_space(): + """Check that if a user passes a custom hyperparameter search space, that this is + passed correctly to the trainer.""" + + classifier = { + "LogisticRegression": { + "C": [0.1, 8.0], + "solver": ["lbfgs"], + "max_iter": [100, 400], + } + } + trainer = CurationModelTrainer(classifiers=classifier, labels=[[0, 1, 0, 1, 0, 1, 0, 1, 0, 1]]) + + model, param_space = trainer.get_classifier_search_space(list(classifier.keys())[0]) + assert param_space == classifier["LogisticRegression"] + + +def test_saved_files(trainer): + """During the trainer's creation, the following files should be created: + - best_model.skops + - labels.csv + - model_accuracies.csv + - model_info.json + - training_data.csv + This test checks that these exist, and checks some properties of the files.""" + + import pandas as pd + import json + + trainer.X = np.random.rand(10, 3) + trainer.y = np.append(np.ones(5), np.zeros(5)) + + trainer.evaluate_model_config() + trainer_folder = Path(trainer.folder) + + assert trainer_folder.is_dir() + + best_model_path = trainer_folder / "best_model.skops" + model_accuracies_path = trainer_folder / "model_accuracies.csv" + training_data_path = trainer_folder / "training_data.csv" + labels_path = trainer_folder / "labels.csv" + model_info_path = trainer_folder / "model_info.json" + + assert (best_model_path).is_file() + + model_accuracies = pd.read_csv(model_accuracies_path) + model_accuracies["classifier name"].values[0] == "LogisticRegression" + assert len(model_accuracies) == 1 + + training_data = pd.read_csv(training_data_path) + assert np.all(np.isclose(training_data.values[:, 1:4], trainer.X, rtol=1e-10)) + + labels = pd.read_csv(labels_path) + assert np.all(labels.values[:, 1] == trainer.y.astype("float")) + + model_info = pd.read_json(model_info_path) + + with open(model_info_path) as f: + model_info = json.load(f) + + assert set(model_info.keys()) == set(["metric_params", "requirements", "label_conversion"]) + + +def test_train_model(): + """A simple function test to check that `train_model` doesn't fail with one csv inputs""" + + metrics_path = make_temp_training_csv() + folder = tempfile.mkdtemp() + metric_names = ["metric1", "metric2", "metric3"] + trainer = train_model( + mode="csv", + metrics_paths=[metrics_path], + folder=folder, + labels=[[0, 1, 0, 1, 0, 1, 0, 1, 0, 1]], + metric_names=metric_names, + imputation_strategies=["median"], + scaling_techniques=["standard_scaler"], + classifiers=["LogisticRegression"], + overwrite=True, + search_kwargs={"cv": 3, "scoring": "balanced_accuracy", "n_iter": 1}, + ) + assert isinstance(trainer, CurationModelTrainer) + + +def test_train_model_using_two_csvs(): + """Models can be trained using more than one set of training data. This test checks + that `train_model` works with two inputs, from csv files.""" + + metrics_path_1 = make_temp_training_csv() + metrics_path_2 = make_temp_training_csv() + + folder = tempfile.mkdtemp() + metric_names = ["metric1", "metric2", "metric3"] + + trainer = train_model( + mode="csv", + metrics_paths=[metrics_path_1, metrics_path_2], + folder=folder, + labels=[[0, 1, 0, 1, 0, 1, 0, 1, 0, 1], [0, 1, 0, 1, 0, 1, 0, 1, 0, 1]], + metric_names=metric_names, + imputation_strategies=["median"], + scaling_techniques=["standard_scaler"], + classifiers=["LogisticRegression"], + overwrite=True, + ) + assert isinstance(trainer, CurationModelTrainer) + + +def test_train_using_two_sorting_analyzers(): + """Models can be trained using more than one set of training data. This test checks + that `train_model` works with two inputs, from sorting analzyers. It also checks that + an error is raised if the sorting_analyzers have different sets of metrics computed.""" + + sorting_analyzer_1 = make_sorting_analyzer() + sorting_analyzer_1.compute({"quality_metrics": {"metric_names": ["num_spikes", "snr"]}}) + + sorting_analyzer_2 = make_sorting_analyzer() + sorting_analyzer_2.compute({"quality_metrics": {"metric_names": ["num_spikes", "snr"]}}) + + labels_1 = [0, 1, 1, 1, 1] + labels_2 = [1, 1, 0, 1, 1] + + folder = tempfile.mkdtemp() + trainer = train_model( + analyzers=[sorting_analyzer_1, sorting_analyzer_2], + folder=folder, + labels=[labels_1, labels_2], + imputation_strategies=["median"], + scaling_techniques=["standard_scaler"], + classifiers=["LogisticRegression"], + overwrite=True, + ) + + assert isinstance(trainer, CurationModelTrainer) + + # Check that there is an error raised if the metric names are different + sorting_analyzer_2 = make_sorting_analyzer() + sorting_analyzer_2.compute({"quality_metrics": {"metric_names": ["num_spikes"], "delete_existing_metrics": True}}) + + with pytest.raises(Exception): + trainer = train_model( + analyzers=[sorting_analyzer_1, sorting_analyzer_2], + folder=folder, + labels=[labels_1, labels_2], + imputation_strategies=["median"], + scaling_techniques=["standard_scaler"], + classifiers=["LogisticRegression"], + overwrite=True, + ) + + # Now check that there is an error raised if we demand the same metric params, but don't have them + + sorting_analyzer_2.compute( + { + "quality_metrics": { + "metric_names": ["num_spikes", "snr"], + "metric_params": {"snr": {"peak_mode": "at_index"}}, + } + } + ) + + with pytest.raises(Exception): + train_model( + analyzers=[sorting_analyzer_1, sorting_analyzer_2], + folder=folder, + labels=[labels_1, labels_2], + imputation_strategies=["median"], + scaling_techniques=["standard_scaler"], + classifiers=["LogisticRegression"], + search_kwargs={"cv": 3, "scoring": "balanced_accuracy", "n_iter": 1}, + overwrite=True, + enforce_metric_params=True, + ) diff --git a/src/spikeinterface/curation/tests/trained_pipeline/best_model.skops b/src/spikeinterface/curation/tests/trained_pipeline/best_model.skops new file mode 100644 index 0000000000..362405f917 Binary files /dev/null and b/src/spikeinterface/curation/tests/trained_pipeline/best_model.skops differ diff --git a/src/spikeinterface/curation/tests/trained_pipeline/labels.csv b/src/spikeinterface/curation/tests/trained_pipeline/labels.csv new file mode 100644 index 0000000000..46680a9e89 --- /dev/null +++ b/src/spikeinterface/curation/tests/trained_pipeline/labels.csv @@ -0,0 +1,21 @@ +unit_index,0 +0,1 +1,0 +2,1 +3,0 +4,1 +0,1 +1,0 +2,1 +3,0 +4,1 +0,1 +1,0 +2,1 +3,0 +4,1 +0,1 +1,0 +2,1 +3,0 +4,1 diff --git a/src/spikeinterface/curation/tests/trained_pipeline/model_accuracies.csv b/src/spikeinterface/curation/tests/trained_pipeline/model_accuracies.csv new file mode 100644 index 0000000000..7f015c380b --- /dev/null +++ b/src/spikeinterface/curation/tests/trained_pipeline/model_accuracies.csv @@ -0,0 +1,2 @@ +,classifier name,imputation_strategy,scaling_strategy,accuracy,precision,recall,model_id,best_params +0,LogisticRegression,median,StandardScaler(),1.0000,1.0000,1.0000,0,"OrderedDict([('C', 4.811707275233983), ('max_iter', 384), ('solver', 'saga')])" diff --git a/src/spikeinterface/curation/tests/trained_pipeline/model_info.json b/src/spikeinterface/curation/tests/trained_pipeline/model_info.json new file mode 100644 index 0000000000..75ced28486 --- /dev/null +++ b/src/spikeinterface/curation/tests/trained_pipeline/model_info.json @@ -0,0 +1,60 @@ +{ + "metric_params": { + "quality_metric_params": { + "metric_names": [ + "snr", + "num_spikes" + ], + "peak_sign": null, + "seed": null, + "metric_params": { + "num_spikes": {}, + "snr": { + "peak_sign": "neg", + "peak_mode": "extremum" + } + }, + "skip_pc_metrics": false, + "delete_existing_metrics": false, + "metrics_to_compute": [ + "snr", + "num_spikes" + ] + }, + "template_metric_params": { + "metric_names": [ + "half_width" + ], + "sparsity": null, + "peak_sign": "neg", + "upsampling_factor": 10, + "metric_params": { + "half_width": { + "recovery_window_ms": 0.7, + "peak_relative_threshold": 0.2, + "peak_width_ms": 0.1, + "depth_direction": "y", + "min_channels_for_velocity": 5, + "min_r2_velocity": 0.5, + "exp_peak_function": "ptp", + "min_r2_exp_decay": 0.5, + "spread_threshold": 0.2, + "spread_smooth_um": 20, + "column_range": null + } + }, + "delete_existing_metrics": false, + "metrics_to_compute": [ + "half_width" + ] + } + }, + "requirements": { + "spikeinterface": "0.101.1", + "scikit-learn": "1.3.2" + }, + "label_conversion": { + "1": 1, + "0": 0 + } +} diff --git a/src/spikeinterface/curation/tests/trained_pipeline/training_data.csv b/src/spikeinterface/curation/tests/trained_pipeline/training_data.csv new file mode 100644 index 0000000000..c9efca17ad --- /dev/null +++ b/src/spikeinterface/curation/tests/trained_pipeline/training_data.csv @@ -0,0 +1,21 @@ +unit_id,snr,num_spikes,half_width +0,21.026926,5968.0,0.00027333334 +1,34.64474,5928.0,0.00023666666 +2,6.986315,5954.0,0.00026666667 +3,8.223127,6032.0,0.00020333333 +4,2.7464194,6002.0,0.00026666667 +0,21.026926,5968.0,0.00027333334 +1,34.64474,5928.0,0.00023666666 +2,6.986315,5954.0,0.00026666667 +3,8.223127,6032.0,0.00020333333 +4,2.7464194,6002.0,0.00026666667 +0,21.026926,5968.0,0.00027333334 +1,34.64474,5928.0,0.00023666666 +2,6.986315,5954.0,0.00026666667 +3,8.223127,6032.0,0.00020333333 +4,2.7464194,6002.0,0.00026666667 +0,21.026926,5968.0,0.00027333334 +1,34.64474,5928.0,0.00023666666 +2,6.986315,5954.0,0.00026666667 +3,8.223127,6032.0,0.00020333333 +4,2.7464194,6002.0,0.00026666667 diff --git a/src/spikeinterface/curation/train_manual_curation.py b/src/spikeinterface/curation/train_manual_curation.py new file mode 100644 index 0000000000..7b315b0fba --- /dev/null +++ b/src/spikeinterface/curation/train_manual_curation.py @@ -0,0 +1,843 @@ +import os +import warnings +import numpy as np +import json +import spikeinterface +from spikeinterface.core.job_tools import fix_job_kwargs +from spikeinterface.qualitymetrics import ( + get_quality_metric_list, + get_quality_pca_metric_list, + qm_compute_name_to_column_names, +) +from spikeinterface.postprocessing import get_template_metric_names +from spikeinterface.postprocessing.template_metrics import tm_compute_name_to_column_names +from pathlib import Path +from copy import deepcopy + + +def get_default_classifier_search_spaces(): + + from scipy.stats import uniform, randint + + default_classifier_search_spaces = { + "RandomForestClassifier": { + "n_estimators": [100, 150], + "criterion": ["gini", "entropy"], + "min_samples_split": [2, 4], + "min_samples_leaf": [2, 4], + "class_weight": ["balanced", "balanced_subsample"], + }, + "AdaBoostClassifier": { + "learning_rate": [1, 2], + "n_estimators": [50, 100], + "algorithm": ["SAMME", "SAMME.R"], + }, + "GradientBoostingClassifier": { + "learning_rate": uniform(0.05, 0.1), + "n_estimators": randint(100, 150), + "max_depth": [2, 4], + "min_samples_split": [2, 4], + "min_samples_leaf": [2, 4], + }, + "SVC": { + "C": uniform(0.001, 10.0), + "kernel": ["sigmoid", "rbf"], + "gamma": uniform(0.001, 10.0), + "probability": [True], + }, + "LogisticRegression": { + "C": uniform(0.001, 10.0), + "solver": ["newton-cg", "lbfgs", "liblinear", "sag", "saga"], + "max_iter": [100], + }, + "XGBClassifier": { + "max_depth": [2, 4], + "eta": uniform(0.2, 0.5), + "sampling_method": ["uniform"], + "grow_policy": ["depthwise", "lossguide"], + }, + "CatBoostClassifier": {"depth": [2, 4], "learning_rate": uniform(0.05, 0.15), "n_estimators": [100, 150]}, + "LGBMClassifier": {"learning_rate": uniform(0.05, 0.15), "n_estimators": randint(100, 150)}, + "MLPClassifier": { + "activation": ["tanh", "relu"], + "solver": ["adam"], + "alpha": uniform(1e-7, 1e-1), + "learning_rate": ["constant", "adaptive"], + "n_iter_no_change": [32], + }, + } + + return default_classifier_search_spaces + + +class CurationModelTrainer: + """ + Used to train and evaluate machine learning models for spike sorting curation. + + Parameters + ---------- + labels : list of lists, default: None + List of curated labels for each unit; must be in the same order as the metrics data. + folder : str, default: None + The folder where outputs such as models and evaluation metrics will be saved, if specified. Requires the skops library. If None, output will not be saved on file system. + metric_names : list of str, default: None + A list of metrics to use for training. If None, default metrics will be used. + imputation_strategies : list of str | None, default: None + A list of imputation strategies to try. Can be "knn”, "iterative" or any allowed + strategy passable to the sklearn `SimpleImputer`. If None, the default strategies + `["median", "most_frequent", "knn", "iterative"]` will be used. + scaling_techniques : list of str | None, default: None + A list of scaling techniques to try. Can be "standard_scaler", "min_max_scaler", + or "robust_scaler", If None, all techniques will be used. + classifiers : list of str or dict, default: None + A list of classifiers to evaluate. Optionally, a dictionary of classifiers and their hyperparameter search spaces can be provided. If None, default classifiers will be used. Check the `get_classifier_search_space` method for the default search spaces & format for custom spaces. + test_size : float, default: 0.2 + Proportion of the dataset to include in the test split, passed to `train_test_split` from `sklear`. + seed : int, default: None + Random seed for reproducibility. If None, a random seed will be generated. + smote : bool, default: False + Whether to apply SMOTE for class imbalance. Default is False. Requires imbalanced-learn package. + verbose : bool, default: True + If True, useful information is printed during training. + search_kwargs : dict or None, default: None + Keyword arguments passed to `BayesSearchCV` or `RandomizedSearchCV` from `sklearn`. If None, use + `search_kwargs = {'cv': 3, 'scoring': 'balanced_accuracy', 'n_iter': 25}`. + + Attributes + ---------- + folder : str + The folder where outputs such as models and evaluation metrics will be saved. Requires the skops library. + labels : list of lists, default: None + List of curated labels for each `sorting_analyzer` and each unit; must be in the same order as the metrics data. + imputation_strategies : list of str | None, default: None + A list of imputation strategies to try. Can be "knn”, "iterative" or any allowed + strategy passable to the sklearn `SimpleImputer`. If None, the default strategies + `["median", "most_frequent", "knn", "iterative"]` will be used. + scaling_techniques : list of str | None, default: None + A list of scaling techniques to try. Can be "standard_scaler", "min_max_scaler", + or "robust_scaler", If None, all techniques will be used. + classifiers : list of str + The list of classifiers to evaluate. + classifier_search_space : dict or None + Dictionary of classifiers and their hyperparameter search spaces, if provided. If None, default search spaces are used. + seed : int + Random seed for reproducibility. + metrics_list : list of str + The list of metrics to use for training. + X : pandas.DataFrame or None + The feature matrix after preprocessing. + y : pandas.Series or None + The target vector after preprocessing. + testing_metrics : dict or None + Dictionary to hold testing metrics data. + label_conversion : dict or None + Dictionary to map string labels to integer codes if target column contains string labels. + + Methods + ------- + get_default_metrics_list() + Returns the default list of metrics. + load_and_preprocess_full(path) + Loads and preprocesses the data from the given path. + load_data_file(path) + Loads the data file from the given path. + process_test_data_for_classification() + Processes the test data for classification. + apply_scaling_imputation(imputation_strategy, scaling_technique, X_train, X_val, y_train, y_val) + Applies the specified imputation and scaling techniques to the data. + get_classifier_instance(classifier_name) + Returns an instance of the specified classifier. + get_classifier_search_space(classifier_name) + Returns the search space for hyperparameter tuning for the specified classifier. + get_classifier_search_space() + Returns the default search spaces for hyperparameter tuning for the classifiers. + evaluate_model_config(imputation_strategies, scaling_techniques, classifiers) + Evaluates the model configurations with the given imputation strategies, scaling techniques, and classifiers. + """ + + def __init__( + self, + labels=None, + folder=None, + metric_names=None, + imputation_strategies=None, + scaling_techniques=None, + classifiers=None, + test_size=0.2, + seed=None, + smote=False, + verbose=True, + search_kwargs=None, + **job_kwargs, + ): + + import pandas as pd + + if imputation_strategies is None: + imputation_strategies = ["median", "most_frequent", "knn", "iterative"] + + if scaling_techniques is None: + scaling_techniques = [ + "standard_scaler", + "min_max_scaler", + "robust_scaler", + ] + + if classifiers is None: + self.classifiers = ["RandomForestClassifier"] + self.classifier_search_space = None + elif isinstance(classifiers, dict): + self.classifiers = list(classifiers.keys()) + self.classifier_search_space = classifiers + elif isinstance(classifiers, list): + self.classifiers = classifiers + self.classifier_search_space = None + else: + raise ValueError("classifiers must be a list or dictionary") + + # check if labels is a list of lists + if not all(isinstance(label, list) or isinstance(label, np.ndarray) for label in labels): + raise ValueError("labels must be a list of lists") + + self.folder = Path(folder) if folder is not None else None + self.imputation_strategies = imputation_strategies + self.scaling_techniques = scaling_techniques + self.test_size = test_size + self.seed = seed if seed is not None else np.random.default_rng(seed=None).integers(0, 2**31) + self.metrics_params = {} + self.smote = smote + self.label_conversion = None + self.verbose = verbose + self.search_kwargs = search_kwargs + + self.X = None + self.testing_metrics = None + + self.requirements = {"spikeinterface": spikeinterface.__version__} + + self.y = pd.concat([pd.DataFrame(one_labels)[0] for one_labels in labels]) + + self.metric_names = metric_names + + if self.folder is not None and not self.folder.is_dir(): + self.folder.mkdir(parents=True, exist_ok=True) + + # update job_kwargs with global ones + job_kwargs = fix_job_kwargs(job_kwargs) + self.n_jobs = job_kwargs["n_jobs"] + + def get_default_metrics_list(self): + """Returns the default list of metrics.""" + return get_quality_metric_list() + get_quality_pca_metric_list() + get_template_metric_names() + + def load_and_preprocess_analyzers(self, analyzers, enforce_metric_params): + """ + Loads and preprocesses the quality metrics and labels from the given list of SortingAnalyzer objects. + """ + import pandas as pd + + metrics_for_each_analyzer = [_get_computed_metrics(an) for an in analyzers] + check_metric_names_are_the_same(metrics_for_each_analyzer) + + self.testing_metrics = pd.concat(metrics_for_each_analyzer, axis=0) + + # Set metric names to those calculated if not provided + if self.metric_names is None: + warnings.warn("No metric_names provided, using all metrics calculated by the analyzers") + self.metric_names = self.testing_metrics.columns.tolist() + + conflicting_metrics = self._check_metrics_parameters(analyzers, enforce_metric_params) + + self.metrics_params = {} + + extension_names = ["quality_metrics", "template_metrics"] + metric_extensions = [analyzers[0].get_extension(extension_name) for extension_name in extension_names] + + for metric_extension, extension_name in zip(metric_extensions, extension_names): + + # remove the 's' at the end of the extension name + extension_name = extension_name[:-1] + if metric_extension is not None: + self.metrics_params[extension_name + "_params"] = metric_extension.params + + # Only save metric params which are 1) consistent and 2) exist in metric_names + metric_names = metric_extension.params["metric_names"] + consistent_metrics = list(set(metric_names).difference(set(conflicting_metrics))) + consistent_metric_params = { + metric: metric_extension.params["metric_params"][metric] for metric in consistent_metrics + } + self.metrics_params[extension_name + "_params"]["metric_params"] = consistent_metric_params + + self.process_test_data_for_classification() + + def _check_metrics_parameters(self, analyzers, enforce_metric_params): + """Checks that the metrics of each analyzer have been calcualted using the same parameters""" + + extension_names = ["quality_metrics", "template_metrics"] + + conflicting_metrics = [] + for analyzer_index_1, analyzer_1 in enumerate(analyzers): + for analyzer_index_2, analyzer_2 in enumerate(analyzers): + + if analyzer_index_1 <= analyzer_index_2: + continue + else: + + metric_params_1 = {} + metric_params_2 = {} + + for extension_name in extension_names: + if (extension_1 := analyzer_1.get_extension(extension_name)) is not None: + metric_params_1.update(extension_1.params["metric_params"]) + if (extension_2 := analyzer_2.get_extension(extension_name)) is not None: + metric_params_2.update(extension_2.params["metric_params"]) + + conflicting_metrics_between_1_2 = [] + # check quality metrics params + for metric, params_1 in metric_params_1.items(): + if params_1 != metric_params_2.get(metric): + conflicting_metrics_between_1_2.append(metric) + + conflicting_metrics += conflicting_metrics_between_1_2 + + if len(conflicting_metrics_between_1_2) > 0: + warning_message = f"Parameters used to calculate {conflicting_metrics_between_1_2} are different for sorting_analyzers #{analyzer_index_1} and #{analyzer_index_2}" + if enforce_metric_params is True: + raise Exception(warning_message) + else: + warnings.warn(warning_message) + + unique_conflicting_metrics = set(conflicting_metrics) + return unique_conflicting_metrics + + def load_and_preprocess_csv(self, paths): + self._load_data_files(paths) + self.process_test_data_for_classification() + self.get_metric_params_csv() + + def get_metric_params_csv(self): + + from itertools import chain + + qm_metric_names = list(chain.from_iterable(qm_compute_name_to_column_names.values())) + tm_metric_names = list(chain.from_iterable(tm_compute_name_to_column_names.values())) + + quality_metric_names = [] + template_metric_names = [] + + for metric_name in self.metric_names: + if metric_name in qm_metric_names: + quality_metric_names.append(metric_name) + if metric_name in tm_metric_names: + template_metric_names.append(metric_name) + + self.metrics_params = {} + if quality_metric_names != {}: + self.metrics_params["quality_metric_params"] = {"metric_names": quality_metric_names} + if template_metric_names != {}: + self.metrics_params["template_metric_params"] = {"metric_names": template_metric_names} + + return + + def process_test_data_for_classification(self): + """ + Cleans the input data so that it can be used by sklearn. + + Extracts the target variable and features from the loaded dataset. + It handles string labels by converting them to integer codes and reindexes the + feature matrix to match the specified metrics list. Infinite values in the features + are replaced with NaN, and any remaining NaN values are filled with zeros. + + Raises + ------ + ValueError + If the target column specified is not found in the loaded dataset. + + Notes + ----- + If the target column contains string labels, a warning is issued and the labels + are converted to integer codes. The mapping from string labels to integer codes + is stored in the `label_conversion` attribute. + """ + + # Convert string labels to integer codes to allow classification + new_y = self.y.astype("category").cat.codes + self.label_conversion = dict(zip(new_y, self.y)) + self.y = new_y + + # Extract features + try: + if (set(self.metric_names) - set(self.testing_metrics.columns) != set()) and self.verbose is True: + print( + f"Dropped metrics (calculated but not included in metric_names): {set(self.testing_metrics.columns) - set(self.metric_names)}" + ) + self.X = self.testing_metrics[self.metric_names] + except KeyError as e: + raise KeyError(f"{str(e)}, metrics_list contains invalid metric names") + + self.X = self.testing_metrics.reindex(columns=self.metric_names) + self.X = _format_metric_dataframe(self.X) + + def apply_scaling_imputation(self, imputation_strategy, scaling_technique, X_train, X_test, y_train, y_test): + """Impute and scale the data using the specified techniques.""" + from sklearn.experimental import enable_iterative_imputer + from sklearn.impute import SimpleImputer, KNNImputer, IterativeImputer + from sklearn.ensemble import HistGradientBoostingRegressor + from sklearn.preprocessing import StandardScaler, MinMaxScaler, RobustScaler + + if imputation_strategy == "knn": + imputer = KNNImputer(n_neighbors=5) + elif imputation_strategy == "iterative": + imputer = IterativeImputer( + estimator=HistGradientBoostingRegressor(random_state=self.seed), random_state=self.seed + ) + else: + imputer = SimpleImputer(strategy=imputation_strategy) + + if scaling_technique == "standard_scaler": + scaler = StandardScaler() + elif scaling_technique == "min_max_scaler": + scaler = MinMaxScaler() + elif scaling_technique == "robust_scaler": + scaler = RobustScaler() + else: + raise ValueError( + f"Unknown scaling technique: {scaling_technique}. Supported scaling techniques are 'standard_scaler', 'min_max_scaler' and 'robust_scaler." + ) + + y_train_processed = y_train.astype(int) + y_test = y_test.astype(int) + + X_train_imputed = imputer.fit_transform(X_train) + X_test_imputed = imputer.transform(X_test) + X_train_processed = scaler.fit_transform(X_train_imputed) + X_test_processed = scaler.transform(X_test_imputed) + + # Apply SMOTE for class imbalance + if self.smote: + try: + from imblearn.over_sampling import SMOTE + except ModuleNotFoundError: + raise ModuleNotFoundError("Please install imbalanced-learn package to use SMOTE") + smote = SMOTE(random_state=self.seed) + X_train_processed, y_train_processed = smote.fit_resample(X_train_processed, y_train_processed) + + return X_train_processed, X_test_processed, y_train_processed, y_test, imputer, scaler + + def get_classifier_instance(self, classifier_name): + from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier, GradientBoostingClassifier + from sklearn.svm import SVC + from sklearn.linear_model import LogisticRegression + from sklearn.neural_network import MLPClassifier + + classifier_mapping = { + "RandomForestClassifier": RandomForestClassifier(random_state=self.seed), + "AdaBoostClassifier": AdaBoostClassifier(random_state=self.seed), + "GradientBoostingClassifier": GradientBoostingClassifier(random_state=self.seed), + "SVC": SVC(random_state=self.seed), + "LogisticRegression": LogisticRegression(random_state=self.seed), + "MLPClassifier": MLPClassifier(random_state=self.seed), + } + + # Check lightgbm package install + if classifier_name == "LGBMClassifier": + try: + import lightgbm + + self.requirements["lightgbm"] = lightgbm.__version__ + classifier_mapping["LGBMClassifier"] = lightgbm.LGBMClassifier(random_state=self.seed, verbose=-1) + except ImportError: + raise ImportError("Please install lightgbm package to use LGBMClassifier") + elif classifier_name == "CatBoostClassifier": + try: + import catboost + + self.requirements["catboost"] = catboost.__version__ + classifier_mapping["CatBoostClassifier"] = catboost.CatBoostClassifier( + silent=True, random_state=self.seed + ) + except ImportError: + raise ImportError("Please install catboost package to use CatBoostClassifier") + elif classifier_name == "XGBClassifier": + try: + import xgboost + + self.requirements["xgboost"] = xgboost.__version__ + classifier_mapping["XGBClassifier"] = xgboost.XGBClassifier( + use_label_encoder=False, random_state=self.seed + ) + except ImportError: + raise ImportError("Please install xgboost package to use XGBClassifier") + + if classifier_name not in classifier_mapping: + raise ValueError( + f"Unknown classifier: {classifier_name}. To see list of supported classifiers run\n\t>>> from spikeinterface.curation import get_default_classifier_search_spaces\n\t>>> print(get_default_classifier_search_spaces().keys())" + ) + + return classifier_mapping[classifier_name] + + def get_classifier_search_space(self, classifier_name): + + default_classifier_search_spaces = get_default_classifier_search_spaces() + + if classifier_name not in default_classifier_search_spaces: + raise ValueError( + f"Unknown classifier: {classifier_name}. To see list of supported classifiers run\n\t>>> from spikeinterface.curation import get_default_classifier_search_spaces\n\t>>> print(get_default_classifier_search_spaces().keys())" + ) + + model = self.get_classifier_instance(classifier_name) + if self.classifier_search_space is not None: + param_space = self.classifier_search_space[classifier_name] + else: + param_space = default_classifier_search_spaces[classifier_name] + return model, param_space + + def evaluate_model_config(self): + """ + Evaluates the model configurations with the given imputation strategies, scaling techniques, and classifiers. + + This method splits the preprocessed data into training and testing sets, then evaluates the specified + combinations of imputation strategies, scaling techniques, and classifiers. The evaluation results are + saved to the output folder. + + Raises + ------ + ValueError + If any of the specified classifier names are not recognized. + + Notes + ----- + The method converts the classifier names to actual classifier instances before evaluating them. + The evaluation results, including the best model and its parameters, are saved to the output folder. + """ + from sklearn.model_selection import train_test_split + + X_train, X_test, y_train, y_test = train_test_split( + self.X, self.y, test_size=self.test_size, random_state=self.seed, stratify=self.y + ) + classifier_instances = [self.get_classifier_instance(clf) for clf in self.classifiers] + self._evaluate( + self.imputation_strategies, + self.scaling_techniques, + classifier_instances, + X_train, + X_test, + y_train, + y_test, + self.search_kwargs, + ) + + def _load_data_files(self, paths): + import pandas as pd + + self.testing_metrics = pd.concat([pd.read_csv(path, index_col=0) for path in paths], axis=0) + + def _evaluate( + self, imputation_strategies, scaling_techniques, classifiers, X_train, X_test, y_train, y_test, search_kwargs + ): + from joblib import Parallel, delayed + from sklearn.pipeline import Pipeline + import pandas as pd + + results = Parallel(n_jobs=self.n_jobs)( + delayed(self._train_and_evaluate)( + imputation_strategy, scaler, classifier, X_train, X_test, y_train, y_test, idx, search_kwargs + ) + for idx, (imputation_strategy, scaler, classifier) in enumerate( + (imputation_strategy, scaler, classifier) + for imputation_strategy in imputation_strategies + for scaler in scaling_techniques + for classifier in classifiers + ) + ) + + test_accuracies, models = zip(*results) + + if self.search_kwargs is None or self.search_kwargs.get("scoring"): + scoring_method = "balanced_accuracy" + else: + scoring_method = self.search_kwargs.get("scoring") + + self.test_accuracies_df = pd.DataFrame(test_accuracies).sort_values(scoring_method, ascending=False) + + best_model_id = int(self.test_accuracies_df.iloc[0]["model_id"]) + best_model, best_imputer, best_scaler = models[best_model_id] + + best_pipeline = Pipeline( + [("imputer", best_imputer), ("scaler", best_scaler), ("classifier", best_model.best_estimator_)] + ) + + self.best_pipeline = best_pipeline + + if self.folder is not None: + self._save() + + def _save(self): + from skops.io import dump + import sklearn + import pandas as pd + + # export training data and labels + pd.DataFrame(self.X).to_csv(self.folder / f"training_data.csv", index_label="unit_id") + pd.DataFrame(self.y).to_csv(self.folder / f"labels.csv", index_label="unit_index") + + self.requirements["scikit-learn"] = sklearn.__version__ + + # Dump to skops if folder is provided + dump(self.best_pipeline, self.folder / f"best_model.skops") + self.test_accuracies_df.to_csv(self.folder / f"model_accuracies.csv", float_format="%.4f") + + model_info = {} + model_info["metric_params"] = self.metrics_params + + model_info["requirements"] = self.requirements + + model_info["label_conversion"] = self.label_conversion + + param_file = self.folder / "model_info.json" + Path(param_file).write_text(json.dumps(model_info, indent=4), encoding="utf8") + + def _train_and_evaluate( + self, imputation_strategy, scaler, classifier, X_train, X_test, y_train, y_test, model_id, search_kwargs + ): + from sklearn.metrics import balanced_accuracy_score, precision_score, recall_score + + search_kwargs = set_default_search_kwargs(search_kwargs) + + X_train_scaled, X_test_scaled, y_train, y_test, imputer, scaler = self.apply_scaling_imputation( + imputation_strategy, scaler, X_train, X_test, y_train, y_test + ) + if self.verbose is True: + print(f"Running {classifier.__class__.__name__} with imputation {imputation_strategy} and scaling {scaler}") + model, param_space = self.get_classifier_search_space(classifier.__class__.__name__) + + try: + from skopt import BayesSearchCV + + model = BayesSearchCV( + model, + param_space, + random_state=self.seed, + **search_kwargs, + ) + except: + if self.verbose is True: + print("BayesSearchCV from scikit-optimize not available, using RandomizedSearchCV") + from sklearn.model_selection import RandomizedSearchCV + + model = RandomizedSearchCV(model, param_space, **search_kwargs) + + model.fit(X_train_scaled, y_train) + y_pred = model.predict(X_test_scaled) + balanced_acc = balanced_accuracy_score(y_test, y_pred) + precision = precision_score(y_test, y_pred, average="macro") + recall = recall_score(y_test, y_pred, average="macro") + return { + "classifier name": classifier.__class__.__name__, + "imputation_strategy": imputation_strategy, + "scaling_strategy": scaler, + "balanced_accuracy": balanced_acc, + "precision": precision, + "recall": recall, + "model_id": model_id, + "best_params": model.best_params_, + }, (model, imputer, scaler) + + +def train_model( + mode="analyzers", + labels=None, + analyzers=None, + metrics_paths=None, + folder=None, + metric_names=None, + imputation_strategies=None, + scaling_techniques=None, + classifiers=None, + test_size=0.2, + overwrite=False, + seed=None, + search_kwargs=None, + verbose=True, + enforce_metric_params=False, + **job_kwargs, +): + """ + Trains and evaluates machine learning models for spike sorting curation. + + This function initializes a `CurationModelTrainer` object, loads and preprocesses the data, + and evaluates the specified combinations of imputation strategies, scaling techniques, and classifiers. + The evaluation results, including the best model and its parameters, are saved to the output folder. + + Parameters + ---------- + mode : "analyzers" | "csv", default: "analyzers" + Mode to use for training. + analyzers : list of SortingAnalyzer | None, default: None + List of SortingAnalyzer objects containing the quality metrics and labels to use for training, if using 'analyzers' mode. + labels : list of list | None, default: None + List of curated labels for each unit; must be in the same order as the metrics data. + metrics_paths : list of str or None, default: None + List of paths to the CSV files containing the metrics data if using 'csv' mode. + folder : str | None, default: None + The folder where outputs such as models and evaluation metrics will be saved. + metric_names : list of str | None, default: None + A list of metrics to use for training. If None, default metrics will be used. + imputation_strategies : list of str | None, default: None + A list of imputation strategies to try. Can be "knn”, "iterative" or any allowed + strategy passable to the sklearn `SimpleImputer`. If None, the default strategies + `["median", "most_frequent", "knn", "iterative"]` will be used. + scaling_techniques : list of str | None, default: None + A list of scaling techniques to try. Can be "standard_scaler", "min_max_scaler", + or "robust_scaler", If None, all techniques will be used. + classifiers : list of str | dict | None, default: None + A list of classifiers to evaluate. Optionally, a dictionary of classifiers and their hyperparameter search spaces can be provided. If None, default classifiers will be used. Check the `get_classifier_search_space` method for the default search spaces & format for custom spaces. + test_size : float, default: 0.2 + Proportion of the dataset to include in the test split, passed to `train_test_split` from `sklear`. + overwrite : bool, default: False + Overwrites the `folder` if it already exists + seed : int | None, default: None + Random seed for reproducibility. If None, a random seed will be generated. + search_kwargs : dict or None, default: None + Keyword arguments passed to `BayesSearchCV` or `RandomizedSearchCV` from `sklearn`. If None, use + `search_kwargs = {'cv': 3, 'scoring': 'balanced_accuracy', 'n_iter': 25}`. + verbose : bool, default: True + If True, useful information is printed during training. + enforce_metric_params : bool, default: False + If True and metric parameters used to calculate metrics for different `sorting_analyzer`s are + different, an error will be raised. + + + Returns + ------- + CurationModelTrainer + The `CurationModelTrainer` object used for training and evaluation. + + Notes + ----- + This function handles the entire workflow of initializing the trainer, loading and preprocessing the data, + and evaluating the models. The evaluation results are saved to the specified output folder. + """ + + if folder is None: + raise Exception("You must supply a folder for the model to be saved in using `folder='path/to/folder/'`") + + if overwrite is False: + assert not Path(folder).is_dir(), f"folder {folder} already exists, choose another name or use overwrite=True" + + if labels is None: + raise Exception("You must supply a list of lists of curated labels using `labels = [[...],[...],...]`") + + if mode not in ["analyzers", "csv"]: + raise Exception("`mode` must be equal to 'analyzers' or 'csv'.") + + if (test_size > 1.0) or (0.0 > test_size): + raise Exception("`test_size` must be between 0.0 and 1.0") + + trainer = CurationModelTrainer( + labels=labels, + folder=folder, + metric_names=metric_names, + imputation_strategies=imputation_strategies, + scaling_techniques=scaling_techniques, + classifiers=classifiers, + test_size=test_size, + seed=seed, + verbose=verbose, + search_kwargs=search_kwargs, + **job_kwargs, + ) + + if mode == "analyzers": + assert analyzers is not None, "Analyzers must be provided as a list for mode 'analyzers'" + trainer.load_and_preprocess_analyzers(analyzers, enforce_metric_params) + + elif mode == "csv": + for metrics_path in metrics_paths: + assert Path(metrics_path).is_file(), f"{metrics_path} is not a file." + trainer.load_and_preprocess_csv(metrics_paths) + + trainer.evaluate_model_config() + return trainer + + +def _get_computed_metrics(sorting_analyzer): + """Loads and organises the computed metrics from a sorting_analyzer into a single dataframe""" + + import pandas as pd + + quality_metrics, template_metrics = try_to_get_metrics_from_analyzer(sorting_analyzer) + calculated_metrics = pd.concat([quality_metrics, template_metrics], axis=1) + + # Remove any metrics for non-existent units, raise error if no units are present + calculated_metrics = calculated_metrics.loc[calculated_metrics.index.isin(sorting_analyzer.sorting.get_unit_ids())] + if calculated_metrics.shape[0] == 0: + raise ValueError("No units present in sorting data") + + return calculated_metrics + + +def try_to_get_metrics_from_analyzer(sorting_analyzer): + + extension_names = ["quality_metrics", "template_metrics"] + metric_extensions = [sorting_analyzer.get_extension(extension_name) for extension_name in extension_names] + + if any(metric_extensions) is False: + raise ValueError( + "At least one of quality metrics or template metrics must be computed before classification.", + "Compute both using `sorting_analyzer.compute('quality_metrics', 'template_metrics')", + ) + + metric_extensions_data = [] + for metric_extension in metric_extensions: + try: + metric_extensions_data.append(metric_extension.get_data()) + except: + metric_extensions_data.append(None) + + return metric_extensions_data + + +def set_default_search_kwargs(search_kwargs): + + if search_kwargs is None: + search_kwargs = {} + + if search_kwargs.get("cv") is None: + search_kwargs["cv"] = 5 + if search_kwargs.get("scoring") is None: + search_kwargs["scoring"] = "balanced_accuracy" + if search_kwargs.get("n_iter") is None: + search_kwargs["n_iter"] = 25 + + return search_kwargs + + +def check_metric_names_are_the_same(metrics_for_each_analyzer): + """ + Given a list of dataframes, checks that the keys are all equal. + """ + + for i, metrics_for_analyzer_1 in enumerate(metrics_for_each_analyzer): + for j, metrics_for_analyzer_2 in enumerate(metrics_for_each_analyzer): + if i > j: + metric_names_1 = set(metrics_for_analyzer_1.keys()) + metric_names_2 = set(metrics_for_analyzer_2.keys()) + if metric_names_1 != metric_names_2: + metrics_in_1_but_not_2 = metric_names_1.difference(metric_names_2) + metrics_in_2_but_not_1 = metric_names_2.difference(metric_names_1) + + error_message = f"Computed metrics are not equal for sorting_analyzers #{j} and #{i}\n" + if metrics_in_1_but_not_2: + error_message += f"#{j} does not contain {metrics_in_1_but_not_2}, which #{i} does." + if metrics_in_2_but_not_1: + error_message += f"#{i} does not contain {metrics_in_2_but_not_1}, which #{j} does." + raise Exception(error_message) + + +def _format_metric_dataframe(input_data): + + input_data = input_data.map(lambda x: np.nan if np.isinf(x) else x) + input_data = input_data.astype("float32") + + return input_data diff --git a/src/spikeinterface/extractors/neoextractors/openephys.py b/src/spikeinterface/extractors/neoextractors/openephys.py index 24bc7591e4..dd24e6cae7 100644 --- a/src/spikeinterface/extractors/neoextractors/openephys.py +++ b/src/spikeinterface/extractors/neoextractors/openephys.py @@ -250,6 +250,7 @@ def __init__( except: warnings.warn(f"Could not load synchronized timestamps for {stream_name}") + self.annotate(experiment_name=f"experiment{exp_id}") self._stream_folders = stream_folders self._kwargs.update( diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index d797e64910..171992f6b1 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -599,6 +599,8 @@ def __init__( else: gains, offsets, locations, groups = self._fetch_main_properties_backend() self.extra_requirements.append("h5py") + if stream_mode is not None: + self.extra_requirements.append(stream_mode) self.set_channel_gains(gains) self.set_channel_offsets(offsets) if locations is not None: @@ -1100,6 +1102,8 @@ def __init__( for property_name, property_values in properties.items(): values = [x.decode("utf-8") if isinstance(x, bytes) else x for x in property_values] self.set_property(property_name, values) + if stream_mode is not None: + self.extra_requirements.append(stream_mode) if stream_mode is None and file_path is not None: file_path = str(Path(file_path).resolve()) diff --git a/src/spikeinterface/extractors/tests/test_mdaextractors.py b/src/spikeinterface/extractors/tests/test_mdaextractors.py index 0ef6697c6c..78e6afb65e 100644 --- a/src/spikeinterface/extractors/tests/test_mdaextractors.py +++ b/src/spikeinterface/extractors/tests/test_mdaextractors.py @@ -9,6 +9,12 @@ def test_mda_extractors(create_cache_folder): cache_folder = create_cache_folder rec, sort = generate_ground_truth_recording(durations=[10.0], num_units=10) + ids_as_integers = [id for id in range(rec.get_num_channels())] + rec = rec.rename_channels(new_channel_ids=ids_as_integers) + + ids_as_integers = [id for id in range(sort.get_num_units())] + sort = sort.rename_units(new_unit_ids=ids_as_integers) + MdaRecordingExtractor.write_recording(rec, cache_folder / "mdatest") rec_mda = MdaRecordingExtractor(cache_folder / "mdatest") probe = rec_mda.get_probe() diff --git a/src/spikeinterface/postprocessing/tests/test_multi_extensions.py b/src/spikeinterface/postprocessing/tests/test_multi_extensions.py index bf0000135c..be0070d94a 100644 --- a/src/spikeinterface/postprocessing/tests/test_multi_extensions.py +++ b/src/spikeinterface/postprocessing/tests/test_multi_extensions.py @@ -23,6 +23,11 @@ def get_dataset(): seed=2205, ) + channel_ids_as_integers = [id for id in range(recording.get_num_channels())] + unit_ids_as_integers = [id for id in range(sorting.get_num_units())] + recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers) + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers) + # since templates are going to be averaged and this might be a problem for amplitude scaling # we select the 3 units with the largest templates to split analyzer_raw = create_sorting_analyzer(sorting, recording, format="memory", sparse=False) diff --git a/src/spikeinterface/preprocessing/tests/test_clip.py b/src/spikeinterface/preprocessing/tests/test_clip.py index 724ba2c963..c18c7d37af 100644 --- a/src/spikeinterface/preprocessing/tests/test_clip.py +++ b/src/spikeinterface/preprocessing/tests/test_clip.py @@ -14,12 +14,12 @@ def test_clip(): rec1 = clip(rec, a_min=-1.5) rec1.save(verbose=False) - traces0 = rec0.get_traces(segment_index=0, channel_ids=[1]) + traces0 = rec0.get_traces(segment_index=0, channel_ids=["1"]) assert traces0.shape[1] == 1 assert np.all(-2 <= traces0[0] <= 3) - traces1 = rec1.get_traces(segment_index=0, channel_ids=[0, 1]) + traces1 = rec1.get_traces(segment_index=0, channel_ids=["0", "1"]) assert traces1.shape[1] == 2 assert np.all(-1.5 <= traces1[1]) @@ -34,11 +34,11 @@ def test_blank_staturation(): rec1 = blank_staturation(rec, quantile_threshold=0.01, direction="both", chunk_size=10000) rec1.save(verbose=False) - traces0 = rec0.get_traces(segment_index=0, channel_ids=[1]) + traces0 = rec0.get_traces(segment_index=0, channel_ids=["1"]) assert traces0.shape[1] == 1 assert np.all(traces0 < 3.0) - traces1 = rec1.get_traces(segment_index=0, channel_ids=[0]) + traces1 = rec1.get_traces(segment_index=0, channel_ids=["0"]) assert traces1.shape[1] == 1 # use a smaller value to be sure a_min = rec1._recording_segments[0].a_min diff --git a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py index 1189f04f7d..06bde4e3d1 100644 --- a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py +++ b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py @@ -163,7 +163,9 @@ def test_output_values(): expected_weights = np.r_[np.tile(np.exp(-2), 3), np.exp(-4)] expected_weights /= np.sum(expected_weights) - si_interpolated_recording = spre.interpolate_bad_channels(recording, bad_channel_indexes, sigma_um=1, p=1) + si_interpolated_recording = spre.interpolate_bad_channels( + recording, bad_channel_ids=bad_channel_ids, sigma_um=1, p=1 + ) si_interpolated = si_interpolated_recording.get_traces() expected_ts = si_interpolated[:, 1:] @ expected_weights diff --git a/src/spikeinterface/preprocessing/tests/test_normalize_scale.py b/src/spikeinterface/preprocessing/tests/test_normalize_scale.py index 576b570832..151752e0e6 100644 --- a/src/spikeinterface/preprocessing/tests/test_normalize_scale.py +++ b/src/spikeinterface/preprocessing/tests/test_normalize_scale.py @@ -15,7 +15,7 @@ def test_normalize_by_quantile(): rec2 = normalize_by_quantile(rec, mode="by_channel") rec2.save(verbose=False) - traces = rec2.get_traces(segment_index=0, channel_ids=[1]) + traces = rec2.get_traces(segment_index=0, channel_ids=["1"]) assert traces.shape[1] == 1 rec2 = normalize_by_quantile(rec, mode="pool_channel") diff --git a/src/spikeinterface/preprocessing/tests/test_rectify.py b/src/spikeinterface/preprocessing/tests/test_rectify.py index b8bb31015e..a2a06e7a1f 100644 --- a/src/spikeinterface/preprocessing/tests/test_rectify.py +++ b/src/spikeinterface/preprocessing/tests/test_rectify.py @@ -15,7 +15,7 @@ def test_rectify(): rec2 = rectify(rec) rec2.save(verbose=False) - traces = rec2.get_traces(segment_index=0, channel_ids=[1]) + traces = rec2.get_traces(segment_index=0, channel_ids=["1"]) assert traces.shape[1] == 1 # import matplotlib.pyplot as plt diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index ca21f1e45f..c789d1af82 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -42,6 +42,9 @@ max_spikes=10000, min_spikes=10, min_fr=0.0, n_neighbors=4, n_components=10, radius_um=100, peak_sign="neg" ), silhouette=dict(method=("simplified",)), + isolation_distance=dict(), + l_ratio=dict(), + d_prime=dict(), ) diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index d71450853f..11ce3d0160 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -17,6 +17,7 @@ _misc_metric_name_to_func, _possible_pc_metric_names, qm_compute_name_to_column_names, + column_name_to_column_dtype, ) from .misc_metrics import _default_params as misc_metrics_params from .pca_metrics import _default_params as pca_metrics_params @@ -140,13 +141,20 @@ def _merge_extension_data( all_unit_ids = new_sorting_analyzer.unit_ids not_new_ids = all_unit_ids[~np.isin(all_unit_ids, new_unit_ids)] + # this creates a new metrics dictionary, but the dtype for everything will be + # object. So we will need to fix this later after computing metrics metrics = pd.DataFrame(index=all_unit_ids, columns=old_metrics.columns) - metrics.loc[not_new_ids, :] = old_metrics.loc[not_new_ids, :] metrics.loc[new_unit_ids, :] = self._compute_metrics( new_sorting_analyzer, new_unit_ids, verbose, metric_names, **job_kwargs ) + # we need to fix the dtypes after we compute everything because we have nans + # we can iterate through the columns and convert them back to the dtype + # of the original quality dataframe. + for column in old_metrics.columns: + metrics[column] = metrics[column].astype(old_metrics[column].dtype) + new_data = dict(metrics=metrics) return new_data @@ -229,10 +237,20 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri # add NaN for empty units if len(empty_unit_ids) > 0: metrics.loc[empty_unit_ids] = np.nan + # num_spikes is an int and should be 0 + if "num_spikes" in metrics.columns: + metrics.loc[empty_unit_ids, ["num_spikes"]] = 0 # we use the convert_dtypes to convert the columns to the most appropriate dtype and avoid object columns # (in case of NaN values) metrics = metrics.convert_dtypes() + + # we do this because the convert_dtypes infers the wrong types sometimes. + # the actual types for columns can be found in column_name_to_column_dtype dictionary. + for column in metrics.columns: + if column in column_name_to_column_dtype: + metrics[column] = metrics[column].astype(column_name_to_column_dtype[column]) + return metrics def _run(self, verbose=False, **job_kwargs): diff --git a/src/spikeinterface/qualitymetrics/quality_metric_list.py b/src/spikeinterface/qualitymetrics/quality_metric_list.py index fc7e92b50d..23b781eb9d 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_list.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_list.py @@ -66,7 +66,11 @@ "amplitude_cutoff": ["amplitude_cutoff"], "amplitude_median": ["amplitude_median"], "amplitude_cv": ["amplitude_cv_median", "amplitude_cv_range"], - "synchrony": ["sync_spike_2", "sync_spike_4", "sync_spike_8"], + "synchrony": [ + "sync_spike_2", + "sync_spike_4", + "sync_spike_8", + ], "firing_range": ["firing_range"], "drift": ["drift_ptp", "drift_std", "drift_mad"], "sd_ratio": ["sd_ratio"], @@ -79,3 +83,38 @@ "silhouette": ["silhouette"], "silhouette_full": ["silhouette_full"], } + +# this dict allows us to ensure the appropriate dtype of metrics rather than allow Pandas to infer them +column_name_to_column_dtype = { + "num_spikes": int, + "firing_rate": float, + "presence_ratio": float, + "snr": float, + "isi_violations_ratio": float, + "isi_violations_count": float, + "rp_violations": float, + "rp_contamination": float, + "sliding_rp_violation": float, + "amplitude_cutoff": float, + "amplitude_median": float, + "amplitude_cv_median": float, + "amplitude_cv_range": float, + "sync_spike_2": float, + "sync_spike_4": float, + "sync_spike_8": float, + "firing_range": float, + "drift_ptp": float, + "drift_std": float, + "drift_mad": float, + "sd_ratio": float, + "isolation_distance": float, + "l_ratio": float, + "d_prime": float, + "nn_hit_rate": float, + "nn_miss_rate": float, + "nn_isolation": float, + "nn_unit_id": float, + "nn_noise_overlap": float, + "silhouette": float, + "silhouette_full": float, +} diff --git a/src/spikeinterface/qualitymetrics/tests/conftest.py b/src/spikeinterface/qualitymetrics/tests/conftest.py index 01fa16c8d7..ac1789a375 100644 --- a/src/spikeinterface/qualitymetrics/tests/conftest.py +++ b/src/spikeinterface/qualitymetrics/tests/conftest.py @@ -16,6 +16,11 @@ def small_sorting_analyzer(): seed=1205, ) + channel_ids_as_integers = [id for id in range(recording.get_num_channels())] + unit_ids_as_integers = [id for id in range(sorting.get_num_units())] + recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers) + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers) + sorting = sorting.select_units([2, 7, 0], ["#3", "#9", "#4"]) sorting_analyzer = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory") @@ -60,6 +65,11 @@ def sorting_analyzer_simple(): seed=1205, ) + channel_ids_as_integers = [id for id in range(recording.get_num_channels())] + unit_ids_as_integers = [id for id in range(sorting.get_num_units())] + recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers) + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers) + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True) sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=1205) diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index 60f0490f51..ea8939ebb4 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -48,6 +48,33 @@ def test_compute_quality_metrics(sorting_analyzer_simple): assert "isolation_distance" in metrics.columns +def test_merging_quality_metrics(sorting_analyzer_simple): + + sorting_analyzer = sorting_analyzer_simple + + metrics = compute_quality_metrics( + sorting_analyzer, + metric_names=None, + qm_params=dict(isi_violation=dict(isi_threshold_ms=2)), + skip_pc_metrics=False, + seed=2205, + ) + + # sorting_analyzer_simple has ten units + new_sorting_analyzer = sorting_analyzer.merge_units([[0, 1]]) + + new_metrics = new_sorting_analyzer.get_extension("quality_metrics").get_data() + + # we should copy over the metrics after merge + for column in metrics.columns: + assert column in new_metrics.columns + # should copy dtype too + assert metrics[column].dtype == new_metrics[column].dtype + + # 10 units vs 9 units + assert len(metrics.index) > len(new_metrics.index) + + def test_compute_quality_metrics_recordingless(sorting_analyzer_simple): sorting_analyzer = sorting_analyzer_simple @@ -106,10 +133,15 @@ def test_empty_units(sorting_analyzer_simple): seed=2205, ) - for empty_unit_id in sorting_empty.get_empty_unit_ids(): + # num_spikes are ints not nans so we confirm empty units are nans for everything except + # num_spikes which should be 0 + nan_containing_columns = [column for column in metrics_empty.columns if column != "num_spikes"] + for empty_unit_ids in sorting_empty.get_empty_unit_ids(): from pandas import isnull - assert np.all(isnull(metrics_empty.loc[empty_unit_id].values)) + assert np.all(isnull(metrics_empty.loc[empty_unit_ids, nan_containing_columns].values)) + if "num_spikes" in metrics_empty.columns: + assert sum(metrics_empty.loc[empty_unit_ids, ["num_spikes"]]) == 0 # TODO @alessio all theses old test should be moved in test_metric_functions.py or test_pca_metrics() diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 2a9fb34267..ec15506006 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -66,7 +66,7 @@ class Kilosort4Sorter(BaseSorter): "do_correction": True, "keep_good_only": False, "skip_kilosort_preprocessing": False, - "use_binary_file": None, + "use_binary_file": True, "delete_recording_dat": True, } @@ -116,7 +116,7 @@ class Kilosort4Sorter(BaseSorter): "keep_good_only": "If True, only the units labeled as 'good' by Kilosort are returned in the output. (spikeinterface parameter)", "use_binary_file": "If True then Kilosort is run using a binary file. In this case, if the input recording is not binary compatible, it is written to a binary file in the output folder. " "If False then Kilosort is run on the recording object directly using the RecordingExtractorAsArray object. If None, then if the recording is binary compatible, the sorter will use the binary file, otherwise the RecordingExtractorAsArray. " - "Default is None. (spikeinterface parameter)", + "Default is True. (spikeinterface parameter)", "delete_recording_dat": "If True, if a temporary binary file is created, it is deleted after the sorting is done. Default is True. (spikeinterface parameter)", } diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index eed693b343..a3a3523591 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -6,12 +6,16 @@ import numpy as np from spikeinterface.core import NumpySorting -from spikeinterface.core.job_tools import fix_job_kwargs +from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs from spikeinterface.core.recording_tools import get_noise_levels from spikeinterface.core.template import Templates from spikeinterface.core.waveform_tools import estimate_templates from spikeinterface.preprocessing import common_reference, whiten, bandpass_filter, correct_motion -from spikeinterface.sortingcomponents.tools import cache_preprocessing +from spikeinterface.sortingcomponents.tools import ( + cache_preprocessing, + get_prototype_and_waveforms_from_recording, + get_shuffled_recording_slices, +) from spikeinterface.core.basesorting import minimum_spike_dtype from spikeinterface.core.sparsity import compute_sparsity from spikeinterface.core.sortinganalyzer import create_sorting_analyzer @@ -26,7 +30,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): _default_params = { "general": {"ms_before": 2, "ms_after": 2, "radius_um": 75}, "sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25}, - "filtering": {"freq_min": 150, "freq_max": 7000, "ftype": "bessel", "filter_order": 2}, + "filtering": {"freq_min": 150, "freq_max": 7000, "ftype": "bessel", "filter_order": 2, "margin_ms": 10}, "whitening": {"mode": "local", "regularize": False}, "detection": {"peak_sign": "neg", "detect_threshold": 4}, "selection": { @@ -52,7 +56,8 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "matched_filtering": True, "cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True}, "multi_units_only": False, - "job_kwargs": {"n_jobs": 0.8}, + "job_kwargs": {"n_jobs": 0.5}, + "seed": 42, "debug": False, } @@ -74,18 +79,21 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "merging": "A dictionary to specify the final merging param to group cells after template matching (get_potential_auto_merge)", "motion_correction": "A dictionary to be provided if motion correction has to be performed (dense probe only)", "apply_preprocessing": "Boolean to specify whether circus 2 should preprocess the recording or not. If yes, then high_pass filtering + common\ - median reference + zscore", + median reference + whitening", + "apply_motion_correction": "Boolean to specify whether circus 2 should apply motion correction to the recording or not", + "matched_filtering": "Boolean to specify whether circus 2 should detect peaks via matched filtering (slightly slower)", "cache_preprocessing": "How to cache the preprocessed recording. Mode can be memory, file, zarr, with extra arguments. In case of memory (default), \ memory_limit will control how much RAM can be used. In case of folder or zarr, delete_cache controls if cache is cleaned after sorting", "multi_units_only": "Boolean to get only multi units activity (i.e. one template per electrode)", "job_kwargs": "A dictionary to specify how many jobs and which parameters they should used", + "seed": "An int to control how chunks are shuffled while detecting peaks", "debug": "Boolean to specify if internal data structures made during the sorting should be kept for debugging", } sorter_description = """Spyking Circus 2 is a rewriting of Spyking Circus, within the SpikeInterface framework It uses a more conservative clustering algorithm (compared to Spyking Circus), which is less prone to hallucinate units and/or find noise. In addition, it also uses a full Orthogonal Matching Pursuit engine to reconstruct the traces, leading to more spikes - being discovered.""" + being discovered. The code is much faster and memory efficient, inheriting from all the preprocessing possibilities of spikeinterface""" @classmethod def get_sorter_version(cls): @@ -114,8 +122,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): from spikeinterface.sortingcomponents.clustering import find_cluster_from_peaks from spikeinterface.sortingcomponents.matching import find_spikes_from_templates from spikeinterface.sortingcomponents.tools import remove_empty_templates - from spikeinterface.sortingcomponents.tools import get_prototype_spike, check_probe_for_drift_correction - from spikeinterface.sortingcomponents.tools import get_prototype_spike + from spikeinterface.sortingcomponents.tools import check_probe_for_drift_correction job_kwargs = fix_job_kwargs(params["job_kwargs"]) job_kwargs.update({"progress_bar": verbose}) @@ -132,10 +139,14 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## First, we are filtering the data filtering_params = params["filtering"].copy() if params["apply_preprocessing"]: + if verbose: + print("Preprocessing the recording (bandpass filtering + CMR + whitening)") recording_f = bandpass_filter(recording, **filtering_params, dtype="float32") if num_channels > 1: recording_f = common_reference(recording_f) else: + if verbose: + print("Skipping preprocessing (whitening only)") recording_f = recording recording_f.annotate(is_filtered=True) @@ -158,12 +169,14 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # TODO add , regularize=True chen ready whitening_kwargs = params["whitening"].copy() whitening_kwargs["dtype"] = "float32" - whitening_kwargs["radius_um"] = radius_um + whitening_kwargs["regularize"] = whitening_kwargs.get("regularize", False) if num_channels == 1: whitening_kwargs["regularize"] = False + if whitening_kwargs["regularize"]: + whitening_kwargs["regularize_kwargs"] = {"method": "LedoitWolf"} recording_w = whiten(recording_f, **whitening_kwargs) - noise_levels = get_noise_levels(recording_w, return_scaled=False) + noise_levels = get_noise_levels(recording_w, return_scaled=False, **job_kwargs) if recording_w.check_serializability("json"): recording_w.dump(sorter_output_folder / "preprocessed_recording.json", relative_to=None) @@ -174,9 +187,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## Then, we are detecting peaks with a locally_exclusive method detection_params = params["detection"].copy() - detection_params.update(job_kwargs) - - detection_params["radius_um"] = detection_params.get("radius_um", 50) + selection_params = params["selection"].copy() + detection_params["radius_um"] = radius_um detection_params["exclude_sweep_ms"] = exclude_sweep_ms detection_params["noise_levels"] = noise_levels @@ -184,17 +196,47 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): nbefore = int(ms_before * fs / 1000.0) nafter = int(ms_after * fs / 1000.0) + skip_peaks = not params["multi_units_only"] and selection_params.get("method", "uniform") == "uniform" + max_n_peaks = selection_params["n_peaks_per_channel"] * num_channels + n_peaks = max(selection_params["min_n_peaks"], max_n_peaks) + + if params["debug"]: + clustering_folder = sorter_output_folder / "clustering" + clustering_folder.mkdir(parents=True, exist_ok=True) + np.save(clustering_folder / "noise_levels.npy", noise_levels) + if params["matched_filtering"]: - peaks = detect_peaks(recording_w, "locally_exclusive", **detection_params, skip_after_n_peaks=5000) - prototype = get_prototype_spike(recording_w, peaks, ms_before, ms_after, **job_kwargs) + prototype, waveforms, _ = get_prototype_and_waveforms_from_recording( + recording_w, + n_peaks=10000, + ms_before=ms_before, + ms_after=ms_after, + seed=params["seed"], + **detection_params, + **job_kwargs, + ) detection_params["prototype"] = prototype detection_params["ms_before"] = ms_before - peaks = detect_peaks(recording_w, "matched_filtering", **detection_params) + if params["debug"]: + np.save(clustering_folder / "waveforms.npy", waveforms) + np.save(clustering_folder / "prototype.npy", prototype) + if skip_peaks: + detection_params["skip_after_n_peaks"] = n_peaks + detection_params["recording_slices"] = get_shuffled_recording_slices( + recording_w, seed=params["seed"], **job_kwargs + ) + peaks = detect_peaks(recording_w, "matched_filtering", **detection_params, **job_kwargs) else: - peaks = detect_peaks(recording_w, "locally_exclusive", **detection_params) + waveforms = None + if skip_peaks: + detection_params["skip_after_n_peaks"] = n_peaks + detection_params["recording_slices"] = get_shuffled_recording_slices( + recording_w, seed=params["seed"], **job_kwargs + ) + peaks = detect_peaks(recording_w, "locally_exclusive", **detection_params, **job_kwargs) - if verbose: - print("We found %d peaks in total" % len(peaks)) + if not skip_peaks and verbose: + print("Found %d peaks in total" % len(peaks)) if params["multi_units_only"]: sorting = NumpySorting.from_peaks(peaks, sampling_frequency, unit_ids=recording_w.unit_ids) @@ -202,14 +244,12 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## We subselect a subset of all the peaks, by making the distributions os SNRs over all ## channels as flat as possible selection_params = params["selection"] - selection_params["n_peaks"] = min(len(peaks), selection_params["n_peaks_per_channel"] * num_channels) - selection_params["n_peaks"] = max(selection_params["min_n_peaks"], selection_params["n_peaks"]) - + selection_params["n_peaks"] = n_peaks selection_params.update({"noise_levels": noise_levels}) selected_peaks = select_peaks(peaks, **selection_params) if verbose: - print("We kept %d peaks for clustering" % len(selected_peaks)) + print("Kept %d peaks for clustering" % len(selected_peaks)) ## We launch a clustering (using hdbscan) relying on positions and features extracted on ## the fly from the snippets @@ -219,11 +259,13 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clustering_params["radius_um"] = radius_um clustering_params["waveforms"]["ms_before"] = ms_before clustering_params["waveforms"]["ms_after"] = ms_after - clustering_params["job_kwargs"] = job_kwargs + clustering_params["few_waveforms"] = waveforms clustering_params["noise_levels"] = noise_levels - clustering_params["ms_before"] = exclude_sweep_ms - clustering_params["ms_after"] = exclude_sweep_ms + clustering_params["ms_before"] = ms_before + clustering_params["ms_after"] = ms_after + clustering_params["verbose"] = verbose clustering_params["tmp_folder"] = sorter_output_folder / "clustering" + clustering_params["noise_threshold"] = detection_params.get("detect_threshold", 4) legacy = clustering_params.get("legacy", True) @@ -233,7 +275,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clustering_method = "random_projections" labels, peak_labels = find_cluster_from_peaks( - recording_w, selected_peaks, method=clustering_method, method_kwargs=clustering_params + recording_w, selected_peaks, method=clustering_method, method_kwargs=clustering_params, **job_kwargs ) ## We get the labels for our peaks @@ -248,12 +290,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): unit_ids = np.arange(len(np.unique(labeled_peaks["unit_index"]))) sorting = NumpySorting(labeled_peaks, sampling_frequency, unit_ids=unit_ids) - clustering_folder = sorter_output_folder / "clustering" - clustering_folder.mkdir(parents=True, exist_ok=True) - - if not params["debug"]: - shutil.rmtree(clustering_folder) - else: + if params["debug"]: + np.save(clustering_folder / "peak_labels", peak_labels) np.save(clustering_folder / "labels", labels) np.save(clustering_folder / "peaks", selected_peaks) @@ -284,11 +322,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): matching_method = params["matching"].pop("method") matching_params = params["matching"].copy() matching_params["templates"] = templates - matching_job_params = job_kwargs.copy() if matching_method is not None: spikes = find_spikes_from_templates( - recording_w, matching_method, method_kwargs=matching_params, **matching_job_params + recording_w, matching_method, method_kwargs=matching_params, **job_kwargs ) if params["debug"]: @@ -297,7 +334,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): np.save(fitting_folder / "spikes", spikes) if verbose: - print("We found %d spikes" % len(spikes)) + print("Found %d spikes" % len(spikes)) ## And this is it! We have a spyking circus sorting = np.zeros(spikes.size, dtype=minimum_spike_dtype) @@ -337,10 +374,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): sorting.save(folder=curation_folder) # np.save(fitting_folder / "amplitudes", guessed_amplitudes) - sorting = final_cleaning_circus(recording_w, sorting, templates, **merging_params) + sorting = final_cleaning_circus(recording_w, sorting, templates, merging_params, **job_kwargs) if verbose: - print(f"Final merging, keeping {len(sorting.unit_ids)} units") + print(f"Kept {len(sorting.unit_ids)} units after final merging") folder_to_delete = None cache_mode = params["cache_preprocessing"].get("mode", "memory") @@ -379,17 +416,18 @@ def create_sorting_analyzer_with_templates(sorting, recording, templates, remove return sa -def final_cleaning_circus(recording, sorting, templates, **merging_kwargs): +def final_cleaning_circus(recording, sorting, templates, merging_kwargs, **job_kwargs): from spikeinterface.core.sorting_tools import apply_merges_to_sorting sa = create_sorting_analyzer_with_templates(sorting, recording, templates) - sa.compute("unit_locations", method="monopolar_triangulation") + sa.compute("unit_locations", method="monopolar_triangulation", **job_kwargs) similarity_kwargs = merging_kwargs.pop("similarity_kwargs", {}) - sa.compute("template_similarity", **similarity_kwargs) + sa.compute("template_similarity", **similarity_kwargs, **job_kwargs) correlograms_kwargs = merging_kwargs.pop("correlograms_kwargs", {}) - sa.compute("correlograms", **correlograms_kwargs) + sa.compute("correlograms", **correlograms_kwargs, **job_kwargs) + auto_merge_kwargs = merging_kwargs.pop("auto_merge", {}) merges = get_potential_auto_merge(sa, resolve_graph=True, **auto_merge_kwargs) sorting = apply_merges_to_sorting(sa.sorting, merges) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 99c59f493e..bc173a6ff0 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -18,7 +18,6 @@ from spikeinterface.core.waveform_tools import estimate_templates from .clustering_tools import remove_duplicates_via_matching from spikeinterface.core.recording_tools import get_noise_levels, get_channel_distances -from spikeinterface.core.job_tools import fix_job_kwargs from spikeinterface.sortingcomponents.peak_selection import select_peaks from spikeinterface.sortingcomponents.waveforms.temporal_pca import TemporalPCAProjection from spikeinterface.core.template import Templates @@ -41,13 +40,7 @@ class CircusClustering: """ _default_params = { - "hdbscan_kwargs": { - "min_cluster_size": 25, - "allow_single_cluster": True, - "core_dist_n_jobs": -1, - "cluster_selection_method": "eom", - # "cluster_selection_epsilon" : 5 ## To be optimized - }, + "hdbscan_kwargs": {"min_cluster_size": 10, "allow_single_cluster": True, "min_samples": 5}, "cleaning_kwargs": {}, "waveforms": {"ms_before": 2, "ms_after": 2}, "sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25}, @@ -58,21 +51,20 @@ class CircusClustering: }, "radius_um": 100, "n_svd": [5, 2], + "few_waveforms": None, "ms_before": 0.5, "ms_after": 0.5, + "noise_threshold": 4, "rank": 5, "noise_levels": None, "tmp_folder": None, - "job_kwargs": {}, "verbose": True, } @classmethod - def main_function(cls, recording, peaks, params): + def main_function(cls, recording, peaks, params, job_kwargs=dict()): assert HAVE_HDBSCAN, "random projections clustering needs hdbscan to be installed" - job_kwargs = fix_job_kwargs(params["job_kwargs"]) - d = params verbose = d["verbose"] @@ -90,12 +82,25 @@ def main_function(cls, recording, peaks, params): tmp_folder.mkdir(parents=True, exist_ok=True) # SVD for time compression - few_peaks = select_peaks(peaks, recording=recording, method="uniform", n_peaks=10000, margin=(nbefore, nafter)) - few_wfs = extract_waveform_at_max_channel( - recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs - ) + if params["few_waveforms"] is None: + few_peaks = select_peaks( + peaks, recording=recording, method="uniform", n_peaks=10000, margin=(nbefore, nafter) + ) + few_wfs = extract_waveform_at_max_channel( + recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs + ) + wfs = few_wfs[:, :, 0] + else: + offset = int(params["waveforms"]["ms_before"] * fs / 1000) + wfs = params["few_waveforms"][:, offset - nbefore : offset + nafter] + + # Ensure all waveforms have a positive max + wfs *= np.sign(wfs[:, nbefore])[:, np.newaxis] + + # Remove outliers + valid = np.argmax(np.abs(wfs), axis=1) == nbefore + wfs = wfs[valid] - wfs = few_wfs[:, :, 0] from sklearn.decomposition import TruncatedSVD tsvd = TruncatedSVD(params["n_svd"][0]) @@ -193,7 +198,7 @@ def main_function(cls, recording, peaks, params): original_labels = peaks["channel_index"] from spikeinterface.sortingcomponents.clustering.split import split_clusters - min_size = params["hdbscan_kwargs"].get("min_cluster_size", 50) + min_size = 2 * params["hdbscan_kwargs"].get("min_cluster_size", 10) peak_labels, _ = split_clusters( original_labels, @@ -229,47 +234,64 @@ def main_function(cls, recording, peaks, params): nbefore = int(params["waveforms"]["ms_before"] * fs / 1000.0) nafter = int(params["waveforms"]["ms_after"] * fs / 1000.0) + if params["noise_levels"] is None: + params["noise_levels"] = get_noise_levels(recording, return_scaled=False, **job_kwargs) + templates_array = estimate_templates( - recording, spikes, unit_ids, nbefore, nafter, return_scaled=False, job_name=None, **job_kwargs + recording, + spikes, + unit_ids, + nbefore, + nafter, + return_scaled=False, + job_name=None, + **job_kwargs, ) + best_channels = np.argmax(np.abs(templates_array[:, nbefore, :]), axis=1) + peak_snrs = np.abs(templates_array[:, nbefore, :]) + best_snrs_ratio = (peak_snrs / params["noise_levels"])[np.arange(len(peak_snrs)), best_channels] + valid_templates = best_snrs_ratio > params["noise_threshold"] + if d["rank"] is not None: from spikeinterface.sortingcomponents.matching.circus import compress_templates _, _, _, templates_array = compress_templates(templates_array, d["rank"]) templates = Templates( - templates_array=templates_array, + templates_array=templates_array[valid_templates], sampling_frequency=fs, nbefore=nbefore, sparsity_mask=None, channel_ids=recording.channel_ids, - unit_ids=unit_ids, + unit_ids=unit_ids[valid_templates], probe=recording.get_probe(), is_scaled=False, ) - if params["noise_levels"] is None: - params["noise_levels"] = get_noise_levels(recording, return_scaled=False) + sparsity = compute_sparsity(templates, noise_levels=params["noise_levels"], **params["sparsity"]) templates = templates.to_sparse(sparsity) empty_templates = templates.sparsity_mask.sum(axis=1) == 0 templates = remove_empty_templates(templates) + mask = np.isin(peak_labels, np.where(empty_templates)[0]) peak_labels[mask] = -1 + mask = np.isin(peak_labels, np.where(~valid_templates)[0]) + peak_labels[mask] = -1 + if verbose: - print("We found %d raw clusters, starting to clean with matching..." % (len(templates.unit_ids))) + print("Found %d raw clusters, starting to clean with matching" % (len(templates.unit_ids))) - cleaning_matching_params = params["job_kwargs"].copy() - cleaning_matching_params["n_jobs"] = 1 - cleaning_matching_params["progress_bar"] = False + cleaning_job_kwargs = job_kwargs.copy() + cleaning_job_kwargs["progress_bar"] = False cleaning_params = params["cleaning_kwargs"].copy() labels, peak_labels = remove_duplicates_via_matching( - templates, peak_labels, job_kwargs=cleaning_matching_params, **cleaning_params + templates, peak_labels, job_kwargs=cleaning_job_kwargs, **cleaning_params ) if verbose: - print("We kept %d non-duplicated clusters..." % len(labels)) + print("Kept %d non-duplicated clusters" % len(labels)) return labels, peak_labels diff --git a/src/spikeinterface/sortingcomponents/clustering/clean.py b/src/spikeinterface/sortingcomponents/clustering/clean.py index c7d57b14e4..e8bc5a1d49 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clean.py +++ b/src/spikeinterface/sortingcomponents/clustering/clean.py @@ -32,7 +32,6 @@ def clean_clusters( count = np.zeros(n, dtype="int64") for i, label in enumerate(labels_set): count[i] = np.sum(peak_labels == label) - print(count) templates = compute_template_from_sparse(peaks, peak_labels, labels_set, sparse_wfs, sparse_mask, total_channels) @@ -42,6 +41,5 @@ def clean_clusters( max_values = -np.min(templates, axis=(1, 2)) elif peak_sign == "pos": max_values = np.max(templates, axis=(1, 2)) - print(max_values) return clean_labels diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 08a1384333..93db9a268f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -570,7 +570,7 @@ def detect_mixtures(templates, method_kwargs={}, job_kwargs={}, tmp_folder=None, ) else: recording = NumpyRecording(zdata, sampling_frequency=fs) - recording = SharedMemoryRecording.from_recording(recording) + recording = SharedMemoryRecording.from_recording(recording, **job_kwargs) recording = recording.set_probe(templates.probe) recording.annotate(is_filtered=True) @@ -587,6 +587,8 @@ def detect_mixtures(templates, method_kwargs={}, job_kwargs={}, tmp_folder=None, keep_searching = True + local_job_kargs = {"n_jobs": 1, "progress_bar": False} + DEBUG = False while keep_searching: @@ -604,7 +606,11 @@ def detect_mixtures(templates, method_kwargs={}, job_kwargs={}, tmp_folder=None, local_params.update({"ignore_inds": ignore_inds + [i]}) spikes, more_outputs = find_spikes_from_templates( - sub_recording, method="circus-omp-svd", method_kwargs=local_params, extra_outputs=True, **job_kwargs + sub_recording, + method="circus-omp-svd", + method_kwargs=local_params, + extra_outputs=True, + **local_job_kargs, ) local_params["precomputed"] = more_outputs valid = (spikes["sample_index"] >= 0) * (spikes["sample_index"] < duration + 2 * margin) diff --git a/src/spikeinterface/sortingcomponents/clustering/dummy.py b/src/spikeinterface/sortingcomponents/clustering/dummy.py index c1032ee6c6..b5761ad5cf 100644 --- a/src/spikeinterface/sortingcomponents/clustering/dummy.py +++ b/src/spikeinterface/sortingcomponents/clustering/dummy.py @@ -13,7 +13,7 @@ class DummyClustering: _default_params = {} @classmethod - def main_function(cls, recording, peaks, params): + def main_function(cls, recording, peaks, params, job_kwargs=dict()): labels = np.arange(recording.get_num_channels(), dtype="int64") peak_labels = peaks["channel_index"] return labels, peak_labels diff --git a/src/spikeinterface/sortingcomponents/clustering/main.py b/src/spikeinterface/sortingcomponents/clustering/main.py index 99881f2f34..ba0fe6f9ac 100644 --- a/src/spikeinterface/sortingcomponents/clustering/main.py +++ b/src/spikeinterface/sortingcomponents/clustering/main.py @@ -41,7 +41,7 @@ def find_cluster_from_peaks(recording, peaks, method="stupid", method_kwargs={}, params = method_class._default_params.copy() params.update(**method_kwargs) - outputs = method_class.main_function(recording, peaks, params) + outputs = method_class.main_function(recording, peaks, params, job_kwargs=job_kwargs) if extra_outputs: return outputs diff --git a/src/spikeinterface/sortingcomponents/clustering/position.py b/src/spikeinterface/sortingcomponents/clustering/position.py index ae772206bb..dc76d787f6 100644 --- a/src/spikeinterface/sortingcomponents/clustering/position.py +++ b/src/spikeinterface/sortingcomponents/clustering/position.py @@ -25,18 +25,17 @@ class PositionClustering: "hdbscan_kwargs": {"min_cluster_size": 20, "allow_single_cluster": True, "core_dist_n_jobs": -1}, "debug": False, "tmp_folder": None, - "job_kwargs": {"n_jobs": -1, "chunk_memory": "10M"}, } @classmethod - def main_function(cls, recording, peaks, params): + def main_function(cls, recording, peaks, params, job_kwargs=dict()): assert HAVE_HDBSCAN, "position clustering need hdbscan to be installed" d = params if d["peak_locations"] is None: from spikeinterface.sortingcomponents.peak_localization import localize_peaks - peak_locations = localize_peaks(recording, peaks, **d["peak_localization_kwargs"], **d["job_kwargs"]) + peak_locations = localize_peaks(recording, peaks, **d["peak_localization_kwargs"], **job_kwargs) else: peak_locations = d["peak_locations"] diff --git a/src/spikeinterface/sortingcomponents/clustering/position_and_features.py b/src/spikeinterface/sortingcomponents/clustering/position_and_features.py index d23eb26239..20067a2eec 100644 --- a/src/spikeinterface/sortingcomponents/clustering/position_and_features.py +++ b/src/spikeinterface/sortingcomponents/clustering/position_and_features.py @@ -42,23 +42,14 @@ class PositionAndFeaturesClustering: "ms_before": 1.5, "ms_after": 1.5, "cleaning_method": "dip", - "job_kwargs": {"n_jobs": -1, "chunk_memory": "10M", "progress_bar": True}, } @classmethod - def main_function(cls, recording, peaks, params): + def main_function(cls, recording, peaks, params, job_kwargs=dict()): from sklearn.preprocessing import QuantileTransformer assert HAVE_HDBSCAN, "twisted clustering needs hdbscan to be installed" - if "n_jobs" in params["job_kwargs"]: - if params["job_kwargs"]["n_jobs"] == -1: - params["job_kwargs"]["n_jobs"] = os.cpu_count() - - if "core_dist_n_jobs" in params["hdbscan_kwargs"]: - if params["hdbscan_kwargs"]["core_dist_n_jobs"] == -1: - params["hdbscan_kwargs"]["core_dist_n_jobs"] = os.cpu_count() - d = params peak_dtype = [("sample_index", "int64"), ("unit_index", "int64"), ("segment_index", "int64")] @@ -80,7 +71,7 @@ def main_function(cls, recording, peaks, params): } features_data = compute_features_from_peaks( - recording, peaks, features_list, features_params, ms_before=1, ms_after=1, **params["job_kwargs"] + recording, peaks, features_list, features_params, ms_before=1, ms_after=1, **job_kwargs ) hdbscan_data = np.zeros((len(peaks), 3), dtype=np.float32) @@ -150,10 +141,10 @@ def main_function(cls, recording, peaks, params): dtype=recording.get_dtype(), sparsity_mask=None, copy=True, - **params["job_kwargs"], + **job_kwargs, ) - noise_levels = get_noise_levels(recording, return_scaled=False) + noise_levels = get_noise_levels(recording, return_scaled=False, **job_kwargs) labels, peak_labels = remove_duplicates( wfs_arrays, noise_levels, peak_labels, num_samples, num_chans, **params["cleaning_kwargs"] ) @@ -181,7 +172,7 @@ def main_function(cls, recording, peaks, params): nbefore, nafter, return_scaled=False, - **params["job_kwargs"], + **job_kwargs, ) templates = Templates( templates_array=templates_array, @@ -193,7 +184,7 @@ def main_function(cls, recording, peaks, params): ) labels, peak_labels = remove_duplicates_via_matching( - templates, peak_labels, job_kwargs=params["job_kwargs"], **params["cleaning_kwargs"] + templates, peak_labels, job_kwargs=job_kwargs, **params["cleaning_kwargs"] ) shutil.rmtree(tmp_folder) diff --git a/src/spikeinterface/sortingcomponents/clustering/position_and_pca.py b/src/spikeinterface/sortingcomponents/clustering/position_and_pca.py index 4dfe3c960c..c4f372fc21 100644 --- a/src/spikeinterface/sortingcomponents/clustering/position_and_pca.py +++ b/src/spikeinterface/sortingcomponents/clustering/position_and_pca.py @@ -38,7 +38,6 @@ class PositionAndPCAClustering: "ms_after": 2.5, "n_components_by_channel": 3, "n_components": 5, - "job_kwargs": {"n_jobs": -1, "chunk_memory": "10M", "progress_bar": True}, "hdbscan_global_kwargs": {"min_cluster_size": 20, "allow_single_cluster": True, "core_dist_n_jobs": -1}, "hdbscan_local_kwargs": {"min_cluster_size": 20, "allow_single_cluster": True, "core_dist_n_jobs": -1}, "waveform_mode": "shared_memory", @@ -73,7 +72,7 @@ def _check_params(cls, recording, peaks, params): return params2 @classmethod - def main_function(cls, recording, peaks, params): + def main_function(cls, recording, peaks, params, job_kwargs=dict()): # res = PositionClustering(recording, peaks, params) assert HAVE_HDBSCAN, "position_and_pca clustering need hdbscan to be installed" @@ -85,9 +84,7 @@ def main_function(cls, recording, peaks, params): if params["peak_locations"] is None: from spikeinterface.sortingcomponents.peak_localization import localize_peaks - peak_locations = localize_peaks( - recording, peaks, **params["peak_localization_kwargs"], **params["job_kwargs"] - ) + peak_locations = localize_peaks(recording, peaks, **params["peak_localization_kwargs"], **job_kwargs) else: peak_locations = params["peak_locations"] @@ -155,7 +152,7 @@ def main_function(cls, recording, peaks, params): dtype=recording.get_dtype(), sparsity_mask=sparsity_mask, copy=(params["waveform_mode"] == "shared_memory"), - **params["job_kwargs"], + **job_kwargs, ) noise = get_random_data_chunks( @@ -222,7 +219,7 @@ def main_function(cls, recording, peaks, params): dtype=recording.get_dtype(), sparsity_mask=sparsity_mask3, copy=(params["waveform_mode"] == "shared_memory"), - **params["job_kwargs"], + **job_kwargs, ) clean_peak_labels, peak_sample_shifts = auto_clean_clustering( diff --git a/src/spikeinterface/sortingcomponents/clustering/position_ptp_scaled.py b/src/spikeinterface/sortingcomponents/clustering/position_ptp_scaled.py index 788addf1e6..0f7390d7ac 100644 --- a/src/spikeinterface/sortingcomponents/clustering/position_ptp_scaled.py +++ b/src/spikeinterface/sortingcomponents/clustering/position_ptp_scaled.py @@ -26,7 +26,6 @@ class PositionPTPScaledClustering: "ptps": None, "scales": (1, 1, 10), "peak_localization_kwargs": {"method": "center_of_mass"}, - "job_kwargs": {"n_jobs": -1, "chunk_memory": "10M", "progress_bar": True}, "hdbscan_kwargs": { "min_cluster_size": 20, "min_samples": 20, @@ -38,7 +37,7 @@ class PositionPTPScaledClustering: } @classmethod - def main_function(cls, recording, peaks, params): + def main_function(cls, recording, peaks, params, job_kwargs=dict()): assert HAVE_HDBSCAN, "position clustering need hdbscan to be installed" d = params @@ -60,7 +59,7 @@ def main_function(cls, recording, peaks, params): if d["ptps"] is None: (ptps,) = compute_features_from_peaks( - recording, peaks, ["ptp"], feature_params={"ptp": {"all_channels": True}}, **d["job_kwargs"] + recording, peaks, ["ptp"], feature_params={"ptp": {"all_channels": True}}, **job_kwargs ) else: ptps = d["ptps"] diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index f7ca999d53..1d4d8881ad 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -16,8 +16,7 @@ from spikeinterface.core.basesorting import minimum_spike_dtype from spikeinterface.core.waveform_tools import estimate_templates from .clustering_tools import remove_duplicates_via_matching -from spikeinterface.core.recording_tools import get_noise_levels, get_channel_distances -from spikeinterface.core.job_tools import fix_job_kwargs +from spikeinterface.core.recording_tools import get_noise_levels from spikeinterface.sortingcomponents.waveforms.savgol_denoiser import SavGolDenoiser from spikeinterface.sortingcomponents.features_from_peaks import RandomProjectionsFeature from spikeinterface.core.template import Templates @@ -54,17 +53,15 @@ class RandomProjectionClustering: "random_seed": 42, "noise_levels": None, "smoothing_kwargs": {"window_length_ms": 0.25}, + "noise_threshold": 4, "tmp_folder": None, - "job_kwargs": {}, "verbose": True, } @classmethod - def main_function(cls, recording, peaks, params): + def main_function(cls, recording, peaks, params, job_kwargs=dict()): assert HAVE_HDBSCAN, "random projections clustering need hdbscan to be installed" - job_kwargs = fix_job_kwargs(params["job_kwargs"]) - d = params verbose = d["verbose"] @@ -133,44 +130,59 @@ def main_function(cls, recording, peaks, params): nbefore = int(params["waveforms"]["ms_before"] * fs / 1000.0) nafter = int(params["waveforms"]["ms_after"] * fs / 1000.0) + if params["noise_levels"] is None: + params["noise_levels"] = get_noise_levels(recording, return_scaled=False, **job_kwargs) + templates_array = estimate_templates( - recording, spikes, unit_ids, nbefore, nafter, return_scaled=False, job_name=None, **job_kwargs + recording, + spikes, + unit_ids, + nbefore, + nafter, + return_scaled=False, + job_name=None, + **job_kwargs, ) + best_channels = np.argmax(np.abs(templates_array[:, nbefore, :]), axis=1) + peak_snrs = np.abs(templates_array[:, nbefore, :]) + best_snrs_ratio = (peak_snrs / params["noise_levels"])[np.arange(len(peak_snrs)), best_channels] + valid_templates = best_snrs_ratio > params["noise_threshold"] + templates = Templates( - templates_array=templates_array, + templates_array=templates_array[valid_templates], sampling_frequency=fs, nbefore=nbefore, sparsity_mask=None, channel_ids=recording.channel_ids, - unit_ids=unit_ids, + unit_ids=unit_ids[valid_templates], probe=recording.get_probe(), is_scaled=False, ) - if params["noise_levels"] is None: - params["noise_levels"] = get_noise_levels(recording, return_scaled=False) - sparsity = compute_sparsity(templates, params["noise_levels"], **params["sparsity"]) + + sparsity = compute_sparsity(templates, noise_levels=params["noise_levels"], **params["sparsity"]) templates = templates.to_sparse(sparsity) + empty_templates = templates.sparsity_mask.sum(axis=1) == 0 templates = remove_empty_templates(templates) - if verbose: - print("We found %d raw clusters, starting to clean with matching..." % (len(templates.unit_ids))) + mask = np.isin(peak_labels, np.where(empty_templates)[0]) + peak_labels[mask] = -1 - cleaning_matching_params = job_kwargs.copy() - for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: - if value in cleaning_matching_params: - cleaning_matching_params[value] = None - cleaning_matching_params["chunk_duration"] = "100ms" - cleaning_matching_params["n_jobs"] = 1 - cleaning_matching_params["progress_bar"] = False + mask = np.isin(peak_labels, np.where(~valid_templates)[0]) + peak_labels[mask] = -1 + + if verbose: + print("Found %d raw clusters, starting to clean with matching" % (len(templates.unit_ids))) + cleaning_job_kwargs = job_kwargs.copy() + cleaning_job_kwargs["progress_bar"] = False cleaning_params = params["cleaning_kwargs"].copy() labels, peak_labels = remove_duplicates_via_matching( - templates, peak_labels, job_kwargs=cleaning_matching_params, **cleaning_params + templates, peak_labels, job_kwargs=cleaning_job_kwargs, **cleaning_params ) if verbose: - print("We kept %d non-duplicated clusters..." % len(labels)) + print("Kept %d non-duplicated clusters" % len(labels)) return labels, peak_labels diff --git a/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py b/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py index 8b9acbc92d..5f8ac99848 100644 --- a/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py +++ b/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py @@ -23,7 +23,7 @@ get_random_data_chunks, extract_waveforms_to_buffers, ) -from .clustering_tools import auto_clean_clustering, auto_split_clustering +from .clustering_tools import auto_clean_clustering class SlidingHdbscanClustering: @@ -55,18 +55,17 @@ class SlidingHdbscanClustering: "auto_merge_quantile_limit": 0.8, "ratio_num_channel_intersect": 0.5, # ~ 'auto_trash_misalignment_shift' : 4, - "job_kwargs": {"n_jobs": -1, "chunk_memory": "10M", "progress_bar": True}, } @classmethod - def main_function(cls, recording, peaks, params): + def main_function(cls, recording, peaks, params, job_kwargs=dict()): assert HAVE_HDBSCAN, "sliding_hdbscan clustering need hdbscan to be installed" params = cls._check_params(recording, peaks, params) wfs_arrays, sparsity_mask, noise = cls._initialize_folder(recording, peaks, params) peak_labels = cls._find_clusters(recording, peaks, wfs_arrays, sparsity_mask, noise, params) wfs_arrays2, sparsity_mask2 = cls._prepare_clean( - recording, peaks, wfs_arrays, sparsity_mask, peak_labels, params + recording, peaks, wfs_arrays, sparsity_mask, peak_labels, params, job_kwargs ) clean_peak_labels, peak_sample_shifts = cls._clean_cluster( @@ -100,7 +99,7 @@ def _check_params(cls, recording, peaks, params): return params2 @classmethod - def _initialize_folder(cls, recording, peaks, params): + def _initialize_folder(cls, recording, peaks, params, job_kwargs=dict()): d = params tmp_folder = params["tmp_folder"] @@ -145,7 +144,7 @@ def _initialize_folder(cls, recording, peaks, params): dtype=dtype, sparsity_mask=sparsity_mask, copy=(d["waveform_mode"] == "shared_memory"), - **d["job_kwargs"], + **job_kwargs, ) # noise @@ -401,7 +400,7 @@ def _find_clusters(cls, recording, peaks, wfs_arrays, sparsity_mask, noise, d): return peak_labels @classmethod - def _prepare_clean(cls, recording, peaks, wfs_arrays, sparsity_mask, peak_labels, d): + def _prepare_clean(cls, recording, peaks, wfs_arrays, sparsity_mask, peak_labels, d, job_kwargs=dict()): tmp_folder = d["tmp_folder"] if tmp_folder is None: wf_folder = None @@ -465,7 +464,7 @@ def _prepare_clean(cls, recording, peaks, wfs_arrays, sparsity_mask, peak_labels dtype=recording.get_dtype(), sparsity_mask=sparsity_mask2, copy=(d["waveform_mode"] == "shared_memory"), - **d["job_kwargs"], + **job_kwargs, ) return wfs_arrays2, sparsity_mask2 diff --git a/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py b/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py index a6ffa5fdc2..40cedacdc5 100644 --- a/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py +++ b/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py @@ -71,11 +71,10 @@ class SlidingNNClustering: "tmp_folder": None, "verbose": False, "tmp_folder": None, - "job_kwargs": {"n_jobs": -1}, } @classmethod - def _initialize_folder(cls, recording, peaks, params): + def _initialize_folder(cls, recording, peaks, params, job_kwargs=dict()): assert HAVE_NUMBA, "SlidingNN needs numba to work" assert HAVE_TORCH, "SlidingNN needs torch to work" assert HAVE_NNDESCENT, "SlidingNN needs pynndescent to work" @@ -126,16 +125,16 @@ def _initialize_folder(cls, recording, peaks, params): dtype=dtype, sparsity_mask=sparsity_mask, copy=(d["waveform_mode"] == "shared_memory"), - **d["job_kwargs"], + **job_kwargs, ) return wfs_arrays, sparsity_mask @classmethod - def main_function(cls, recording, peaks, params): + def main_function(cls, recording, peaks, params, job_kwargs=dict()): d = params - # wfs_arrays, sparsity_mask, noise = cls._initialize_folder(recording, peaks, params) + # wfs_arrays, sparsity_mask, noise = cls._initialize_folder(recording, peaks, params, job_kwargs) # prepare neighborhood parameters fs = recording.get_sampling_frequency() @@ -228,7 +227,7 @@ def main_function(cls, recording, peaks, params): n_channel_neighbors=d["n_channel_neighbors"], low_memory=d["low_memory"], knn_verbose=d["verbose"], - n_jobs=d["job_kwargs"]["n_jobs"], + n_jobs=job_kwargs["n_jobs"], ) # remove the first nearest neighbor (which should be self) knn_distances = knn_distances[:, 1:] @@ -297,7 +296,7 @@ def main_function(cls, recording, peaks, params): # TODO HDBSCAN can be done on GPU with NVIDIA RAPIDS for speed clusterer = hdbscan.HDBSCAN( prediction_data=True, - core_dist_n_jobs=d["job_kwargs"]["n_jobs"], + core_dist_n_jobs=job_kwargs["n_jobs"], **d["hdbscan_kwargs"], ).fit(embeddings_chunk) diff --git a/src/spikeinterface/sortingcomponents/clustering/tdc.py b/src/spikeinterface/sortingcomponents/clustering/tdc.py index 13af5b0fab..59472d1374 100644 --- a/src/spikeinterface/sortingcomponents/clustering/tdc.py +++ b/src/spikeinterface/sortingcomponents/clustering/tdc.py @@ -9,27 +9,21 @@ from spikeinterface.core import ( get_channel_distances, - Templates, - compute_sparsity, get_global_tmp_folder, ) from spikeinterface.core.node_pipeline import ( run_node_pipeline, - ExtractDenseWaveforms, ExtractSparseWaveforms, PeakRetriever, ) -from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel, cache_preprocessing -from spikeinterface.sortingcomponents.peak_detection import detect_peaks, DetectPeakLocallyExclusive +from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel from spikeinterface.sortingcomponents.peak_selection import select_peaks -from spikeinterface.sortingcomponents.peak_localization import LocalizeCenterOfMass, LocalizeGridConvolution from spikeinterface.sortingcomponents.waveforms.temporal_pca import TemporalPCAProjection from spikeinterface.sortingcomponents.clustering.split import split_clusters from spikeinterface.sortingcomponents.clustering.merge import merge_clusters -from spikeinterface.sortingcomponents.clustering.tools import compute_template_from_sparse class TdcClustering: @@ -50,15 +44,12 @@ class TdcClustering: "merge_radius_um": 40.0, "threshold_diff": 1.5, }, - "job_kwargs": {}, } @classmethod - def main_function(cls, recording, peaks, params): + def main_function(cls, recording, peaks, params, job_kwargs=dict()): import hdbscan - job_kwargs = params["job_kwargs"] - if params["folder"] is None: randname = "".join(random.choices(string.ascii_uppercase + string.digits, k=6)) clustering_folder = get_global_tmp_folder() / f"tdcclustering_{randname}" diff --git a/src/spikeinterface/sortingcomponents/clustering/tools.py b/src/spikeinterface/sortingcomponents/clustering/tools.py index e2a0d273d6..64cc0f39c4 100644 --- a/src/spikeinterface/sortingcomponents/clustering/tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/tools.py @@ -172,8 +172,6 @@ def apply_waveforms_shift(waveforms, peak_shifts, inplace=False): """ - print("apply_waveforms_shift") - if inplace: aligned_waveforms = waveforms else: @@ -193,6 +191,4 @@ def apply_waveforms_shift(waveforms, peak_shifts, inplace=False): else: aligned_waveforms[mask, -shift:, :] = wfs[:, :-shift, :] - print("apply_waveforms_shift DONE") - return aligned_waveforms diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index d03744f8f9..12955e2c40 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -118,7 +118,11 @@ def detect_peaks( squeeze_output = True else: squeeze_output = False - job_name += f" + {len(pipeline_nodes)} nodes" + if len(pipeline_nodes) == 1: + plural = "" + else: + plural = "s" + job_name += f" + {len(pipeline_nodes)} node{plural}" # because node are modified inplace (insert parent) they need to copy incase # the same pipeline is run several times @@ -677,7 +681,6 @@ def __init__( medians = medians[:, None] noise_levels = np.median(np.abs(conv_random_data - medians), axis=1) / 0.6744897501960817 self.abs_thresholds = noise_levels * detect_threshold - self._dtype = np.dtype(base_peak_dtype + [("z", "float32")]) def get_dtype(self): @@ -727,8 +730,8 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): return (np.zeros(0, dtype=self._dtype),) peak_sample_ind += self.exclude_sweep_size + self.conv_margin + self.nbefore - peak_amplitude = traces[peak_sample_ind, peak_chan_ind] + local_peaks = np.zeros(peak_sample_ind.size, dtype=self._dtype) local_peaks["sample_index"] = peak_sample_ind local_peaks["channel_index"] = peak_chan_ind diff --git a/src/spikeinterface/sortingcomponents/peak_localization.py b/src/spikeinterface/sortingcomponents/peak_localization.py index 08bcabf5e5..1e4e0edded 100644 --- a/src/spikeinterface/sortingcomponents/peak_localization.py +++ b/src/spikeinterface/sortingcomponents/peak_localization.py @@ -33,7 +33,7 @@ get_grid_convolution_templates_and_weights, ) -from .tools import get_prototype_spike +from .tools import get_prototype_and_waveforms_from_peaks def get_localization_pipeline_nodes( @@ -73,8 +73,8 @@ def get_localization_pipeline_nodes( assert isinstance(peak_source, (PeakRetriever, SpikeRetriever)) # extract prototypes silently job_kwargs["progress_bar"] = False - method_kwargs["prototype"] = get_prototype_spike( - recording, peak_source.peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs + method_kwargs["prototype"], _, _ = get_prototype_and_waveforms_from_peaks( + recording, peaks=peak_source.peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs ) extract_dense_waveforms = ExtractDenseWaveforms( recording, parents=[peak_source], ms_before=ms_before, ms_after=ms_after, return_output=False diff --git a/src/spikeinterface/sortingcomponents/tests/common.py b/src/spikeinterface/sortingcomponents/tests/common.py index 01e4445a13..d5e5b6be1b 100644 --- a/src/spikeinterface/sortingcomponents/tests/common.py +++ b/src/spikeinterface/sortingcomponents/tests/common.py @@ -21,4 +21,10 @@ def make_dataset(): noise_kwargs=dict(noise_levels=5.0, strategy="on_the_fly"), seed=2205, ) + + channel_ids_as_integers = [id for id in range(recording.get_num_channels())] + unit_ids_as_integers = [id for id in range(sorting.get_num_units())] + recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers) + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers) + return recording, sorting diff --git a/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py b/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py index 7c34f5948d..341ed3426d 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py +++ b/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py @@ -22,7 +22,7 @@ ) from spikeinterface.core.node_pipeline import run_node_pipeline -from spikeinterface.sortingcomponents.tools import get_prototype_spike +from spikeinterface.sortingcomponents.tools import get_prototype_and_waveforms_from_peaks from spikeinterface.sortingcomponents.tests.common import make_dataset @@ -314,7 +314,9 @@ def test_detect_peaks_locally_exclusive_matched_filtering(recording, job_kwargs) ms_before = 1.0 ms_after = 1.0 - prototype = get_prototype_spike(recording, peaks_by_channel_np, ms_before, ms_after, **job_kwargs) + prototype, _, _ = get_prototype_and_waveforms_from_peaks( + recording, peaks=peaks_by_channel_np, ms_before=ms_before, ms_after=ms_after, **job_kwargs + ) peaks_local_mf_filtering = detect_peaks( recording, diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 1501582336..1bd2381cda 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -69,25 +69,174 @@ def extract_waveform_at_max_channel(rec, peaks, ms_before=0.5, ms_after=1.5, **j return all_wfs -def get_prototype_spike(recording, peaks, ms_before=0.5, ms_after=0.5, nb_peaks=1000, **job_kwargs): +def get_prototype_and_waveforms_from_peaks( + recording, peaks, n_peaks=5000, ms_before=0.5, ms_after=0.5, seed=None, **all_kwargs +): + """ + Function to extract a prototype waveform from peaks. + + Parameters + ---------- + recording : Recording + The recording object containing the data. + peaks : numpy.array, optional + Array of peaks, if None, peaks will be detected, by default None. + n_peaks : int, optional + Number of peaks to consider, by default 5000. + ms_before : float, optional + Time in milliseconds before the peak to extract the waveform, by default 0.5. + ms_after : float, optional + Time in milliseconds after the peak to extract the waveform, by default 0.5. + seed : int or None, optional + Seed for random number generator, by default None. + **all_kwargs : dict + Additional keyword arguments for peak detection and job kwargs. + + Returns + ------- + prototype : numpy.array + The prototype waveform. + waveforms : numpy.array + The extracted waveforms for the selected peaks. + peaks : numpy.array + The selected peaks used to extract waveforms. + """ from spikeinterface.sortingcomponents.peak_selection import select_peaks + _, job_kwargs = split_job_kwargs(all_kwargs) + nbefore = int(ms_before * recording.sampling_frequency / 1000.0) nafter = int(ms_after * recording.sampling_frequency / 1000.0) - few_peaks = select_peaks(peaks, recording=recording, method="uniform", n_peaks=nb_peaks, margin=(nbefore, nafter)) - + few_peaks = select_peaks( + peaks, recording=recording, method="uniform", n_peaks=n_peaks, margin=(nbefore, nafter), seed=seed + ) waveforms = extract_waveform_at_max_channel( recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs ) + + with np.errstate(divide="ignore", invalid="ignore"): + prototype = np.nanmedian(waveforms[:, :, 0] / (np.abs(waveforms[:, nbefore, 0][:, np.newaxis])), axis=0) + + return prototype, waveforms[:, :, 0], few_peaks + + +def get_prototype_and_waveforms_from_recording( + recording, n_peaks=5000, ms_before=0.5, ms_after=0.5, seed=None, **all_kwargs +): + """ + Function to extract a prototype waveform from peaks detected on the fly. + + Parameters + ---------- + recording : Recording + The recording object containing the data. + n_peaks : int, optional + Number of peaks to consider, by default 5000. + ms_before : float, optional + Time in milliseconds before the peak to extract the waveform, by default 0.5. + ms_after : float, optional + Time in milliseconds after the peak to extract the waveform, by default 0.5. + seed : int or None, optional + Seed for random number generator, by default None. + **all_kwargs : dict + Additional keyword arguments for peak detection and job kwargs. + + Returns + ------- + prototype : numpy.array + The prototype waveform. + waveforms : numpy.array + The extracted waveforms for the selected peaks. + peaks : numpy.array + The selected peaks used to extract waveforms. + """ + from spikeinterface.sortingcomponents.peak_detection import detect_peaks + from spikeinterface.core.node_pipeline import ExtractSparseWaveforms + + detection_kwargs, job_kwargs = split_job_kwargs(all_kwargs) + + nbefore = int(ms_before * recording.sampling_frequency / 1000.0) + node = ExtractSparseWaveforms( + recording, + parents=None, + return_output=True, + ms_before=ms_before, + ms_after=ms_after, + radius_um=0, + ) + + pipeline_nodes = [node] + + recording_slices = get_shuffled_recording_slices(recording, seed=seed, **job_kwargs) + + res = detect_peaks( + recording, + pipeline_nodes=pipeline_nodes, + skip_after_n_peaks=n_peaks, + recording_slices=recording_slices, + **detection_kwargs, + **job_kwargs, + ) + + rng = np.random.RandomState(seed) + indices = rng.permutation(np.arange(len(res[0]))) + + few_peaks = res[0][indices[:n_peaks]] + waveforms = res[1][indices[:n_peaks]] + with np.errstate(divide="ignore", invalid="ignore"): prototype = np.nanmedian(waveforms[:, :, 0] / (np.abs(waveforms[:, nbefore, 0][:, np.newaxis])), axis=0) - return prototype + + return prototype, waveforms[:, :, 0], few_peaks + + +def get_prototype_and_waveforms( + recording, n_peaks=5000, peaks=None, ms_before=0.5, ms_after=0.5, seed=None, **all_kwargs +): + """ + Function to extract a prototype waveform either from peaks or from a peak detection. Note that in case + of a peak detection, the detection stops as soon as n_peaks are detected. + + Parameters + ---------- + recording : Recording + The recording object containing the data. + n_peaks : int, optional + Number of peaks to consider, by default 5000. + peaks : numpy.array, optional + Array of peaks, if None, peaks will be detected, by default None. + ms_before : float, optional + Time in milliseconds before the peak to extract the waveform, by default 0.5. + ms_after : float, optional + Time in milliseconds after the peak to extract the waveform, by default 0.5. + seed : int or None, optional + Seed for random number generator, by default None. + **all_kwargs : dict + Additional keyword arguments for peak detection and job kwargs. + + Returns + ------- + prototype : numpy.array + The prototype waveform. + waveforms : numpy.array + The extracted waveforms for the selected peaks. + peaks : numpy.array + The selected peaks used to extract waveforms. + """ + if peaks is None: + return get_prototype_and_waveforms_from_recording( + recording, n_peaks, ms_before=ms_before, ms_after=ms_after, seed=seed, **all_kwargs + ) + else: + return get_prototype_and_waveforms_from_peaks( + recording, peaks, n_peaks, ms_before=ms_before, ms_after=ms_after, seed=seed, **all_kwargs + ) def check_probe_for_drift_correction(recording, dist_x_max=60): num_channels = recording.get_num_channels() - if num_channels < 32: + if num_channels <= 32: return False else: locations = recording.get_channel_locations() @@ -151,3 +300,20 @@ def fit_sigmoid(xdata, ydata, p0=None): popt, pcov = curve_fit(sigmoid, xdata, ydata, p0) return popt + + +def get_shuffled_recording_slices(recording, seed=None, **job_kwargs): + from spikeinterface.core.job_tools import ensure_chunk_size + from spikeinterface.core.job_tools import divide_segment_into_chunks + + chunk_size = ensure_chunk_size(recording, **job_kwargs) + recording_slices = [] + for segment_index in range(recording.get_num_segments()): + num_frames = recording.get_num_samples(segment_index) + chunks = divide_segment_into_chunks(num_frames, chunk_size) + recording_slices.extend([(segment_index, frame_start, frame_stop) for frame_start, frame_stop in chunks]) + + rng = np.random.default_rng(seed) + recording_slices = rng.permutation(recording_slices) + + return recording_slices diff --git a/src/spikeinterface/widgets/unit_waveforms.py b/src/spikeinterface/widgets/unit_waveforms.py index c593836061..3b31eacee5 100644 --- a/src/spikeinterface/widgets/unit_waveforms.py +++ b/src/spikeinterface/widgets/unit_waveforms.py @@ -565,7 +565,7 @@ def _update_plot(self, change): channel_locations = self.sorting_analyzer.get_channel_locations() else: unit_indices = [list(self.templates.unit_ids).index(unit_id) for unit_id in unit_ids] - templates = self.templates.templates_array[unit_indices] + templates = self.templates.get_dense_templates()[unit_indices] templates_shadings = None channel_locations = self.templates.get_channel_locations()