Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow quality and template metrics in sortingview's unit table #3299

Merged
merged 4 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/spikeinterface/widgets/sorting_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,12 @@ class SortingSummaryWidget(BaseWidget):
List of labels to be added to the curation table
(sortingview backend)
unit_table_properties : list or None, default: None
List of properties to be added to the unit table
List of properties to be added to the unit table.
These may be drawn from the sorting extractor, and, if available,
the quality_metrics and template_metrics extensions of the SortingAnalyzer.
See all properties available with sorting.get_property_keys(), and, if available,
analyzer.get_extension("quality_metrics").get_data().columns and
analyzer.get_extension("template_metrics").get_data().columns.
(sortingview backend)
"""

Expand Down Expand Up @@ -151,7 +156,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs):

# unit ids
v_units_table = generate_unit_table_view(
dp.sorting_analyzer.sorting, dp.unit_table_properties, similarity_scores=similarity_scores
dp.sorting_analyzer, dp.unit_table_properties, similarity_scores=similarity_scores
)

if dp.curation:
Expand Down
59 changes: 55 additions & 4 deletions src/spikeinterface/widgets/utils_sortingview.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np

from ..core.core_tools import check_json
from warnings import warn


def make_serializable(*args):
Expand Down Expand Up @@ -45,9 +46,33 @@ def handle_display_and_url(widget, view, **backend_kwargs):
return url


def generate_unit_table_view(sorting, unit_properties=None, similarity_scores=None):
def generate_unit_table_view(analyzer, unit_properties=None, similarity_scores=None):
import sortingview.views as vv

sorting = analyzer.sorting

# Find available unit properties from all sources
sorting_props = list(sorting.get_property_keys())
if analyzer.get_extension("quality_metrics") is not None:
qm_props = list(analyzer.get_extension("quality_metrics").get_data().columns)
qm_data = analyzer.get_extension("quality_metrics").get_data()
else:
qm_props = []
if analyzer.get_extension("template_metrics") is not None:
tm_props = list(analyzer.get_extension("template_metrics").get_data().columns)
tm_data = analyzer.get_extension("template_metrics").get_data()
else:
tm_props = []

# Check for any overlaps and warn user if any
all_props = sorting_props + qm_props + tm_props
overlap_props = [prop for prop in all_props if all_props.count(prop) > 1]
if len(overlap_props) > 0:
warn(
f"Warning: Overlapping properties found in sorting, quality_metrics, and template_metrics: {overlap_props}"
)

# Get unit properties
if unit_properties is None:
ut_columns = []
ut_rows = [vv.UnitsTableRow(unit_id=u, values={}) for u in sorting.unit_ids]
Expand All @@ -56,8 +81,20 @@ def generate_unit_table_view(sorting, unit_properties=None, similarity_scores=No
ut_rows = []
values = {}
valid_unit_properties = []

# Create columns for each property
for prop_name in unit_properties:
property_values = sorting.get_property(prop_name)

# Get property values from correct location
if prop_name in sorting_props:
property_values = sorting.get_property(prop_name)
elif prop_name in qm_props:
property_values = qm_data[prop_name].values
elif prop_name in tm_props:
property_values = tm_data[prop_name].values
else:
raise ValueError(f"Property '{prop_name}' not found in sorting, quality_metrics, or template_metrics")

# make dtype available
val0 = np.array(property_values[0])
if val0.dtype.kind in ("i", "u"):
Expand All @@ -74,14 +111,28 @@ def generate_unit_table_view(sorting, unit_properties=None, similarity_scores=No
ut_columns.append(vv.UnitsTableColumn(key=prop_name, label=prop_name, dtype=dtype))
valid_unit_properties.append(prop_name)

# Create rows for each unit
for ui, unit in enumerate(sorting.unit_ids):
for prop_name in valid_unit_properties:
property_values = sorting.get_property(prop_name)

# Get property values from correct location
if prop_name in sorting_props:
property_values = sorting.get_property(prop_name)
elif prop_name in qm_props:
property_values = qm_data[prop_name].values
elif prop_name in tm_props:
property_values = tm_data[prop_name].values
else:
raise ValueError(
f"Property '{prop_name}' not found in sorting, quality_metrics, or template_metrics"
)

# Check for NaN values
val0 = np.array(property_values[0])
if val0.dtype.kind == "f":
if np.isnan(property_values[ui]):
continue
values[prop_name] = property_values[ui]
values[prop_name] = np.format_float_positional(property_values[ui], precision=4, fractional=False)
ut_rows.append(vv.UnitsTableRow(unit_id=unit, values=check_json(values)))

v_units_table = vv.UnitsTable(rows=ut_rows, columns=ut_columns, similarity_scores=similarity_scores)
Expand Down