Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed May 30, 2024
1 parent 66f9098 commit c411d7c
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions src/spikeinterface/curation/auto_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from spikeinterface.core import SortingAnalyzer
from spikeinterface.qualitymetrics.quality_metric_calculator import get_default_qm_params
from spikeinterface.postprocessing.template_metrics import _default_function_kwargs as default_template_metrics_params


class ModelBasedClassification:
"""
Class for performing model-based classification on spike sorting data.
Expand Down Expand Up @@ -78,7 +80,7 @@ def predict_labels(self):
return classified_units

def _get_metrics_for_classification(self):
""" Check if all required metrics are present and return a DataFrame of metrics for classification """
"""Check if all required metrics are present and return a DataFrame of metrics for classification"""

try:
quality_metrics = self.sorting_analyzer.extensions["quality_metrics"].data["metrics"]
Expand All @@ -100,15 +102,15 @@ def _get_metrics_for_classification(self):
return calculated_metrics

def _check_params_for_classification(self):
""" Check that quality and template metrics parameters match those used to train the model
"""Check that quality and template metrics parameters match those used to train the model
NEEDS UPDATING TO PULL IN PARAMS FROM TRAINING DATA"""

try:
quality_metrics_params = self.sorting_analyzer.extensions["quality_metrics"].params["qm_params"]
template_metric_params = self.sorting_analyzer.extensions["template_metrics"].params["qm_params"]
except KeyError:
raise ValueError("Quality and template metrics must be computed before classification")

# TODO: check metrics_params match those used to train the model - how?
# TEMP - check that params match the default. Need to add ability to check against model training params
default_quality_metrics_params = get_default_qm_params()
Expand All @@ -125,6 +127,7 @@ def _check_params_for_classification(self):
# TODO: decide whether to also check params against parent extensions of metrics (e.g. waveforms, templates)
# This would need to account for the fact that these extensions may no longer exist


def auto_label_units(sorting_analyzer: SortingAnalyzer, pipeline: Pipeline, required_metrics: Sequence[str]):
"""
Automatically labels units based on a model-based classification.
Expand Down

0 comments on commit c411d7c

Please sign in to comment.