From 4831ad0851e71910b3fce9e2630b6c7cc72190d3 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Sat, 18 Jan 2025 16:16:33 +0100 Subject: [PATCH 1/3] Use pd.concat instead of df.append --- src/spikeinterface/widgets/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/widgets/utils.py b/src/spikeinterface/widgets/utils.py index 7d5cf98c01..ae1dce8571 100644 --- a/src/spikeinterface/widgets/utils.py +++ b/src/spikeinterface/widgets/utils.py @@ -309,15 +309,15 @@ def make_units_table_from_analyzer( if analyzer.get_extension("unit_locations") is not None: locs = analyzer.get_extension("unit_locations").get_data() df = pd.DataFrame(locs[:, :2], columns=["x", "y"], index=analyzer.unit_ids) - all_df.append(df) + all_df = pd.concat([all_df, df]) if analyzer.get_extension("quality_metrics") is not None: df = analyzer.get_extension("quality_metrics").get_data() - all_df.append(df) + all_df = pd.concat([all_df, df]) if analyzer.get_extension("template_metrics") is not None: all_df = analyzer.get_extension("template_metrics").get_data() - all_df.append(df) + all_df = pd.concat([all_df, df]) if len(all_df) > 0: units_table = pd.concat(all_df, axis=1) From 2ba92b7c7987cb12142556c0cbeb8f34bccb6cb3 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Sat, 18 Jan 2025 16:26:19 +0100 Subject: [PATCH 2/3] Fix template metrics sv --- src/spikeinterface/widgets/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/widgets/utils.py b/src/spikeinterface/widgets/utils.py index ae1dce8571..75c6248f0f 100644 --- a/src/spikeinterface/widgets/utils.py +++ b/src/spikeinterface/widgets/utils.py @@ -309,15 +309,15 @@ def make_units_table_from_analyzer( if analyzer.get_extension("unit_locations") is not None: locs = analyzer.get_extension("unit_locations").get_data() df = pd.DataFrame(locs[:, :2], columns=["x", "y"], index=analyzer.unit_ids) - all_df = pd.concat([all_df, df]) + all_df.append(df) if analyzer.get_extension("quality_metrics") is not None: df = analyzer.get_extension("quality_metrics").get_data() - all_df = pd.concat([all_df, df]) + all_df.append(df) if analyzer.get_extension("template_metrics") is not None: - all_df = analyzer.get_extension("template_metrics").get_data() - all_df = pd.concat([all_df, df]) + df = analyzer.get_extension("template_metrics").get_data() + all_df.append(df) if len(all_df) > 0: units_table = pd.concat(all_df, axis=1) From 6cf86f958331dc7965206568b0cf4e9b0772149d Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Sat, 18 Jan 2025 17:31:54 +0100 Subject: [PATCH 3/3] Fix tests and warning --- src/spikeinterface/widgets/sorting_summary.py | 2 +- .../widgets/tests/test_widgets.py | 25 ++++++++++++++----- .../widgets/utils_sortingview.py | 6 ++--- 3 files changed, 22 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/widgets/sorting_summary.py b/src/spikeinterface/widgets/sorting_summary.py index 8587830862..8eada29b0e 100644 --- a/src/spikeinterface/widgets/sorting_summary.py +++ b/src/spikeinterface/widgets/sorting_summary.py @@ -16,7 +16,7 @@ from ..core import SortingAnalyzer -_default_displayed_unit_properties = ["firing_rate", "num_spikes", "x", "y", "amplitude", "snr", "rp_violation"] +_default_displayed_unit_properties = ["firing_rate", "num_spikes", "x", "y", "amplitude_median", "snr", "rp_violation"] class SortingSummaryWidget(BaseWidget): diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index b723a7ca9f..d5ffec6dba 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -73,7 +73,9 @@ def setUpClass(cls): spike_amplitudes=dict(), unit_locations=dict(), spike_locations=dict(), - quality_metrics=dict(metric_names=["snr", "isi_violation", "num_spikes", "amplitude_cutoff"]), + quality_metrics=dict( + metric_names=["snr", "isi_violation", "num_spikes", "firing_rate", "amplitude_cutoff"] + ), template_metrics=dict(), correlograms=dict(), template_similarity=dict(), @@ -531,18 +533,29 @@ def test_plot_sorting_summary(self): possible_backends = list(sw.SortingSummaryWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - sw.plot_sorting_summary(self.sorting_analyzer_dense, backend=backend, **self.backend_kwargs[backend]) - sw.plot_sorting_summary(self.sorting_analyzer_sparse, backend=backend, **self.backend_kwargs[backend]) + sw.plot_sorting_summary( + self.sorting_analyzer_dense, + displayed_unit_properties=[], + backend=backend, + **self.backend_kwargs[backend], + ) + sw.plot_sorting_summary( + self.sorting_analyzer_sparse, + displayed_unit_properties=[], + backend=backend, + **self.backend_kwargs[backend], + ) sw.plot_sorting_summary( self.sorting_analyzer_sparse, sparsity=self.sparsity_strict, + displayed_unit_properties=[], backend=backend, **self.backend_kwargs[backend], ) - # add unit_properties + # select unit_properties sw.plot_sorting_summary( self.sorting_analyzer_sparse, - unit_table_properties=["firing_rate", "snr"], + displayed_unit_properties=["firing_rate", "snr"], backend=backend, **self.backend_kwargs[backend], ) @@ -550,7 +563,7 @@ def test_plot_sorting_summary(self): with self.assertWarns(UserWarning): sw.plot_sorting_summary( self.sorting_analyzer_sparse, - unit_table_properties=["missing_property"], + displayed_unit_properties=["missing_property"], backend=backend, **self.backend_kwargs[backend], ) diff --git a/src/spikeinterface/widgets/utils_sortingview.py b/src/spikeinterface/widgets/utils_sortingview.py index 451b1d145e..d594414287 100644 --- a/src/spikeinterface/widgets/utils_sortingview.py +++ b/src/spikeinterface/widgets/utils_sortingview.py @@ -72,6 +72,8 @@ def generate_unit_table_view( # keep only selected columns unit_properties = np.array(unit_properties) keep = np.isin(unit_properties, units_tables.columns) + if sum(keep) < len(unit_properties): + warn(f"Some unit properties are not in the sorting: {unit_properties[~keep]}") unit_properties = unit_properties[keep] units_tables = units_tables.loc[:, unit_properties] @@ -79,8 +81,6 @@ def generate_unit_table_view( ut_columns = [] for col in unit_properties: - if col not in units_tables.columns: - continue values = units_tables[col].to_numpy() if values.dtype.kind in dtype_convertor: txt_dtype = dtype_convertor[values.dtype.kind] @@ -90,8 +90,6 @@ def generate_unit_table_view( for unit_index, unit_id in enumerate(sorting.unit_ids): row_values = {} for col in unit_properties: - if col not in units_tables.columns: - continue values = units_tables[col].to_numpy() if values.dtype.kind in dtype_convertor: value = values[unit_index]