Skip to content

Commit

Permalink
Merge branch 'main' into support_python_3.13
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 authored Jan 20, 2025
2 parents 97ee37f + cabe66e commit f0cd3cb
Show file tree
Hide file tree
Showing 4 changed files with 207 additions and 108 deletions.
70 changes: 57 additions & 13 deletions src/spikeinterface/widgets/sorting_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import numpy as np

import warnings

from .base import BaseWidget, to_attr

from .amplitudes import AmplitudesWidget
Expand All @@ -14,6 +16,9 @@
from ..core import SortingAnalyzer


_default_displayed_unit_properties = ["firing_rate", "num_spikes", "x", "y", "amplitude_median", "snr", "rp_violation"]


class SortingSummaryWidget(BaseWidget):
"""
Plots spike sorting summary.
Expand Down Expand Up @@ -42,14 +47,24 @@ class SortingSummaryWidget(BaseWidget):
label_choices : list or None, default: None
List of labels to be added to the curation table
(sortingview backend)
unit_table_properties : list or None, default: None
displayed_unit_properties : list or None, default: None
List of properties to be added to the unit table.
These may be drawn from the sorting extractor, and, if available,
the quality_metrics and template_metrics extensions of the SortingAnalyzer.
the quality_metrics/template_metrics/unit_locations extensions of the SortingAnalyzer.
See all properties available with sorting.get_property_keys(), and, if available,
analyzer.get_extension("quality_metrics").get_data().columns and
analyzer.get_extension("template_metrics").get_data().columns.
(sortingview backend)
extra_unit_properties : dict or None, default: None
A dict with extra units properties to display.
curation_dict : dict or None, default: None
When curation is True, optionaly the viewer can get a previous 'curation_dict'
to continue/check previous curations on this analyzer.
In this case label_definitions must be None beacuse it is already included in the curation_dict.
(spikeinterface_gui backend)
label_definitions : dict or None, default: None
When curation is True, optionaly the user can provide a label_definitions dict.
This replaces the label_choices in the curation_format.
(spikeinterface_gui backend)
"""

def __init__(
Expand All @@ -60,11 +75,24 @@ def __init__(
max_amplitudes_per_unit=None,
min_similarity_for_correlograms=0.2,
curation=False,
unit_table_properties=None,
displayed_unit_properties=None,
extra_unit_properties=None,
label_choices=None,
curation_dict=None,
label_definitions=None,
backend=None,
unit_table_properties=None,
**backend_kwargs,
):

if unit_table_properties is not None:
warnings.warn(
"plot_sorting_summary() : unit_table_properties is deprecated, use displayed_unit_properties instead",
category=DeprecationWarning,
stacklevel=2,
)
displayed_unit_properties = unit_table_properties

sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer)
self.check_extensions(
sorting_analyzer, ["correlograms", "spike_amplitudes", "unit_locations", "template_similarity"]
Expand All @@ -74,18 +102,29 @@ def __init__(
if unit_ids is None:
unit_ids = sorting.get_unit_ids()

plot_data = dict(
if curation_dict is not None and label_definitions is not None:
raise ValueError("curation_dict and label_definitions are mutualy exclusive, they cannot be not None both")

if displayed_unit_properties is None:
displayed_unit_properties = list(_default_displayed_unit_properties)
if extra_unit_properties is not None:
displayed_unit_properties += list(extra_unit_properties.keys())

data_plot = dict(
sorting_analyzer=sorting_analyzer,
unit_ids=unit_ids,
sparsity=sparsity,
min_similarity_for_correlograms=min_similarity_for_correlograms,
unit_table_properties=unit_table_properties,
displayed_unit_properties=displayed_unit_properties,
extra_unit_properties=extra_unit_properties,
curation=curation,
label_choices=label_choices,
max_amplitudes_per_unit=max_amplitudes_per_unit,
curation_dict=curation_dict,
label_definitions=label_definitions,
)

BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs)
BaseWidget.__init__(self, data_plot, backend=backend, **backend_kwargs)

def plot_sortingview(self, data_plot, **backend_kwargs):
import sortingview.views as vv
Expand Down Expand Up @@ -156,7 +195,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs):

# unit ids
v_units_table = generate_unit_table_view(
dp.sorting_analyzer, dp.unit_table_properties, similarity_scores=similarity_scores
dp.sorting_analyzer, dp.displayed_unit_properties, similarity_scores=similarity_scores
)

if dp.curation:
Expand Down Expand Up @@ -190,9 +229,14 @@ def plot_sortingview(self, data_plot, **backend_kwargs):
def plot_spikeinterface_gui(self, data_plot, **backend_kwargs):
sorting_analyzer = data_plot["sorting_analyzer"]

import spikeinterface_gui
from spikeinterface_gui import run_mainwindow

app = spikeinterface_gui.mkQApp()
win = spikeinterface_gui.MainWindow(sorting_analyzer, curation=data_plot["curation"])
win.show()
app.exec_()
run_mainwindow(
sorting_analyzer,
with_traces=True,
curation=data_plot["curation"],
curation_dict=data_plot["curation_dict"],
label_definitions=data_plot["label_definitions"],
extra_unit_properties=data_plot["extra_unit_properties"],
displayed_unit_properties=data_plot["displayed_unit_properties"],
)
31 changes: 22 additions & 9 deletions src/spikeinterface/widgets/tests/test_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ def setUpClass(cls):
spike_amplitudes=dict(),
unit_locations=dict(),
spike_locations=dict(),
quality_metrics=dict(metric_names=["snr", "isi_violation", "num_spikes", "amplitude_cutoff"]),
quality_metrics=dict(
metric_names=["snr", "isi_violation", "num_spikes", "firing_rate", "amplitude_cutoff"]
),
template_metrics=dict(),
correlograms=dict(),
template_similarity=dict(),
Expand Down Expand Up @@ -531,26 +533,37 @@ def test_plot_sorting_summary(self):
possible_backends = list(sw.SortingSummaryWidget.get_possible_backends())
for backend in possible_backends:
if backend not in self.skip_backends:
sw.plot_sorting_summary(self.sorting_analyzer_dense, backend=backend, **self.backend_kwargs[backend])
sw.plot_sorting_summary(self.sorting_analyzer_sparse, backend=backend, **self.backend_kwargs[backend])
sw.plot_sorting_summary(
self.sorting_analyzer_dense,
displayed_unit_properties=[],
backend=backend,
**self.backend_kwargs[backend],
)
sw.plot_sorting_summary(
self.sorting_analyzer_sparse,
displayed_unit_properties=[],
backend=backend,
**self.backend_kwargs[backend],
)
sw.plot_sorting_summary(
self.sorting_analyzer_sparse,
sparsity=self.sparsity_strict,
displayed_unit_properties=[],
backend=backend,
**self.backend_kwargs[backend],
)
# add unit_properties
# select unit_properties
sw.plot_sorting_summary(
self.sorting_analyzer_sparse,
unit_table_properties=["firing_rate", "snr"],
displayed_unit_properties=["firing_rate", "snr"],
backend=backend,
**self.backend_kwargs[backend],
)
# adding a missing property should raise a warning
with self.assertWarns(UserWarning):
sw.plot_sorting_summary(
self.sorting_analyzer_sparse,
unit_table_properties=["missing_property"],
displayed_unit_properties=["missing_property"],
backend=backend,
**self.backend_kwargs[backend],
)
Expand Down Expand Up @@ -688,9 +701,9 @@ def test_plot_motion_info(self):
# mytest.test_plot_unit_presence()
# mytest.test_plot_peak_activity()
# mytest.test_plot_multicomparison()
# mytest.test_plot_sorting_summary()
mytest.test_plot_sorting_summary()
# mytest.test_plot_motion()
mytest.test_plot_motion_info()
plt.show()
# mytest.test_plot_motion_info()
# plt.show()

# TestWidgets.tearDownClass()
90 changes: 90 additions & 0 deletions src/spikeinterface/widgets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,3 +243,93 @@ def array_to_image(
output_image = np.frombuffer(image.tobytes(), dtype=np.uint8).reshape(output_image.shape)

return output_image


def make_units_table_from_sorting(sorting, units_table=None):
"""
Make a DataFrame from sorting properties.
Only for properties with ndim=1
Parameters
----------
sorting : Sorting
The Sorting object
units_table : None | pd.DataFrame
Optionally a existing dataframe.
Returns
-------
units_table : pd.DataFrame
Table containing all columns.
"""

if units_table is None:
import pandas as pd

units_table = pd.DataFrame(index=sorting.unit_ids)

for col in sorting.get_property_keys():
values = sorting.get_property(col)
if values.dtype.kind in "iuUSfb" and values.ndim == 1:
units_table.loc[:, col] = values

return units_table


def make_units_table_from_analyzer(
analyzer,
extra_properties=None,
):
"""
Make a DataFrame by aggregating :
* quality metrics
* template metrics
* unit_position
* sorting properties
* extra columns
This used in sortingview and spikeinterface-gui to display the units table in a flexible way.
Parameters
----------
sorting_analyzer : SortingAnalyzer
The SortingAnalyzer object
extra_properties : None | dict
Extra columns given as dict.
Returns
-------
units_table : pd.DataFrame
Table containing all columns.
"""
import pandas as pd

all_df = []

if analyzer.get_extension("unit_locations") is not None:
locs = analyzer.get_extension("unit_locations").get_data()
df = pd.DataFrame(locs[:, :2], columns=["x", "y"], index=analyzer.unit_ids)
all_df.append(df)

if analyzer.get_extension("quality_metrics") is not None:
df = analyzer.get_extension("quality_metrics").get_data()
all_df.append(df)

if analyzer.get_extension("template_metrics") is not None:
df = analyzer.get_extension("template_metrics").get_data()
all_df.append(df)

if len(all_df) > 0:
units_table = pd.concat(all_df, axis=1)
else:
units_table = pd.DataFrame(index=analyzer.unit_ids)

make_units_table_from_sorting(analyzer.sorting, units_table=units_table)

if extra_properties is not None:
for col, values in extra_properties.items():
# the ndim = 1 is important because we need column only for the display in gui.
if values.dtype.kind in "iuUSfb" and values.ndim == 1:
units_table.loc[:, col] = values

return units_table
Loading

0 comments on commit f0cd3cb

Please sign in to comment.