Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch calculate_pc_metrics to compute_pc_metrics for api consistency #2925

Merged
merged 5 commits into from
Jun 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions src/spikeinterface/qualitymetrics/pca_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,14 @@ def get_quality_pca_metric_list():
return deepcopy(_possible_pc_metric_names)


def calculate_pc_metrics(
sorting_analyzer, metric_names=None, qm_params=None, unit_ids=None, seed=None, n_jobs=1, progress_bar=False
def compute_pc_metrics(
sorting_analyzer,
metric_names=None,
qm_params=None,
unit_ids=None,
seed=None,
n_jobs=1,
progress_bar=False,
):
"""Calculate principal component derived metrics.

Expand Down Expand Up @@ -180,6 +186,28 @@ def calculate_pc_metrics(
return pc_metrics


def calculate_pc_metrics(
sorting_analyzer, metric_names=None, qm_params=None, unit_ids=None, seed=None, n_jobs=1, progress_bar=False
):
warnings.warn(
"The `calculate_pc_metrics` function is deprecated and will be removed in 0.103.0. Please use compute_pc_metrics instead",
category=DeprecationWarning,
stacklevel=2,
)

pc_metrics = compute_pc_metrics(
sorting_analyzer,
metric_names=metric_names,
qm_params=qm_params,
unit_ids=unit_ids,
seed=seed,
n_jobs=n_jobs,
progress_bar=progress_bar,
)

return pc_metrics


#################################################################
# Code from spikemetrics

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension


from .quality_metric_list import calculate_pc_metrics, _misc_metric_name_to_func, _possible_pc_metric_names
from .quality_metric_list import compute_pc_metrics, _misc_metric_name_to_func, _possible_pc_metric_names
from .misc_metrics import _default_params as misc_metrics_params
from .pca_metrics import _default_params as pca_metrics_params

Expand Down Expand Up @@ -143,7 +143,7 @@ def _run(self, verbose=False, **job_kwargs):
if len(pc_metric_names) > 0 and not self.params["skip_pc_metrics"]:
if not self.sorting_analyzer.has_extension("principal_components"):
raise ValueError("waveform_principal_component must be provied")
pc_metrics = calculate_pc_metrics(
pc_metrics = compute_pc_metrics(
self.sorting_analyzer,
unit_ids=non_empty_unit_ids,
metric_names=pc_metric_names,
Expand Down
3 changes: 2 additions & 1 deletion src/spikeinterface/qualitymetrics/quality_metric_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
)

from .pca_metrics import (
calculate_pc_metrics,
compute_pc_metrics,
calculate_pc_metrics, # remove after 0.103.0
mahalanobis_metrics,
lda_metrics,
nearest_neighbors_metrics,
Expand Down
6 changes: 3 additions & 3 deletions src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from spikeinterface.qualitymetrics.utils import create_ground_truth_pc_distributions

from spikeinterface.qualitymetrics import (
calculate_pc_metrics,
compute_pc_metrics,
nearest_neighbors_isolation,
nearest_neighbors_noise_overlap,
)
Expand Down Expand Up @@ -55,10 +55,10 @@ def sorting_analyzer_simple():

def test_calculate_pc_metrics(sorting_analyzer_simple):
sorting_analyzer = sorting_analyzer_simple
res1 = calculate_pc_metrics(sorting_analyzer, n_jobs=1, progress_bar=True)
res1 = compute_pc_metrics(sorting_analyzer, n_jobs=1, progress_bar=True)
res1 = pd.DataFrame(res1)

res2 = calculate_pc_metrics(sorting_analyzer, n_jobs=2, progress_bar=True)
res2 = compute_pc_metrics(sorting_analyzer, n_jobs=2, progress_bar=True)
res2 = pd.DataFrame(res2)

for k in res1.columns:
Expand Down