From 7a5e75fc65008f4eafee48db8ebf69c043c5df65 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 8 Jan 2025 21:52:54 +0100 Subject: [PATCH 01/10] Start changes to improve sigui API --- src/spikeinterface/widgets/sorting_summary.py | 25 +++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/widgets/sorting_summary.py b/src/spikeinterface/widgets/sorting_summary.py index a113298851..6b8e9b7d44 100644 --- a/src/spikeinterface/widgets/sorting_summary.py +++ b/src/spikeinterface/widgets/sorting_summary.py @@ -50,8 +50,16 @@ class SortingSummaryWidget(BaseWidget): analyzer.get_extension("quality_metrics").get_data().columns and analyzer.get_extension("template_metrics").get_data().columns. (sortingview backend) + curation_dict : dict or None + When curation is True, optionaly the viewer can get a previous 'curation_dict' + to continue/check previous curations on this analyzer. + In this case label_definitions must be None beacuse it is already included in the curation_dict. + (spikeinterface_gui backend) + label_definitions : dict or None + When curation is True, optionaly the user can provide a label_definitions dict. + This replaces the label_choices in the curation_format. + (spikeinterface_gui backend) """ - def __init__( self, sorting_analyzer: SortingAnalyzer, @@ -62,6 +70,8 @@ def __init__( curation=False, unit_table_properties=None, label_choices=None, + curation_dict=None, + label_definitions=None, backend=None, **backend_kwargs, ): @@ -74,6 +84,9 @@ def __init__( if unit_ids is None: unit_ids = sorting.get_unit_ids() + if curation_dict is not None and label_definitions is not None: + raise ValueError("curation_dict and label_definitions are mutualy exclusive, they cannot be not None both") + plot_data = dict( sorting_analyzer=sorting_analyzer, unit_ids=unit_ids, @@ -83,6 +96,8 @@ def __init__( curation=curation, label_choices=label_choices, max_amplitudes_per_unit=max_amplitudes_per_unit, + curation_dict=curation_dict, + label_definitions=label_definitions, ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) @@ -193,6 +208,12 @@ def plot_spikeinterface_gui(self, data_plot, **backend_kwargs): import spikeinterface_gui app = spikeinterface_gui.mkQApp() - win = spikeinterface_gui.MainWindow(sorting_analyzer, curation=data_plot["curation"]) + win = spikeinterface_gui.MainWindow( + sorting_analyzer, + curation=data_plot["curation"] + curation_data=data_plot["curation_dict"], + label_definitions=data_plot["label_definitions"], + more_units_properties=data_plot["unit_table_properties"], + ) win.show() app.exec_() From 89c36ac603a50b67ebdc9a466b95f6d644346c88 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 14 Jan 2025 16:36:48 +0100 Subject: [PATCH 02/10] Some change in plot_unit_summary for sigui --- src/spikeinterface/widgets/sorting_summary.py | 40 ++++-- .../widgets/tests/test_widgets.py | 6 +- src/spikeinterface/widgets/utils.py | 74 +++++++++++ .../widgets/utils_sortingview.py | 115 +++++------------- 4 files changed, 135 insertions(+), 100 deletions(-) diff --git a/src/spikeinterface/widgets/sorting_summary.py b/src/spikeinterface/widgets/sorting_summary.py index 6b8e9b7d44..46796e3be4 100644 --- a/src/spikeinterface/widgets/sorting_summary.py +++ b/src/spikeinterface/widgets/sorting_summary.py @@ -2,6 +2,8 @@ import numpy as np +import warnings + from .base import BaseWidget, to_attr from .amplitudes import AmplitudesWidget @@ -50,6 +52,8 @@ class SortingSummaryWidget(BaseWidget): analyzer.get_extension("quality_metrics").get_data().columns and analyzer.get_extension("template_metrics").get_data().columns. (sortingview backend) + extra_units_properties : None dict, default: None + A dict with extra units properties to display. curation_dict : dict or None When curation is True, optionaly the viewer can get a previous 'curation_dict' to continue/check previous curations on this analyzer. @@ -68,13 +72,21 @@ def __init__( max_amplitudes_per_unit=None, min_similarity_for_correlograms=0.2, curation=False, - unit_table_properties=None, + displayed_units_properties=None, + extra_units_properties=None, label_choices=None, curation_dict=None, label_definitions=None, backend=None, + unit_table_properties=None, **backend_kwargs, ): + + if unit_table_properties is not None: + warnings.warn("plot_sorting_summary() : unit_table_properties is deprecated, use displayed_units_properties instead") + displayed_units_properties = unit_table_properties + + sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) self.check_extensions( sorting_analyzer, ["correlograms", "spike_amplitudes", "unit_locations", "template_similarity"] @@ -87,12 +99,13 @@ def __init__( if curation_dict is not None and label_definitions is not None: raise ValueError("curation_dict and label_definitions are mutualy exclusive, they cannot be not None both") - plot_data = dict( + data_plot = dict( sorting_analyzer=sorting_analyzer, unit_ids=unit_ids, sparsity=sparsity, min_similarity_for_correlograms=min_similarity_for_correlograms, - unit_table_properties=unit_table_properties, + displayed_units_properties=displayed_units_properties, + extra_units_properties=extra_units_properties, curation=curation, label_choices=label_choices, max_amplitudes_per_unit=max_amplitudes_per_unit, @@ -100,7 +113,7 @@ def __init__( label_definitions=label_definitions, ) - BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + BaseWidget.__init__(self, data_plot, backend=backend, **backend_kwargs) def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv @@ -171,7 +184,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs): # unit ids v_units_table = generate_unit_table_view( - dp.sorting_analyzer, dp.unit_table_properties, similarity_scores=similarity_scores + dp.sorting_analyzer, dp.displayed_units_properties, similarity_scores=similarity_scores ) if dp.curation: @@ -205,15 +218,16 @@ def plot_sortingview(self, data_plot, **backend_kwargs): def plot_spikeinterface_gui(self, data_plot, **backend_kwargs): sorting_analyzer = data_plot["sorting_analyzer"] - import spikeinterface_gui + from spikeinterface_gui import run_mainwindow + - app = spikeinterface_gui.mkQApp() - win = spikeinterface_gui.MainWindow( + run_mainwindow( sorting_analyzer, - curation=data_plot["curation"] - curation_data=data_plot["curation_dict"], + with_traces=True, + curation=data_plot["curation"], + curation_dict=data_plot["curation_dict"], label_definitions=data_plot["label_definitions"], - more_units_properties=data_plot["unit_table_properties"], + extra_units_properties=data_plot["extra_units_properties"], + displayed_units_properties=data_plot["displayed_units_properties"], ) - win.show() - app.exec_() + diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 80f58f5ad9..b723a7ca9f 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -688,9 +688,9 @@ def test_plot_motion_info(self): # mytest.test_plot_unit_presence() # mytest.test_plot_peak_activity() # mytest.test_plot_multicomparison() - # mytest.test_plot_sorting_summary() + mytest.test_plot_sorting_summary() # mytest.test_plot_motion() - mytest.test_plot_motion_info() - plt.show() + # mytest.test_plot_motion_info() + # plt.show() # TestWidgets.tearDownClass() diff --git a/src/spikeinterface/widgets/utils.py b/src/spikeinterface/widgets/utils.py index ca09cc4d8f..d7789f29ed 100644 --- a/src/spikeinterface/widgets/utils.py +++ b/src/spikeinterface/widgets/utils.py @@ -243,3 +243,77 @@ def array_to_image( output_image = np.frombuffer(image.tobytes(), dtype=np.uint8).reshape(output_image.shape) return output_image + + + +def make_units_table_from_sorting(sorting, units_table=None): + + if units_table is None: + import pandas as pd + units_table = pd.DataFrame(index=sorting.unit_ids) + + for col in sorting.get_property_keys(): + values = sorting.get_property(col) + if values.dtype.kind in "iuUSfb" and values.ndim == 1: + print(col, values, sorting.unit_ids) + print(col, len(values), len(sorting.unit_ids)) + units_table.loc[:, col] = values + + return units_table + +def make_units_table_from_analyzer( + analyzer, + extra_properties=None, + ): + """ + Make a DataFrame by aggregating : + * quality metrics + * template metrics + * unit_position + * sorting properties + * extra columns + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The SortingAnalyzer object + extra_properties : None | dict + Extra columns given as dict. + + Returns + ------- + units_table : pd.DataFrame + Table containing all columns. + """ + import pandas as pd + all_df = [] + + if analyzer.get_extension("unit_locations") is not None: + locs = analyzer.get_extension("unit_locations").get_data() + df = pd.DataFrame(locs[:, :2], columns=["x", "y"], index=analyzer.unit_ids) + print(df.index, df.index.dtype) + all_df.append(df) + + if analyzer.get_extension("quality_metrics") is not None: + df = analyzer.get_extension("quality_metrics").get_data() + print(df.index, df.index.dtype) + all_df.append(df) + + if analyzer.get_extension("template_metrics") is not None: + all_df = analyzer.get_extension("template_metrics").get_data() + all_df.append(df) + + if len(all_df) > 0: + units_table = pd.concat(all_df, axis=1) + else: + units_table = pd.DataFrame(index=analyzer.unit_ids) + + print(units_table) + make_units_table_from_sorting(analyzer.sorting, units_table=units_table) + + if extra_properties is not None: + for col, values in extra_properties.items(): + if values.dtype.kind in "iuUSfb" and values.ndim == 1: + units_table.loc[:, col] = values + + return units_table \ No newline at end of file diff --git a/src/spikeinterface/widgets/utils_sortingview.py b/src/spikeinterface/widgets/utils_sortingview.py index a6cc562ba2..215c1eaf32 100644 --- a/src/spikeinterface/widgets/utils_sortingview.py +++ b/src/spikeinterface/widgets/utils_sortingview.py @@ -1,10 +1,13 @@ from __future__ import annotations +from warnings import warn + import numpy as np from ..core import SortingAnalyzer, BaseSorting from ..core.core_tools import check_json -from warnings import warn +from .utils import make_units_table_from_sorting, make_units_table_from_analyzer + def make_serializable(*args): @@ -50,105 +53,49 @@ def handle_display_and_url(widget, view, **backend_kwargs): def generate_unit_table_view( sorting_or_sorting_analyzer: SortingAnalyzer | BaseSorting, unit_properties: list[str] | None = None, - similarity_scores: npndarray | None = None, + similarity_scores: np.ndarray | None = None, ): import sortingview.views as vv if isinstance(sorting_or_sorting_analyzer, SortingAnalyzer): analyzer = sorting_or_sorting_analyzer + units_tables = make_units_table_from_sorting(analyzer) sorting = analyzer.sorting else: sorting = sorting_or_sorting_analyzer - analyzer = None - - # Find available unit properties from all sources - sorting_props = list(sorting.get_property_keys()) - if analyzer is not None: - if analyzer.get_extension("quality_metrics") is not None: - qm_props = list(analyzer.get_extension("quality_metrics").get_data().columns) - qm_data = analyzer.get_extension("quality_metrics").get_data() - else: - qm_props = [] - if analyzer.get_extension("template_metrics") is not None: - tm_props = list(analyzer.get_extension("template_metrics").get_data().columns) - tm_data = analyzer.get_extension("template_metrics").get_data() - else: - tm_props = [] - # Check for any overlaps and warn user if any - all_props = sorting_props + qm_props + tm_props - else: - all_props = sorting_props - qm_props = [] - tm_props = [] - qm_data = None - tm_data = None - - overlap_props = [prop for prop in all_props if all_props.count(prop) > 1] - if len(overlap_props) > 0: - warn( - f"Warning: Overlapping properties found in sorting, quality_metrics, and template_metrics: {overlap_props}" - ) - - # Get unit properties + units_tables = make_units_table_from_analyzer(sorting) + # analyzer = None + if unit_properties is None: ut_columns = [] ut_rows = [vv.UnitsTableRow(unit_id=u, values={}) for u in sorting.unit_ids] else: + # keep only selected columns + unit_properties = np.array(unit_properties) + keep = np.isin(unit_properties, units_tables.columns) + unit_properties = unit_properties[keep] + units_tables = units_tables.loc[:, unit_properties] + ut_columns = [] - ut_rows = [] - values = {} - valid_unit_properties = [] - - # Create columns for each property - for prop_name in unit_properties: - - # Get property values from correct location - if prop_name in sorting_props: - property_values = sorting.get_property(prop_name) - elif prop_name in qm_props: - property_values = qm_data[prop_name].to_numpy() - elif prop_name in tm_props: - property_values = tm_data[prop_name].to_numpy() - else: - warn(f"Property '{prop_name}' not found in sorting, quality_metrics, or template_metrics") - continue - - # make dtype available - val0 = np.array(property_values[0]) - if val0.dtype.kind in ("i", "u"): - dtype = "int" - elif val0.dtype.kind in ("U", "S"): - dtype = "str" - elif val0.dtype.kind == "f": - dtype = "float" - elif val0.dtype.kind == "b": - dtype = "bool" - else: - warn(f"Unsupported dtype {val0.dtype} for property {prop_name}. Skipping") - continue - ut_columns.append(vv.UnitsTableColumn(key=prop_name, label=prop_name, dtype=dtype)) - valid_unit_properties.append(prop_name) - - # Create rows for each unit - for ui, unit in enumerate(sorting.unit_ids): - for prop_name in valid_unit_properties: - - # Get property values from correct location - if prop_name in sorting_props: - property_values = sorting.get_property(prop_name) - elif prop_name in qm_props: - property_values = qm_data[prop_name].to_numpy() - elif prop_name in tm_props: - property_values = tm_data[prop_name].to_numpy() + for col in units_tables.columns: + values = units_tables[col].to_numpy() + ut_columns.append(vv.UnitsTableColumn(key=col, label=col, dtype=values.dtype)) + ut_rows = [] + for unit_index, unit_id in enumerate(sorting.unit_ids): + row_values = {} + for col in units_tables.columns: + values = units_tables[col].to_numpy() + value = values[unit_index] # Check for NaN values and round floats - val0 = np.array(property_values[0]) - if val0.dtype.kind == "f": - if np.isnan(property_values[ui]): + if values.dtype.kind == "f": + if np.isnan(values[unit_index]): continue - property_values[ui] = np.format_float_positional(property_values[ui], precision=4, fractional=False) - values[prop_name] = property_values[ui] - ut_rows.append(vv.UnitsTableRow(unit_id=unit, values=check_json(values))) + value = np.format_float_positional(value, precision=4, fractional=False) + row_values[col] = value + ut_rows.append(vv.UnitsTableRow(unit_id=unit_id, values=check_json(row_values))) + v_units_table = vv.UnitsTable(rows=ut_rows, columns=ut_columns, similarity_scores=similarity_scores) + return v_units_table From 3c1e19524c6c04284332a94fad81bc548b20659d Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 15 Jan 2025 09:35:09 +0100 Subject: [PATCH 03/10] clean --- src/spikeinterface/widgets/sorting_summary.py | 5 ++-- src/spikeinterface/widgets/utils.py | 26 ++++++++++++++----- .../widgets/utils_sortingview.py | 22 +++++++++------- 3 files changed, 35 insertions(+), 18 deletions(-) diff --git a/src/spikeinterface/widgets/sorting_summary.py b/src/spikeinterface/widgets/sorting_summary.py index 46796e3be4..0b127abae1 100644 --- a/src/spikeinterface/widgets/sorting_summary.py +++ b/src/spikeinterface/widgets/sorting_summary.py @@ -44,14 +44,13 @@ class SortingSummaryWidget(BaseWidget): label_choices : list or None, default: None List of labels to be added to the curation table (sortingview backend) - unit_table_properties : list or None, default: None + displayed_units_properties : list or None, default: None List of properties to be added to the unit table. These may be drawn from the sorting extractor, and, if available, - the quality_metrics and template_metrics extensions of the SortingAnalyzer. + the quality_metrics/template_metrics/unit_locations extensions of the SortingAnalyzer. See all properties available with sorting.get_property_keys(), and, if available, analyzer.get_extension("quality_metrics").get_data().columns and analyzer.get_extension("template_metrics").get_data().columns. - (sortingview backend) extra_units_properties : None dict, default: None A dict with extra units properties to display. curation_dict : dict or None diff --git a/src/spikeinterface/widgets/utils.py b/src/spikeinterface/widgets/utils.py index d7789f29ed..89025dea31 100644 --- a/src/spikeinterface/widgets/utils.py +++ b/src/spikeinterface/widgets/utils.py @@ -247,6 +247,22 @@ def array_to_image( def make_units_table_from_sorting(sorting, units_table=None): + """ + Make a DataFrame from sorting properties. + Only for properties with ndim=1 + + Parameters + ---------- + sorting : Sorting + The Sorting object + units_table : None | pd.DataFrame + Optionally a existing dataframe. + + Returns + ------- + units_table : pd.DataFrame + Table containing all columns. + """ if units_table is None: import pandas as pd @@ -255,8 +271,6 @@ def make_units_table_from_sorting(sorting, units_table=None): for col in sorting.get_property_keys(): values = sorting.get_property(col) if values.dtype.kind in "iuUSfb" and values.ndim == 1: - print(col, values, sorting.unit_ids) - print(col, len(values), len(sorting.unit_ids)) units_table.loc[:, col] = values return units_table @@ -273,6 +287,8 @@ def make_units_table_from_analyzer( * sorting properties * extra columns + This used in sortingview and spikeinterface-gui to display the units table in a flexible way. + Parameters ---------- sorting_analyzer : SortingAnalyzer @@ -291,12 +307,10 @@ def make_units_table_from_analyzer( if analyzer.get_extension("unit_locations") is not None: locs = analyzer.get_extension("unit_locations").get_data() df = pd.DataFrame(locs[:, :2], columns=["x", "y"], index=analyzer.unit_ids) - print(df.index, df.index.dtype) all_df.append(df) if analyzer.get_extension("quality_metrics") is not None: df = analyzer.get_extension("quality_metrics").get_data() - print(df.index, df.index.dtype) all_df.append(df) if analyzer.get_extension("template_metrics") is not None: @@ -308,12 +322,12 @@ def make_units_table_from_analyzer( else: units_table = pd.DataFrame(index=analyzer.unit_ids) - print(units_table) make_units_table_from_sorting(analyzer.sorting, units_table=units_table) if extra_properties is not None: for col, values in extra_properties.items(): + # the ndim = 1 is important because we need column only for the display in gui. if values.dtype.kind in "iuUSfb" and values.ndim == 1: units_table.loc[:, col] = values - return units_table \ No newline at end of file + return units_table diff --git a/src/spikeinterface/widgets/utils_sortingview.py b/src/spikeinterface/widgets/utils_sortingview.py index 215c1eaf32..f6eb8ea529 100644 --- a/src/spikeinterface/widgets/utils_sortingview.py +++ b/src/spikeinterface/widgets/utils_sortingview.py @@ -76,26 +76,30 @@ def generate_unit_table_view( unit_properties = unit_properties[keep] units_tables = units_tables.loc[:, unit_properties] + dtype_convertor = {"i": "int", "u": "int", "f": "float", "U": "str", "S": "str", "b": "bool"} + ut_columns = [] for col in units_tables.columns: values = units_tables[col].to_numpy() - ut_columns.append(vv.UnitsTableColumn(key=col, label=col, dtype=values.dtype)) + if values.dtype.kind in dtype_convertor: + txt_dtype = dtype_convertor[values.dtype.kind] + ut_columns.append(vv.UnitsTableColumn(key=col, label=col, dtype=txt_dtype)) ut_rows = [] for unit_index, unit_id in enumerate(sorting.unit_ids): row_values = {} for col in units_tables.columns: values = units_tables[col].to_numpy() - value = values[unit_index] - # Check for NaN values and round floats - if values.dtype.kind == "f": - if np.isnan(values[unit_index]): - continue - value = np.format_float_positional(value, precision=4, fractional=False) - row_values[col] = value + if values.dtype.kind in dtype_convertor: + value = values[unit_index] + if values.dtype.kind == "f": + # Check for NaN values and round floats + if np.isnan(values[unit_index]): + continue + value = np.format_float_positional(value, precision=4, fractional=False) + row_values[col] = value ut_rows.append(vv.UnitsTableRow(unit_id=unit_id, values=check_json(row_values))) - v_units_table = vv.UnitsTable(rows=ut_rows, columns=ut_columns, similarity_scores=similarity_scores) return v_units_table From 941a053f45fd7b89691c004e74fa396fb29ca6f6 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 15 Jan 2025 13:20:06 +0100 Subject: [PATCH 04/10] default displayed props --- src/spikeinterface/widgets/sorting_summary.py | 33 ++++++++++++------- .../widgets/utils_sortingview.py | 12 ++++--- 2 files changed, 30 insertions(+), 15 deletions(-) diff --git a/src/spikeinterface/widgets/sorting_summary.py b/src/spikeinterface/widgets/sorting_summary.py index 0b127abae1..480fb76ceb 100644 --- a/src/spikeinterface/widgets/sorting_summary.py +++ b/src/spikeinterface/widgets/sorting_summary.py @@ -16,6 +16,8 @@ from ..core import SortingAnalyzer +_default_displayed_unit_properties = ["firing_rate", "num_spikes", "x", "y", "amplitude", "snr", "rp_violation"] + class SortingSummaryWidget(BaseWidget): """ Plots spike sorting summary. @@ -44,14 +46,14 @@ class SortingSummaryWidget(BaseWidget): label_choices : list or None, default: None List of labels to be added to the curation table (sortingview backend) - displayed_units_properties : list or None, default: None + displayed_unit_properties : list or None, default: None List of properties to be added to the unit table. These may be drawn from the sorting extractor, and, if available, the quality_metrics/template_metrics/unit_locations extensions of the SortingAnalyzer. See all properties available with sorting.get_property_keys(), and, if available, analyzer.get_extension("quality_metrics").get_data().columns and analyzer.get_extension("template_metrics").get_data().columns. - extra_units_properties : None dict, default: None + extra_unit_properties : None dict, default: None A dict with extra units properties to display. curation_dict : dict or None When curation is True, optionaly the viewer can get a previous 'curation_dict' @@ -71,8 +73,8 @@ def __init__( max_amplitudes_per_unit=None, min_similarity_for_correlograms=0.2, curation=False, - displayed_units_properties=None, - extra_units_properties=None, + displayed_unit_properties=None, + extra_unit_properties=None, label_choices=None, curation_dict=None, label_definitions=None, @@ -82,8 +84,12 @@ def __init__( ): if unit_table_properties is not None: - warnings.warn("plot_sorting_summary() : unit_table_properties is deprecated, use displayed_units_properties instead") - displayed_units_properties = unit_table_properties + warnings.warn( + "plot_sorting_summary() : unit_table_properties is deprecated, use displayed_unit_properties instead", + category=DeprecationWarning, + stacklevel=2, + ) + displayed_unit_properties = unit_table_properties sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) @@ -98,13 +104,18 @@ def __init__( if curation_dict is not None and label_definitions is not None: raise ValueError("curation_dict and label_definitions are mutualy exclusive, they cannot be not None both") + if displayed_unit_properties is None: + displayed_unit_properties = list(_default_displayed_unit_properties) + if extra_unit_properties is not None: + displayed_unit_properties += list(extra_unit_properties.keys()) + data_plot = dict( sorting_analyzer=sorting_analyzer, unit_ids=unit_ids, sparsity=sparsity, min_similarity_for_correlograms=min_similarity_for_correlograms, - displayed_units_properties=displayed_units_properties, - extra_units_properties=extra_units_properties, + displayed_unit_properties=displayed_unit_properties, + extra_unit_properties=extra_unit_properties, curation=curation, label_choices=label_choices, max_amplitudes_per_unit=max_amplitudes_per_unit, @@ -183,7 +194,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs): # unit ids v_units_table = generate_unit_table_view( - dp.sorting_analyzer, dp.displayed_units_properties, similarity_scores=similarity_scores + dp.sorting_analyzer, dp.displayed_unit_properties, similarity_scores=similarity_scores ) if dp.curation: @@ -226,7 +237,7 @@ def plot_spikeinterface_gui(self, data_plot, **backend_kwargs): curation=data_plot["curation"], curation_dict=data_plot["curation_dict"], label_definitions=data_plot["label_definitions"], - extra_units_properties=data_plot["extra_units_properties"], - displayed_units_properties=data_plot["displayed_units_properties"], + extra_unit_properties=data_plot["extra_unit_properties"], + displayed_unit_properties=data_plot["displayed_unit_properties"], ) diff --git a/src/spikeinterface/widgets/utils_sortingview.py b/src/spikeinterface/widgets/utils_sortingview.py index f6eb8ea529..2a7a8d5ec4 100644 --- a/src/spikeinterface/widgets/utils_sortingview.py +++ b/src/spikeinterface/widgets/utils_sortingview.py @@ -59,11 +59,11 @@ def generate_unit_table_view( if isinstance(sorting_or_sorting_analyzer, SortingAnalyzer): analyzer = sorting_or_sorting_analyzer - units_tables = make_units_table_from_sorting(analyzer) + units_tables = make_units_table_from_analyzer(analyzer) sorting = analyzer.sorting else: sorting = sorting_or_sorting_analyzer - units_tables = make_units_table_from_analyzer(sorting) + units_tables = make_units_table_from_sorting(sorting) # analyzer = None if unit_properties is None: @@ -79,7 +79,9 @@ def generate_unit_table_view( dtype_convertor = {"i": "int", "u": "int", "f": "float", "U": "str", "S": "str", "b": "bool"} ut_columns = [] - for col in units_tables.columns: + for col in unit_properties: + if col not in units_tables.columns: + continue values = units_tables[col].to_numpy() if values.dtype.kind in dtype_convertor: txt_dtype = dtype_convertor[values.dtype.kind] @@ -88,7 +90,9 @@ def generate_unit_table_view( ut_rows = [] for unit_index, unit_id in enumerate(sorting.unit_ids): row_values = {} - for col in units_tables.columns: + for col in unit_properties: + if col not in units_tables.columns: + continue values = units_tables[col].to_numpy() if values.dtype.kind in dtype_convertor: value = values[unit_index] From 9b8f3d230358de623154630014dcc19db5075165 Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Fri, 17 Jan 2025 12:12:14 +0100 Subject: [PATCH 05/10] yep Co-authored-by: Alessio Buccino --- src/spikeinterface/widgets/sorting_summary.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/widgets/sorting_summary.py b/src/spikeinterface/widgets/sorting_summary.py index 480fb76ceb..76d418b64d 100644 --- a/src/spikeinterface/widgets/sorting_summary.py +++ b/src/spikeinterface/widgets/sorting_summary.py @@ -55,12 +55,12 @@ class SortingSummaryWidget(BaseWidget): analyzer.get_extension("template_metrics").get_data().columns. extra_unit_properties : None dict, default: None A dict with extra units properties to display. - curation_dict : dict or None + curation_dict : dict or None, default: None When curation is True, optionaly the viewer can get a previous 'curation_dict' to continue/check previous curations on this analyzer. In this case label_definitions must be None beacuse it is already included in the curation_dict. (spikeinterface_gui backend) - label_definitions : dict or None + label_definitions : dict or None, default: None When curation is True, optionaly the user can provide a label_definitions dict. This replaces the label_choices in the curation_format. (spikeinterface_gui backend) From e8b4fe12e1674bd4e3d95734b33a94cd011a0fef Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 17 Jan 2025 16:00:59 +0000 Subject: [PATCH 06/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/widgets/sorting_summary.py | 11 +++++------ src/spikeinterface/widgets/utils.py | 10 ++++++---- src/spikeinterface/widgets/utils_sortingview.py | 5 ++--- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/spikeinterface/widgets/sorting_summary.py b/src/spikeinterface/widgets/sorting_summary.py index 76d418b64d..2271e5a4cb 100644 --- a/src/spikeinterface/widgets/sorting_summary.py +++ b/src/spikeinterface/widgets/sorting_summary.py @@ -18,6 +18,7 @@ _default_displayed_unit_properties = ["firing_rate", "num_spikes", "x", "y", "amplitude", "snr", "rp_violation"] + class SortingSummaryWidget(BaseWidget): """ Plots spike sorting summary. @@ -65,6 +66,7 @@ class SortingSummaryWidget(BaseWidget): This replaces the label_choices in the curation_format. (spikeinterface_gui backend) """ + def __init__( self, sorting_analyzer: SortingAnalyzer, @@ -82,16 +84,15 @@ def __init__( unit_table_properties=None, **backend_kwargs, ): - + if unit_table_properties is not None: warnings.warn( "plot_sorting_summary() : unit_table_properties is deprecated, use displayed_unit_properties instead", category=DeprecationWarning, stacklevel=2, - ) + ) displayed_unit_properties = unit_table_properties - sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) self.check_extensions( sorting_analyzer, ["correlograms", "spike_amplitudes", "unit_locations", "template_similarity"] @@ -108,7 +109,7 @@ def __init__( displayed_unit_properties = list(_default_displayed_unit_properties) if extra_unit_properties is not None: displayed_unit_properties += list(extra_unit_properties.keys()) - + data_plot = dict( sorting_analyzer=sorting_analyzer, unit_ids=unit_ids, @@ -230,7 +231,6 @@ def plot_spikeinterface_gui(self, data_plot, **backend_kwargs): from spikeinterface_gui import run_mainwindow - run_mainwindow( sorting_analyzer, with_traces=True, @@ -240,4 +240,3 @@ def plot_spikeinterface_gui(self, data_plot, **backend_kwargs): extra_unit_properties=data_plot["extra_unit_properties"], displayed_unit_properties=data_plot["displayed_unit_properties"], ) - diff --git a/src/spikeinterface/widgets/utils.py b/src/spikeinterface/widgets/utils.py index 89025dea31..7d5cf98c01 100644 --- a/src/spikeinterface/widgets/utils.py +++ b/src/spikeinterface/widgets/utils.py @@ -245,7 +245,6 @@ def array_to_image( return output_image - def make_units_table_from_sorting(sorting, units_table=None): """ Make a DataFrame from sorting properties. @@ -266,6 +265,7 @@ def make_units_table_from_sorting(sorting, units_table=None): if units_table is None: import pandas as pd + units_table = pd.DataFrame(index=sorting.unit_ids) for col in sorting.get_property_keys(): @@ -275,10 +275,11 @@ def make_units_table_from_sorting(sorting, units_table=None): return units_table + def make_units_table_from_analyzer( - analyzer, - extra_properties=None, - ): + analyzer, + extra_properties=None, +): """ Make a DataFrame by aggregating : * quality metrics @@ -302,6 +303,7 @@ def make_units_table_from_analyzer( Table containing all columns. """ import pandas as pd + all_df = [] if analyzer.get_extension("unit_locations") is not None: diff --git a/src/spikeinterface/widgets/utils_sortingview.py b/src/spikeinterface/widgets/utils_sortingview.py index 2a7a8d5ec4..451b1d145e 100644 --- a/src/spikeinterface/widgets/utils_sortingview.py +++ b/src/spikeinterface/widgets/utils_sortingview.py @@ -9,7 +9,6 @@ from .utils import make_units_table_from_sorting, make_units_table_from_analyzer - def make_serializable(*args): dict_to_serialize = {int(i): a for i, a in enumerate(args)} serializable_dict = check_json(dict_to_serialize) @@ -65,7 +64,7 @@ def generate_unit_table_view( sorting = sorting_or_sorting_analyzer units_tables = make_units_table_from_sorting(sorting) # analyzer = None - + if unit_properties is None: ut_columns = [] ut_rows = [vv.UnitsTableRow(unit_id=u, values={}) for u in sorting.unit_ids] @@ -105,5 +104,5 @@ def generate_unit_table_view( ut_rows.append(vv.UnitsTableRow(unit_id=unit_id, values=check_json(row_values))) v_units_table = vv.UnitsTable(rows=ut_rows, columns=ut_columns, similarity_scores=similarity_scores) - + return v_units_table From 7e49ab963d2ed108d61688e4f9e16ab46f938768 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 17 Jan 2025 17:12:43 +0100 Subject: [PATCH 07/10] Update src/spikeinterface/widgets/sorting_summary.py --- src/spikeinterface/widgets/sorting_summary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/sorting_summary.py b/src/spikeinterface/widgets/sorting_summary.py index 2271e5a4cb..8587830862 100644 --- a/src/spikeinterface/widgets/sorting_summary.py +++ b/src/spikeinterface/widgets/sorting_summary.py @@ -54,7 +54,7 @@ class SortingSummaryWidget(BaseWidget): See all properties available with sorting.get_property_keys(), and, if available, analyzer.get_extension("quality_metrics").get_data().columns and analyzer.get_extension("template_metrics").get_data().columns. - extra_unit_properties : None dict, default: None + extra_unit_properties : dict or None, default: None A dict with extra units properties to display. curation_dict : dict or None, default: None When curation is True, optionaly the viewer can get a previous 'curation_dict' From 4831ad0851e71910b3fce9e2630b6c7cc72190d3 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Sat, 18 Jan 2025 16:16:33 +0100 Subject: [PATCH 08/10] Use pd.concat instead of df.append --- src/spikeinterface/widgets/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/widgets/utils.py b/src/spikeinterface/widgets/utils.py index 7d5cf98c01..ae1dce8571 100644 --- a/src/spikeinterface/widgets/utils.py +++ b/src/spikeinterface/widgets/utils.py @@ -309,15 +309,15 @@ def make_units_table_from_analyzer( if analyzer.get_extension("unit_locations") is not None: locs = analyzer.get_extension("unit_locations").get_data() df = pd.DataFrame(locs[:, :2], columns=["x", "y"], index=analyzer.unit_ids) - all_df.append(df) + all_df = pd.concat([all_df, df]) if analyzer.get_extension("quality_metrics") is not None: df = analyzer.get_extension("quality_metrics").get_data() - all_df.append(df) + all_df = pd.concat([all_df, df]) if analyzer.get_extension("template_metrics") is not None: all_df = analyzer.get_extension("template_metrics").get_data() - all_df.append(df) + all_df = pd.concat([all_df, df]) if len(all_df) > 0: units_table = pd.concat(all_df, axis=1) From 2ba92b7c7987cb12142556c0cbeb8f34bccb6cb3 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Sat, 18 Jan 2025 16:26:19 +0100 Subject: [PATCH 09/10] Fix template metrics sv --- src/spikeinterface/widgets/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/widgets/utils.py b/src/spikeinterface/widgets/utils.py index ae1dce8571..75c6248f0f 100644 --- a/src/spikeinterface/widgets/utils.py +++ b/src/spikeinterface/widgets/utils.py @@ -309,15 +309,15 @@ def make_units_table_from_analyzer( if analyzer.get_extension("unit_locations") is not None: locs = analyzer.get_extension("unit_locations").get_data() df = pd.DataFrame(locs[:, :2], columns=["x", "y"], index=analyzer.unit_ids) - all_df = pd.concat([all_df, df]) + all_df.append(df) if analyzer.get_extension("quality_metrics") is not None: df = analyzer.get_extension("quality_metrics").get_data() - all_df = pd.concat([all_df, df]) + all_df.append(df) if analyzer.get_extension("template_metrics") is not None: - all_df = analyzer.get_extension("template_metrics").get_data() - all_df = pd.concat([all_df, df]) + df = analyzer.get_extension("template_metrics").get_data() + all_df.append(df) if len(all_df) > 0: units_table = pd.concat(all_df, axis=1) From 6cf86f958331dc7965206568b0cf4e9b0772149d Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Sat, 18 Jan 2025 17:31:54 +0100 Subject: [PATCH 10/10] Fix tests and warning --- src/spikeinterface/widgets/sorting_summary.py | 2 +- .../widgets/tests/test_widgets.py | 25 ++++++++++++++----- .../widgets/utils_sortingview.py | 6 ++--- 3 files changed, 22 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/widgets/sorting_summary.py b/src/spikeinterface/widgets/sorting_summary.py index 8587830862..8eada29b0e 100644 --- a/src/spikeinterface/widgets/sorting_summary.py +++ b/src/spikeinterface/widgets/sorting_summary.py @@ -16,7 +16,7 @@ from ..core import SortingAnalyzer -_default_displayed_unit_properties = ["firing_rate", "num_spikes", "x", "y", "amplitude", "snr", "rp_violation"] +_default_displayed_unit_properties = ["firing_rate", "num_spikes", "x", "y", "amplitude_median", "snr", "rp_violation"] class SortingSummaryWidget(BaseWidget): diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index b723a7ca9f..d5ffec6dba 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -73,7 +73,9 @@ def setUpClass(cls): spike_amplitudes=dict(), unit_locations=dict(), spike_locations=dict(), - quality_metrics=dict(metric_names=["snr", "isi_violation", "num_spikes", "amplitude_cutoff"]), + quality_metrics=dict( + metric_names=["snr", "isi_violation", "num_spikes", "firing_rate", "amplitude_cutoff"] + ), template_metrics=dict(), correlograms=dict(), template_similarity=dict(), @@ -531,18 +533,29 @@ def test_plot_sorting_summary(self): possible_backends = list(sw.SortingSummaryWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_sorting_summary(self.sorting_analyzer_dense, backend=backend, **self.backend_kwargs[backend]) - sw.plot_sorting_summary(self.sorting_analyzer_sparse, backend=backend, **self.backend_kwargs[backend]) + sw.plot_sorting_summary( + self.sorting_analyzer_dense, + displayed_unit_properties=[], + backend=backend, + **self.backend_kwargs[backend], + ) + sw.plot_sorting_summary( + self.sorting_analyzer_sparse, + displayed_unit_properties=[], + backend=backend, + **self.backend_kwargs[backend], + ) sw.plot_sorting_summary( self.sorting_analyzer_sparse, sparsity=self.sparsity_strict, + displayed_unit_properties=[], backend=backend, **self.backend_kwargs[backend], ) - # add unit_properties + # select unit_properties sw.plot_sorting_summary( self.sorting_analyzer_sparse, - unit_table_properties=["firing_rate", "snr"], + displayed_unit_properties=["firing_rate", "snr"], backend=backend, **self.backend_kwargs[backend], ) @@ -550,7 +563,7 @@ def test_plot_sorting_summary(self): with self.assertWarns(UserWarning): sw.plot_sorting_summary( self.sorting_analyzer_sparse, - unit_table_properties=["missing_property"], + displayed_unit_properties=["missing_property"], backend=backend, **self.backend_kwargs[backend], ) diff --git a/src/spikeinterface/widgets/utils_sortingview.py b/src/spikeinterface/widgets/utils_sortingview.py index 451b1d145e..d594414287 100644 --- a/src/spikeinterface/widgets/utils_sortingview.py +++ b/src/spikeinterface/widgets/utils_sortingview.py @@ -72,6 +72,8 @@ def generate_unit_table_view( # keep only selected columns unit_properties = np.array(unit_properties) keep = np.isin(unit_properties, units_tables.columns) + if sum(keep) < len(unit_properties): + warn(f"Some unit properties are not in the sorting: {unit_properties[~keep]}") unit_properties = unit_properties[keep] units_tables = units_tables.loc[:, unit_properties] @@ -79,8 +81,6 @@ def generate_unit_table_view( ut_columns = [] for col in unit_properties: - if col not in units_tables.columns: - continue values = units_tables[col].to_numpy() if values.dtype.kind in dtype_convertor: txt_dtype = dtype_convertor[values.dtype.kind] @@ -90,8 +90,6 @@ def generate_unit_table_view( for unit_index, unit_id in enumerate(sorting.unit_ids): row_values = {} for col in unit_properties: - if col not in units_tables.columns: - continue values = units_tables[col].to_numpy() if values.dtype.kind in dtype_convertor: value = values[unit_index]