Skip to content

Commit

Permalink
Parameter checking for classification
Browse files Browse the repository at this point in the history
  • Loading branch information
jakeswann1 committed May 30, 2024
1 parent 88bb8b4 commit 66f9098
Showing 1 changed file with 70 additions and 3 deletions.
73 changes: 70 additions & 3 deletions src/spikeinterface/curation/auto_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,63 @@
from sklearn.pipeline import Pipeline

from spikeinterface.core import SortingAnalyzer
from spikeinterface.qualitymetrics.quality_metric_calculator import get_default_qm_params
from spikeinterface.postprocessing.template_metrics import _default_function_kwargs as default_template_metrics_params
class ModelBasedClassification:
"""
Class for performing model-based classification on spike sorting data.
Parameters
----------
sorting_analyzer : SortingAnalyzer
The sorting analyzer object containing the spike sorting data.
pipeline : Pipeline
The pipeline object representing the trained classification model.
required_metrics : Sequence[str]
The list of required metrics for classification.
class ModelBasedClassification:
# TODO docstring
Attributes
----------
sorting_analyzer : SortingAnalyzer
The sorting analyzer object containing the spike sorting data.
required_metrics : Sequence[str]
The list of required metrics for classification.
pipeline : Pipeline
The pipeline object representing the trained classification model.
Methods
-------
predict_labels()
Predicts the labels for the spike sorting data using the trained model.
_get_metrics_for_classification()
Retrieves the metrics data required for classification.
_check_params_for_classification()
Checks if the parameters for classification match the training parameters.
"""

def __init__(self, sorting_analyzer: SortingAnalyzer, pipeline: Pipeline, required_metrics: Sequence[str]):

self.sorting_analyzer = sorting_analyzer

self.required_metrics = required_metrics
self.pipeline = pipeline

def predict_labels(self):
"""
Predicts the labels for the spike sorting data using the trained model.
Returns
-------
dict
A dictionary containing the classified units and their corresponding predictions and probabilities.
The dictionary has the format {unit_id: (prediction, probability)}.
"""

# Get metrics DataFrame for classification
input_data = self._get_metrics_for_classification()

# Check params match training data
self._check_params_for_classification()

# Prepare input data
input_data[np.isinf(input_data)] = np.nan
input_data = input_data.astype("float32")
Expand All @@ -37,6 +78,7 @@ def predict_labels(self):
return classified_units

def _get_metrics_for_classification(self):
""" Check if all required metrics are present and return a DataFrame of metrics for classification """

try:
quality_metrics = self.sorting_analyzer.extensions["quality_metrics"].data["metrics"]
Expand All @@ -57,6 +99,31 @@ def _get_metrics_for_classification(self):

return calculated_metrics

def _check_params_for_classification(self):
""" Check that quality and template metrics parameters match those used to train the model
NEEDS UPDATING TO PULL IN PARAMS FROM TRAINING DATA"""

try:
quality_metrics_params = self.sorting_analyzer.extensions["quality_metrics"].params["qm_params"]
template_metric_params = self.sorting_analyzer.extensions["template_metrics"].params["qm_params"]
except KeyError:
raise ValueError("Quality and template metrics must be computed before classification")

# TODO: check metrics_params match those used to train the model - how?
# TEMP - check that params match the default. Need to add ability to check against model training params
default_quality_metrics_params = get_default_qm_params()
default_template_metrics_params

# Check that dicts are identical
if quality_metrics_params != default_quality_metrics_params:
raise ValueError("Quality metrics params do not match default params")
elif template_metric_params != default_template_metrics_params:
raise ValueError("Template metrics params do not match default params")
else:
pass

# TODO: decide whether to also check params against parent extensions of metrics (e.g. waveforms, templates)
# This would need to account for the fact that these extensions may no longer exist

def auto_label_units(sorting_analyzer: SortingAnalyzer, pipeline: Pipeline, required_metrics: Sequence[str]):
"""
Expand Down

0 comments on commit 66f9098

Please sign in to comment.