diff --git a/src/spikeinterface/widgets/sorting_summary.py b/src/spikeinterface/widgets/sorting_summary.py index a113298851..8eada29b0e 100644 --- a/src/spikeinterface/widgets/sorting_summary.py +++ b/src/spikeinterface/widgets/sorting_summary.py @@ -2,6 +2,8 @@ import numpy as np +import warnings + from .base import BaseWidget, to_attr from .amplitudes import AmplitudesWidget @@ -14,6 +16,9 @@ from ..core import SortingAnalyzer +_default_displayed_unit_properties = ["firing_rate", "num_spikes", "x", "y", "amplitude_median", "snr", "rp_violation"] + + class SortingSummaryWidget(BaseWidget): """ Plots spike sorting summary. @@ -42,14 +47,24 @@ class SortingSummaryWidget(BaseWidget): label_choices : list or None, default: None List of labels to be added to the curation table (sortingview backend) - unit_table_properties : list or None, default: None + displayed_unit_properties : list or None, default: None List of properties to be added to the unit table. These may be drawn from the sorting extractor, and, if available, - the quality_metrics and template_metrics extensions of the SortingAnalyzer. + the quality_metrics/template_metrics/unit_locations extensions of the SortingAnalyzer. See all properties available with sorting.get_property_keys(), and, if available, analyzer.get_extension("quality_metrics").get_data().columns and analyzer.get_extension("template_metrics").get_data().columns. - (sortingview backend) + extra_unit_properties : dict or None, default: None + A dict with extra units properties to display. + curation_dict : dict or None, default: None + When curation is True, optionaly the viewer can get a previous 'curation_dict' + to continue/check previous curations on this analyzer. + In this case label_definitions must be None beacuse it is already included in the curation_dict. + (spikeinterface_gui backend) + label_definitions : dict or None, default: None + When curation is True, optionaly the user can provide a label_definitions dict. + This replaces the label_choices in the curation_format. + (spikeinterface_gui backend) """ def __init__( @@ -60,11 +75,24 @@ def __init__( max_amplitudes_per_unit=None, min_similarity_for_correlograms=0.2, curation=False, - unit_table_properties=None, + displayed_unit_properties=None, + extra_unit_properties=None, label_choices=None, + curation_dict=None, + label_definitions=None, backend=None, + unit_table_properties=None, **backend_kwargs, ): + + if unit_table_properties is not None: + warnings.warn( + "plot_sorting_summary() : unit_table_properties is deprecated, use displayed_unit_properties instead", + category=DeprecationWarning, + stacklevel=2, + ) + displayed_unit_properties = unit_table_properties + sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) self.check_extensions( sorting_analyzer, ["correlograms", "spike_amplitudes", "unit_locations", "template_similarity"] @@ -74,18 +102,29 @@ def __init__( if unit_ids is None: unit_ids = sorting.get_unit_ids() - plot_data = dict( + if curation_dict is not None and label_definitions is not None: + raise ValueError("curation_dict and label_definitions are mutualy exclusive, they cannot be not None both") + + if displayed_unit_properties is None: + displayed_unit_properties = list(_default_displayed_unit_properties) + if extra_unit_properties is not None: + displayed_unit_properties += list(extra_unit_properties.keys()) + + data_plot = dict( sorting_analyzer=sorting_analyzer, unit_ids=unit_ids, sparsity=sparsity, min_similarity_for_correlograms=min_similarity_for_correlograms, - unit_table_properties=unit_table_properties, + displayed_unit_properties=displayed_unit_properties, + extra_unit_properties=extra_unit_properties, curation=curation, label_choices=label_choices, max_amplitudes_per_unit=max_amplitudes_per_unit, + curation_dict=curation_dict, + label_definitions=label_definitions, ) - BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + BaseWidget.__init__(self, data_plot, backend=backend, **backend_kwargs) def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv @@ -156,7 +195,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs): # unit ids v_units_table = generate_unit_table_view( - dp.sorting_analyzer, dp.unit_table_properties, similarity_scores=similarity_scores + dp.sorting_analyzer, dp.displayed_unit_properties, similarity_scores=similarity_scores ) if dp.curation: @@ -190,9 +229,14 @@ def plot_sortingview(self, data_plot, **backend_kwargs): def plot_spikeinterface_gui(self, data_plot, **backend_kwargs): sorting_analyzer = data_plot["sorting_analyzer"] - import spikeinterface_gui + from spikeinterface_gui import run_mainwindow - app = spikeinterface_gui.mkQApp() - win = spikeinterface_gui.MainWindow(sorting_analyzer, curation=data_plot["curation"]) - win.show() - app.exec_() + run_mainwindow( + sorting_analyzer, + with_traces=True, + curation=data_plot["curation"], + curation_dict=data_plot["curation_dict"], + label_definitions=data_plot["label_definitions"], + extra_unit_properties=data_plot["extra_unit_properties"], + displayed_unit_properties=data_plot["displayed_unit_properties"], + ) diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 80f58f5ad9..d5ffec6dba 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -73,7 +73,9 @@ def setUpClass(cls): spike_amplitudes=dict(), unit_locations=dict(), spike_locations=dict(), - quality_metrics=dict(metric_names=["snr", "isi_violation", "num_spikes", "amplitude_cutoff"]), + quality_metrics=dict( + metric_names=["snr", "isi_violation", "num_spikes", "firing_rate", "amplitude_cutoff"] + ), template_metrics=dict(), correlograms=dict(), template_similarity=dict(), @@ -531,18 +533,29 @@ def test_plot_sorting_summary(self): possible_backends = list(sw.SortingSummaryWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_sorting_summary(self.sorting_analyzer_dense, backend=backend, **self.backend_kwargs[backend]) - sw.plot_sorting_summary(self.sorting_analyzer_sparse, backend=backend, **self.backend_kwargs[backend]) + sw.plot_sorting_summary( + self.sorting_analyzer_dense, + displayed_unit_properties=[], + backend=backend, + **self.backend_kwargs[backend], + ) + sw.plot_sorting_summary( + self.sorting_analyzer_sparse, + displayed_unit_properties=[], + backend=backend, + **self.backend_kwargs[backend], + ) sw.plot_sorting_summary( self.sorting_analyzer_sparse, sparsity=self.sparsity_strict, + displayed_unit_properties=[], backend=backend, **self.backend_kwargs[backend], ) - # add unit_properties + # select unit_properties sw.plot_sorting_summary( self.sorting_analyzer_sparse, - unit_table_properties=["firing_rate", "snr"], + displayed_unit_properties=["firing_rate", "snr"], backend=backend, **self.backend_kwargs[backend], ) @@ -550,7 +563,7 @@ def test_plot_sorting_summary(self): with self.assertWarns(UserWarning): sw.plot_sorting_summary( self.sorting_analyzer_sparse, - unit_table_properties=["missing_property"], + displayed_unit_properties=["missing_property"], backend=backend, **self.backend_kwargs[backend], ) @@ -688,9 +701,9 @@ def test_plot_motion_info(self): # mytest.test_plot_unit_presence() # mytest.test_plot_peak_activity() # mytest.test_plot_multicomparison() - # mytest.test_plot_sorting_summary() + mytest.test_plot_sorting_summary() # mytest.test_plot_motion() - mytest.test_plot_motion_info() - plt.show() + # mytest.test_plot_motion_info() + # plt.show() # TestWidgets.tearDownClass() diff --git a/src/spikeinterface/widgets/utils.py b/src/spikeinterface/widgets/utils.py index ca09cc4d8f..75c6248f0f 100644 --- a/src/spikeinterface/widgets/utils.py +++ b/src/spikeinterface/widgets/utils.py @@ -243,3 +243,93 @@ def array_to_image( output_image = np.frombuffer(image.tobytes(), dtype=np.uint8).reshape(output_image.shape) return output_image + + +def make_units_table_from_sorting(sorting, units_table=None): + """ + Make a DataFrame from sorting properties. + Only for properties with ndim=1 + + Parameters + ---------- + sorting : Sorting + The Sorting object + units_table : None | pd.DataFrame + Optionally a existing dataframe. + + Returns + ------- + units_table : pd.DataFrame + Table containing all columns. + """ + + if units_table is None: + import pandas as pd + + units_table = pd.DataFrame(index=sorting.unit_ids) + + for col in sorting.get_property_keys(): + values = sorting.get_property(col) + if values.dtype.kind in "iuUSfb" and values.ndim == 1: + units_table.loc[:, col] = values + + return units_table + + +def make_units_table_from_analyzer( + analyzer, + extra_properties=None, +): + """ + Make a DataFrame by aggregating : + * quality metrics + * template metrics + * unit_position + * sorting properties + * extra columns + + This used in sortingview and spikeinterface-gui to display the units table in a flexible way. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The SortingAnalyzer object + extra_properties : None | dict + Extra columns given as dict. + + Returns + ------- + units_table : pd.DataFrame + Table containing all columns. + """ + import pandas as pd + + all_df = [] + + if analyzer.get_extension("unit_locations") is not None: + locs = analyzer.get_extension("unit_locations").get_data() + df = pd.DataFrame(locs[:, :2], columns=["x", "y"], index=analyzer.unit_ids) + all_df.append(df) + + if analyzer.get_extension("quality_metrics") is not None: + df = analyzer.get_extension("quality_metrics").get_data() + all_df.append(df) + + if analyzer.get_extension("template_metrics") is not None: + df = analyzer.get_extension("template_metrics").get_data() + all_df.append(df) + + if len(all_df) > 0: + units_table = pd.concat(all_df, axis=1) + else: + units_table = pd.DataFrame(index=analyzer.unit_ids) + + make_units_table_from_sorting(analyzer.sorting, units_table=units_table) + + if extra_properties is not None: + for col, values in extra_properties.items(): + # the ndim = 1 is important because we need column only for the display in gui. + if values.dtype.kind in "iuUSfb" and values.ndim == 1: + units_table.loc[:, col] = values + + return units_table diff --git a/src/spikeinterface/widgets/utils_sortingview.py b/src/spikeinterface/widgets/utils_sortingview.py index a6cc562ba2..d594414287 100644 --- a/src/spikeinterface/widgets/utils_sortingview.py +++ b/src/spikeinterface/widgets/utils_sortingview.py @@ -1,10 +1,12 @@ from __future__ import annotations +from warnings import warn + import numpy as np from ..core import SortingAnalyzer, BaseSorting from ..core.core_tools import check_json -from warnings import warn +from .utils import make_units_table_from_sorting, make_units_table_from_analyzer def make_serializable(*args): @@ -50,105 +52,55 @@ def handle_display_and_url(widget, view, **backend_kwargs): def generate_unit_table_view( sorting_or_sorting_analyzer: SortingAnalyzer | BaseSorting, unit_properties: list[str] | None = None, - similarity_scores: npndarray | None = None, + similarity_scores: np.ndarray | None = None, ): import sortingview.views as vv if isinstance(sorting_or_sorting_analyzer, SortingAnalyzer): analyzer = sorting_or_sorting_analyzer + units_tables = make_units_table_from_analyzer(analyzer) sorting = analyzer.sorting else: sorting = sorting_or_sorting_analyzer - analyzer = None - - # Find available unit properties from all sources - sorting_props = list(sorting.get_property_keys()) - if analyzer is not None: - if analyzer.get_extension("quality_metrics") is not None: - qm_props = list(analyzer.get_extension("quality_metrics").get_data().columns) - qm_data = analyzer.get_extension("quality_metrics").get_data() - else: - qm_props = [] - if analyzer.get_extension("template_metrics") is not None: - tm_props = list(analyzer.get_extension("template_metrics").get_data().columns) - tm_data = analyzer.get_extension("template_metrics").get_data() - else: - tm_props = [] - # Check for any overlaps and warn user if any - all_props = sorting_props + qm_props + tm_props - else: - all_props = sorting_props - qm_props = [] - tm_props = [] - qm_data = None - tm_data = None - - overlap_props = [prop for prop in all_props if all_props.count(prop) > 1] - if len(overlap_props) > 0: - warn( - f"Warning: Overlapping properties found in sorting, quality_metrics, and template_metrics: {overlap_props}" - ) - - # Get unit properties + units_tables = make_units_table_from_sorting(sorting) + # analyzer = None + if unit_properties is None: ut_columns = [] ut_rows = [vv.UnitsTableRow(unit_id=u, values={}) for u in sorting.unit_ids] else: + # keep only selected columns + unit_properties = np.array(unit_properties) + keep = np.isin(unit_properties, units_tables.columns) + if sum(keep) < len(unit_properties): + warn(f"Some unit properties are not in the sorting: {unit_properties[~keep]}") + unit_properties = unit_properties[keep] + units_tables = units_tables.loc[:, unit_properties] + + dtype_convertor = {"i": "int", "u": "int", "f": "float", "U": "str", "S": "str", "b": "bool"} + ut_columns = [] + for col in unit_properties: + values = units_tables[col].to_numpy() + if values.dtype.kind in dtype_convertor: + txt_dtype = dtype_convertor[values.dtype.kind] + ut_columns.append(vv.UnitsTableColumn(key=col, label=col, dtype=txt_dtype)) + ut_rows = [] - values = {} - valid_unit_properties = [] - - # Create columns for each property - for prop_name in unit_properties: - - # Get property values from correct location - if prop_name in sorting_props: - property_values = sorting.get_property(prop_name) - elif prop_name in qm_props: - property_values = qm_data[prop_name].to_numpy() - elif prop_name in tm_props: - property_values = tm_data[prop_name].to_numpy() - else: - warn(f"Property '{prop_name}' not found in sorting, quality_metrics, or template_metrics") - continue - - # make dtype available - val0 = np.array(property_values[0]) - if val0.dtype.kind in ("i", "u"): - dtype = "int" - elif val0.dtype.kind in ("U", "S"): - dtype = "str" - elif val0.dtype.kind == "f": - dtype = "float" - elif val0.dtype.kind == "b": - dtype = "bool" - else: - warn(f"Unsupported dtype {val0.dtype} for property {prop_name}. Skipping") - continue - ut_columns.append(vv.UnitsTableColumn(key=prop_name, label=prop_name, dtype=dtype)) - valid_unit_properties.append(prop_name) - - # Create rows for each unit - for ui, unit in enumerate(sorting.unit_ids): - for prop_name in valid_unit_properties: - - # Get property values from correct location - if prop_name in sorting_props: - property_values = sorting.get_property(prop_name) - elif prop_name in qm_props: - property_values = qm_data[prop_name].to_numpy() - elif prop_name in tm_props: - property_values = tm_data[prop_name].to_numpy() - - # Check for NaN values and round floats - val0 = np.array(property_values[0]) - if val0.dtype.kind == "f": - if np.isnan(property_values[ui]): - continue - property_values[ui] = np.format_float_positional(property_values[ui], precision=4, fractional=False) - values[prop_name] = property_values[ui] - ut_rows.append(vv.UnitsTableRow(unit_id=unit, values=check_json(values))) + for unit_index, unit_id in enumerate(sorting.unit_ids): + row_values = {} + for col in unit_properties: + values = units_tables[col].to_numpy() + if values.dtype.kind in dtype_convertor: + value = values[unit_index] + if values.dtype.kind == "f": + # Check for NaN values and round floats + if np.isnan(values[unit_index]): + continue + value = np.format_float_positional(value, precision=4, fractional=False) + row_values[col] = value + ut_rows.append(vv.UnitsTableRow(unit_id=unit_id, values=check_json(row_values))) v_units_table = vv.UnitsTable(rows=ut_rows, columns=ut_columns, similarity_scores=similarity_scores) + return v_units_table