Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dtype of quality metrics before and after merging #3497

Merged
merged 23 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion src/spikeinterface/qualitymetrics/quality_metric_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
_misc_metric_name_to_func,
_possible_pc_metric_names,
qm_compute_name_to_column_names,
column_name_to_column_dtype,
)
from .misc_metrics import _default_params as misc_metrics_params
from .pca_metrics import _default_params as pca_metrics_params
Expand Down Expand Up @@ -140,13 +141,20 @@ def _merge_extension_data(
all_unit_ids = new_sorting_analyzer.unit_ids
not_new_ids = all_unit_ids[~np.isin(all_unit_ids, new_unit_ids)]

# this creates a new metrics dictionary, but the dtype for everything will be
# object. So we will need to fix this later after computing metrics
metrics = pd.DataFrame(index=all_unit_ids, columns=old_metrics.columns)

metrics.loc[not_new_ids, :] = old_metrics.loc[not_new_ids, :]
metrics.loc[new_unit_ids, :] = self._compute_metrics(
new_sorting_analyzer, new_unit_ids, verbose, metric_names, **job_kwargs
)

# we need to fix the dtypes after we compute everything because we have nans
# we can iterate through the columns and convert them back to the dtype
# of the original quality dataframe.
for column in old_metrics.columns:
metrics[column] = metrics[column].astype(old_metrics[column].dtype)

new_data = dict(metrics=metrics)
return new_data

Expand Down Expand Up @@ -229,10 +237,20 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri
# add NaN for empty units
if len(empty_unit_ids) > 0:
metrics.loc[empty_unit_ids] = np.nan
# num_spikes is an int and should be 0
if "num_spikes" in metrics.columns:
metrics.loc[empty_unit_ids, ["num_spikes"]] = 0

# we use the convert_dtypes to convert the columns to the most appropriate dtype and avoid object columns
# (in case of NaN values)
metrics = metrics.convert_dtypes()

# we do this because the convert_dtypes infers the wrong types sometimes.
# the actual types for columns can be found in column_name_to_column_dtype dictionary.
for column in metrics.columns:
if column in column_name_to_column_dtype:
metrics[column] = metrics[column].astype(column_name_to_column_dtype[column])

return metrics

def _run(self, verbose=False, **job_kwargs):
Expand Down
41 changes: 40 additions & 1 deletion src/spikeinterface/qualitymetrics/quality_metric_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@
"amplitude_cutoff": ["amplitude_cutoff"],
"amplitude_median": ["amplitude_median"],
"amplitude_cv": ["amplitude_cv_median", "amplitude_cv_range"],
"synchrony": ["sync_spike_2", "sync_spike_4", "sync_spike_8"],
"synchrony": [
"sync_spike_2",
"sync_spike_4",
"sync_spike_8",
],
"firing_range": ["firing_range"],
"drift": ["drift_ptp", "drift_std", "drift_mad"],
"sd_ratio": ["sd_ratio"],
Expand All @@ -79,3 +83,38 @@
"silhouette": ["silhouette"],
"silhouette_full": ["silhouette_full"],
}

# this dict allows us to ensure the appropriate dtype of metrics rather than allow Pandas to infer them
column_name_to_column_dtype = {
"num_spikes": int,
"firing_rate": float,
"presence_ratio": float,
"snr": float,
"isi_violations_ratio": float,
"isi_violations_count": float,
"rp_violations": float,
"rp_contamination": float,
"sliding_rp_violation": float,
"amplitude_cutoff": float,
"amplitude_median": float,
"amplitude_cv_median": float,
"amplitude_cv_range": float,
"sync_spike_2": float,
"sync_spike_4": float,
"sync_spike_8": float,
"firing_range": float,
"drift_ptp": float,
"drift_std": float,
"drift_mad": float,
"sd_ratio": float,
"isolation_distance": float,
"l_ratio": float,
"d_prime": float,
"nn_hit_rate": float,
"nn_miss_rate": float,
"nn_isolation": float,
"nn_unit_id": float,
"nn_noise_overlap": float,
"silhouette": float,
"silhouette_full": float,
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,33 @@ def test_compute_quality_metrics(sorting_analyzer_simple):
assert "isolation_distance" in metrics.columns


def test_merging_quality_metrics(sorting_analyzer_simple):

sorting_analyzer = sorting_analyzer_simple

metrics = compute_quality_metrics(
sorting_analyzer,
metric_names=None,
qm_params=dict(isi_violation=dict(isi_threshold_ms=2)),
skip_pc_metrics=False,
seed=2205,
)

# sorting_analyzer_simple has ten units
new_sorting_analyzer = sorting_analyzer.merge_units([[0, 1]])

new_metrics = new_sorting_analyzer.get_extension("quality_metrics").get_data()

# we should copy over the metrics after merge
for column in metrics.columns:
assert column in new_metrics.columns
# should copy dtype too
assert metrics[column].dtype == new_metrics[column].dtype

# 10 units vs 9 units
assert len(metrics.index) > len(new_metrics.index)


def test_compute_quality_metrics_recordingless(sorting_analyzer_simple):

sorting_analyzer = sorting_analyzer_simple
Expand Down Expand Up @@ -106,10 +133,15 @@ def test_empty_units(sorting_analyzer_simple):
seed=2205,
)

for empty_unit_id in sorting_empty.get_empty_unit_ids():
# num_spikes are ints not nans so we confirm empty units are nans for everything except
# num_spikes which should be 0
nan_containing_columns = [column for column in metrics_empty.columns if column != "num_spikes"]
for empty_unit_ids in sorting_empty.get_empty_unit_ids():
from pandas import isnull

assert np.all(isnull(metrics_empty.loc[empty_unit_id].values))
assert np.all(isnull(metrics_empty.loc[empty_unit_ids, nan_containing_columns].values))
if "num_spikes" in metrics_empty.columns:
assert sum(metrics_empty.loc[empty_unit_ids, ["num_spikes"]]) == 0


# TODO @alessio all theses old test should be moved in test_metric_functions.py or test_pca_metrics()
Expand Down
Loading