From c411d7c8c8d5fd97565808dacd0e8d18b07922e3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 30 May 2024 15:48:14 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/curation/auto_label.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/curation/auto_label.py b/src/spikeinterface/curation/auto_label.py index 0e4979f91b..f6f6e55f28 100644 --- a/src/spikeinterface/curation/auto_label.py +++ b/src/spikeinterface/curation/auto_label.py @@ -6,6 +6,8 @@ from spikeinterface.core import SortingAnalyzer from spikeinterface.qualitymetrics.quality_metric_calculator import get_default_qm_params from spikeinterface.postprocessing.template_metrics import _default_function_kwargs as default_template_metrics_params + + class ModelBasedClassification: """ Class for performing model-based classification on spike sorting data. @@ -78,7 +80,7 @@ def predict_labels(self): return classified_units def _get_metrics_for_classification(self): - """ Check if all required metrics are present and return a DataFrame of metrics for classification """ + """Check if all required metrics are present and return a DataFrame of metrics for classification""" try: quality_metrics = self.sorting_analyzer.extensions["quality_metrics"].data["metrics"] @@ -100,7 +102,7 @@ def _get_metrics_for_classification(self): return calculated_metrics def _check_params_for_classification(self): - """ Check that quality and template metrics parameters match those used to train the model + """Check that quality and template metrics parameters match those used to train the model NEEDS UPDATING TO PULL IN PARAMS FROM TRAINING DATA""" try: @@ -108,7 +110,7 @@ def _check_params_for_classification(self): template_metric_params = self.sorting_analyzer.extensions["template_metrics"].params["qm_params"] except KeyError: raise ValueError("Quality and template metrics must be computed before classification") - + # TODO: check metrics_params match those used to train the model - how? # TEMP - check that params match the default. Need to add ability to check against model training params default_quality_metrics_params = get_default_qm_params() @@ -125,6 +127,7 @@ def _check_params_for_classification(self): # TODO: decide whether to also check params against parent extensions of metrics (e.g. waveforms, templates) # This would need to account for the fact that these extensions may no longer exist + def auto_label_units(sorting_analyzer: SortingAnalyzer, pipeline: Pipeline, required_metrics: Sequence[str]): """ Automatically labels units based on a model-based classification.