From 91c1105d162bd288e054d56ef01a4cde9f3871df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Wed, 6 Mar 2024 14:03:15 +0100 Subject: [PATCH 01/13] Allows `export_to_phy` with `fast_templates` --- .../core/analyzer_extension_core.py | 34 +++++++++++++++++-- src/spikeinterface/exporters/to_phy.py | 4 ++- 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 268513dac8..c288be53bc 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -397,10 +397,10 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save ---------- unit_ids: list or None Unit ids to retrieve waveforms for - mode: "average" | "median" | "std" | "percentile", default: "average" - The mode to compute the templates + operator: "average" | "median" | "std" | "percentile", default: "average" + The operator to compute the templates percentile: float, default: None - Percentile to use for mode="percentile" + Percentile to use for operator="percentile" save: bool, default True In case, the operator is not computed yet it can be saved to folder or zarr. @@ -520,6 +520,34 @@ def _select_extension_data(self, unit_ids): return new_data + def get_templates(self, unit_ids=None, save=True): + """ + Return average templates for multiple units. + + Parameters + ---------- + unit_ids: list or None + Unit ids to retrieve waveforms for + save: bool, default True + In case, the operator is not computed yet it can be saved to folder or zarr. + + Returns + ------- + templates: np.array + The returned templates (num_units, num_samples, num_channels) + """ + + templates = self.data['average'] + + if save: + self.save() + + if unit_ids is not None: + unit_indices = self.sorting_analyzer.sorting.ids_to_indices(unit_ids) + templates = templates[unit_indices, :, :] + + return np.array(templates) + compute_fast_templates = ComputeFastTemplates.function_factory() register_result_extension(ComputeFastTemplates) diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index 30f74e584b..c307319744 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -181,7 +181,9 @@ def export_to_phy( # export templates/templates_ind/similar_templates # shape (num_units, num_samples, max_num_channels) templates_ext = sorting_analyzer.get_extension("templates") - templates_ext is not None, "export_to_phy need SortingAnalyzer with extension 'templates'" + if templates_ext is None: + templates_ext = sorting_analyzer.get_extension("fast_templates") + assert templates_ext is not None, "export_to_phy need SortingAnalyzer with extension 'templates' or 'fast_templates'" max_num_channels = max(len(chan_inds) for chan_inds in sparse_dict.values()) dense_templates = templates_ext.get_templates(unit_ids=unit_ids, operator=template_mode) num_samples = dense_templates.shape[1] From df6757841f1c446e43e199ec80db70c08dfafda3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Wed, 6 Mar 2024 14:11:25 +0100 Subject: [PATCH 02/13] Fixed bug --- src/spikeinterface/exporters/to_phy.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index c307319744..fed00d7442 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -183,9 +183,13 @@ def export_to_phy( templates_ext = sorting_analyzer.get_extension("templates") if templates_ext is None: templates_ext = sorting_analyzer.get_extension("fast_templates") + if templates_ext is not None and template_mode != "average": + assert False, "export_to_phy with SortingAnalyzer with extension 'fast_templates' can only work with template_mode='average'" + dense_templates = templates_ext.get_templates(unit_ids=unit_ids) + else: + dense_templates = templates_ext.get_templates(unit_ids=unit_ids, operator=template_mode) assert templates_ext is not None, "export_to_phy need SortingAnalyzer with extension 'templates' or 'fast_templates'" max_num_channels = max(len(chan_inds) for chan_inds in sparse_dict.values()) - dense_templates = templates_ext.get_templates(unit_ids=unit_ids, operator=template_mode) num_samples = dense_templates.shape[1] templates = np.zeros((len(unit_ids), num_samples, max_num_channels), dtype="float64") # here we pad template inds with -1 if len of sparse channels is unequal From 31f98d6c7b0fd01797d68f6735c86dcf0e04972d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 6 Mar 2024 13:12:24 +0000 Subject: [PATCH 03/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/analyzer_extension_core.py | 2 +- src/spikeinterface/exporters/to_phy.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index c288be53bc..30415afacb 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -537,7 +537,7 @@ def get_templates(self, unit_ids=None, save=True): The returned templates (num_units, num_samples, num_channels) """ - templates = self.data['average'] + templates = self.data["average"] if save: self.save() diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index fed00d7442..229a0f2677 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -184,11 +184,15 @@ def export_to_phy( if templates_ext is None: templates_ext = sorting_analyzer.get_extension("fast_templates") if templates_ext is not None and template_mode != "average": - assert False, "export_to_phy with SortingAnalyzer with extension 'fast_templates' can only work with template_mode='average'" + assert ( + False + ), "export_to_phy with SortingAnalyzer with extension 'fast_templates' can only work with template_mode='average'" dense_templates = templates_ext.get_templates(unit_ids=unit_ids) else: dense_templates = templates_ext.get_templates(unit_ids=unit_ids, operator=template_mode) - assert templates_ext is not None, "export_to_phy need SortingAnalyzer with extension 'templates' or 'fast_templates'" + assert ( + templates_ext is not None + ), "export_to_phy need SortingAnalyzer with extension 'templates' or 'fast_templates'" max_num_channels = max(len(chan_inds) for chan_inds in sparse_dict.values()) num_samples = dense_templates.shape[1] templates = np.zeros((len(unit_ids), num_samples, max_num_channels), dtype="float64") From f573c61f4012683885c68a547fe031a1c8842b46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Wed, 6 Mar 2024 14:26:54 +0100 Subject: [PATCH 04/13] Sam's suggestions + `get_unit_template` --- .../core/analyzer_extension_core.py | 27 ++++++++++++++----- src/spikeinterface/exporters/to_phy.py | 4 +-- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 30415afacb..8ecd76c8af 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -520,7 +520,7 @@ def _select_extension_data(self, unit_ids): return new_data - def get_templates(self, unit_ids=None, save=True): + def get_templates(self, unit_ids=None): """ Return average templates for multiple units. @@ -528,8 +528,6 @@ def get_templates(self, unit_ids=None, save=True): ---------- unit_ids: list or None Unit ids to retrieve waveforms for - save: bool, default True - In case, the operator is not computed yet it can be saved to folder or zarr. Returns ------- @@ -539,15 +537,32 @@ def get_templates(self, unit_ids=None, save=True): templates = self.data["average"] - if save: - self.save() - if unit_ids is not None: unit_indices = self.sorting_analyzer.sorting.ids_to_indices(unit_ids) templates = templates[unit_indices, :, :] return np.array(templates) + def get_unit_template(self, unit_id): + """ + Return average template for a single unit. + + Parameters + ---------- + unit_id: + Unit id to retrieve waveforms for + + Returns + ------- + template: np.array + The returned template (num_samples, num_channels) + """ + + templates = self.data["average"] + unit_index = self.sorting_analyzer.sorting.id_to_index(unit_id) + + return np.array(templates[unit_index, :, :]) + compute_fast_templates = ComputeFastTemplates.function_factory() register_result_extension(ComputeFastTemplates) diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index 229a0f2677..ad79223779 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -184,9 +184,7 @@ def export_to_phy( if templates_ext is None: templates_ext = sorting_analyzer.get_extension("fast_templates") if templates_ext is not None and template_mode != "average": - assert ( - False - ), "export_to_phy with SortingAnalyzer with extension 'fast_templates' can only work with template_mode='average'" + raise ValueError("export_to_phy with SortingAnalyzer with extension 'fast_templates' can only work with template_mode='average'") dense_templates = templates_ext.get_templates(unit_ids=unit_ids) else: dense_templates = templates_ext.get_templates(unit_ids=unit_ids, operator=template_mode) From faf173585fe6d4a72cdc1b3338a02eeb2f5d1dfd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 6 Mar 2024 13:28:31 +0000 Subject: [PATCH 05/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/exporters/to_phy.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index ad79223779..8dc68b801b 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -184,7 +184,9 @@ def export_to_phy( if templates_ext is None: templates_ext = sorting_analyzer.get_extension("fast_templates") if templates_ext is not None and template_mode != "average": - raise ValueError("export_to_phy with SortingAnalyzer with extension 'fast_templates' can only work with template_mode='average'") + raise ValueError( + "export_to_phy with SortingAnalyzer with extension 'fast_templates' can only work with template_mode='average'" + ) dense_templates = templates_ext.get_templates(unit_ids=unit_ids) else: dense_templates = templates_ext.get_templates(unit_ids=unit_ids, operator=template_mode) From 1217184aedb79f3cb205469aa12c76337cb86df2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Wed, 6 Mar 2024 14:34:24 +0100 Subject: [PATCH 06/13] Update src/spikeinterface/exporters/to_phy.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/exporters/to_phy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index 8dc68b801b..cf8d438622 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -192,7 +192,7 @@ def export_to_phy( dense_templates = templates_ext.get_templates(unit_ids=unit_ids, operator=template_mode) assert ( templates_ext is not None - ), "export_to_phy need SortingAnalyzer with extension 'templates' or 'fast_templates'" + ), "export_to_phy requires SortingAnalyzer with either extension 'templates' or 'fast_templates'" max_num_channels = max(len(chan_inds) for chan_inds in sparse_dict.values()) num_samples = dense_templates.shape[1] templates = np.zeros((len(unit_ids), num_samples, max_num_channels), dtype="float64") From 3f82ba11deffb590009711c2981fb487695bf711 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Wed, 6 Mar 2024 14:34:34 +0100 Subject: [PATCH 07/13] Update src/spikeinterface/core/analyzer_extension_core.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/analyzer_extension_core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 8ecd76c8af..5d48434de0 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -526,7 +526,7 @@ def get_templates(self, unit_ids=None): Parameters ---------- - unit_ids: list or None + unit_ids: list or None, default: None Unit ids to retrieve waveforms for Returns From 2611adc39d79d50109dc52b1997626fdab28ce66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Wed, 6 Mar 2024 15:49:42 +0100 Subject: [PATCH 08/13] Doc improvement --- src/spikeinterface/core/analyzer_extension_core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 5d48434de0..0ee012832c 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -549,7 +549,7 @@ def get_unit_template(self, unit_id): Parameters ---------- - unit_id: + unit_id: str | int Unit id to retrieve waveforms for Returns From 6b7769788f708743892222167beddb9b549b5790 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Wed, 27 Mar 2024 10:02:52 +0100 Subject: [PATCH 09/13] Cleaner implementation --- .../core/analyzer_extension_core.py | 6 +++++- src/spikeinterface/exporters/to_phy.py | 15 ++------------- 2 files changed, 7 insertions(+), 14 deletions(-) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 42e75336b3..3faea0f3ee 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -522,7 +522,7 @@ def _select_extension_data(self, unit_ids): return new_data - def get_templates(self, unit_ids=None): + def get_templates(self, unit_ids=None, operator="average"): """ Return average templates for multiple units. @@ -530,6 +530,9 @@ def get_templates(self, unit_ids=None): ---------- unit_ids: list or None, default: None Unit ids to retrieve waveforms for + operator: str + MUST be "average" (only one supported by fast_templates) + The argument exist to have the same signature as ComputeTemplates.get_templates Returns ------- @@ -537,6 +540,7 @@ def get_templates(self, unit_ids=None): The returned templates (num_units, num_samples, num_channels) """ + assert operator == f"average", "Analyzer extension `fast_templates` only works with 'average' templates. Given operator = {operator}" templates = self.data["average"] if unit_ids is not None: diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index cf8d438622..dc824ee775 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -180,19 +180,8 @@ def export_to_phy( # export templates/templates_ind/similar_templates # shape (num_units, num_samples, max_num_channels) - templates_ext = sorting_analyzer.get_extension("templates") - if templates_ext is None: - templates_ext = sorting_analyzer.get_extension("fast_templates") - if templates_ext is not None and template_mode != "average": - raise ValueError( - "export_to_phy with SortingAnalyzer with extension 'fast_templates' can only work with template_mode='average'" - ) - dense_templates = templates_ext.get_templates(unit_ids=unit_ids) - else: - dense_templates = templates_ext.get_templates(unit_ids=unit_ids, operator=template_mode) - assert ( - templates_ext is not None - ), "export_to_phy requires SortingAnalyzer with either extension 'templates' or 'fast_templates'" + templates_ext = sorting_analyzer.get_extension("templates") or sorting_analyzer.get_extension("fast_templates") + dense_templates = templates_ext.get_templates(unit_ids=unit_ids, operator=template_mode) max_num_channels = max(len(chan_inds) for chan_inds in sparse_dict.values()) num_samples = dense_templates.shape[1] templates = np.zeros((len(unit_ids), num_samples, max_num_channels), dtype="float64") From 4f7f15f391a275c08e01e8a7b32144a1572ca6cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Wed, 27 Mar 2024 10:04:06 +0100 Subject: [PATCH 10/13] oops --- src/spikeinterface/exporters/to_phy.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index dc824ee775..8e6d600c82 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -181,8 +181,9 @@ def export_to_phy( # export templates/templates_ind/similar_templates # shape (num_units, num_samples, max_num_channels) templates_ext = sorting_analyzer.get_extension("templates") or sorting_analyzer.get_extension("fast_templates") - dense_templates = templates_ext.get_templates(unit_ids=unit_ids, operator=template_mode) + assert templates_ext is not None, "export_to_phy need SortingAnalyzer with extension 'templates' or 'fast_templates'" max_num_channels = max(len(chan_inds) for chan_inds in sparse_dict.values()) + dense_templates = templates_ext.get_templates(unit_ids=unit_ids, operator=template_mode) num_samples = dense_templates.shape[1] templates = np.zeros((len(unit_ids), num_samples, max_num_channels), dtype="float64") # here we pad template inds with -1 if len of sparse channels is unequal From 44de825363e35df3bcbbdebbaa33e408d0c954b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Wed, 27 Mar 2024 10:04:51 +0100 Subject: [PATCH 11/13] oops --- src/spikeinterface/core/analyzer_extension_core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 3faea0f3ee..c2c8f9bc6c 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -540,7 +540,7 @@ def get_templates(self, unit_ids=None, operator="average"): The returned templates (num_units, num_samples, num_channels) """ - assert operator == f"average", "Analyzer extension `fast_templates` only works with 'average' templates. Given operator = {operator}" + assert operator == "average", f"Analyzer extension `fast_templates` only works with 'average' templates. Given operator = {operator}" templates = self.data["average"] if unit_ids is not None: From a63d7a0dd03fc154e23df2ab73c8cb73d967efcc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 27 Mar 2024 09:06:35 +0000 Subject: [PATCH 12/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/analyzer_extension_core.py | 4 +++- src/spikeinterface/exporters/to_phy.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index c2c8f9bc6c..5b343cfdae 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -540,7 +540,9 @@ def get_templates(self, unit_ids=None, operator="average"): The returned templates (num_units, num_samples, num_channels) """ - assert operator == "average", f"Analyzer extension `fast_templates` only works with 'average' templates. Given operator = {operator}" + assert ( + operator == "average" + ), f"Analyzer extension `fast_templates` only works with 'average' templates. Given operator = {operator}" templates = self.data["average"] if unit_ids is not None: diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index 8e6d600c82..52cf052ed2 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -181,7 +181,9 @@ def export_to_phy( # export templates/templates_ind/similar_templates # shape (num_units, num_samples, max_num_channels) templates_ext = sorting_analyzer.get_extension("templates") or sorting_analyzer.get_extension("fast_templates") - assert templates_ext is not None, "export_to_phy need SortingAnalyzer with extension 'templates' or 'fast_templates'" + assert ( + templates_ext is not None + ), "export_to_phy need SortingAnalyzer with extension 'templates' or 'fast_templates'" max_num_channels = max(len(chan_inds) for chan_inds in sparse_dict.values()) dense_templates = templates_ext.get_templates(unit_ids=unit_ids, operator=template_mode) num_samples = dense_templates.shape[1] From 645837eb640240ceb6905cd52e39edeb7e2b69f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Wed, 27 Mar 2024 10:47:49 +0100 Subject: [PATCH 13/13] Added `get_unit_template` to classic templates extension --- .../core/analyzer_extension_core.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 5b343cfdae..87c6b8acbc 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -437,6 +437,28 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save return np.array(templates) + def get_unit_template(self, unit_id, operator="average"): + """ + Return template for a single unit. + + Parameters + ---------- + unit_id: str | int + Unit id to retrieve waveforms for + operator: str + The operator to compute the templates + + Returns + ------- + template: np.array + The returned template (num_samples, num_channels) + """ + + templates = self.data[operator] + unit_index = self.sorting_analyzer.sorting.id_to_index(unit_id) + + return np.array(templates[unit_index, :, :]) + compute_templates = ComputeTemplates.function_factory() register_result_extension(ComputeTemplates)