Skip to content

Commit

Permalink
Merge pull request #2549 from DradeAW/fix_export_phy_templates
Browse files Browse the repository at this point in the history
Allow `export_to_phy` to work with `fast_templates`
  • Loading branch information
alejoe91 authored Mar 27, 2024
2 parents 56a0eea + 645837e commit 93f9cc5
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 5 deletions.
77 changes: 74 additions & 3 deletions src/spikeinterface/core/analyzer_extension_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,10 +399,10 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save
----------
unit_ids: list or None
Unit ids to retrieve waveforms for
mode: "average" | "median" | "std" | "percentile", default: "average"
The mode to compute the templates
operator: "average" | "median" | "std" | "percentile", default: "average"
The operator to compute the templates
percentile: float, default: None
Percentile to use for mode="percentile"
Percentile to use for operator="percentile"
save: bool, default True
In case, the operator is not computed yet it can be saved to folder or zarr.
Expand Down Expand Up @@ -437,6 +437,28 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save

return np.array(templates)

def get_unit_template(self, unit_id, operator="average"):
"""
Return template for a single unit.
Parameters
----------
unit_id: str | int
Unit id to retrieve waveforms for
operator: str
The operator to compute the templates
Returns
-------
template: np.array
The returned template (num_samples, num_channels)
"""

templates = self.data[operator]
unit_index = self.sorting_analyzer.sorting.id_to_index(unit_id)

return np.array(templates[unit_index, :, :])


compute_templates = ComputeTemplates.function_factory()
register_result_extension(ComputeTemplates)
Expand Down Expand Up @@ -522,6 +544,55 @@ def _select_extension_data(self, unit_ids):

return new_data

def get_templates(self, unit_ids=None, operator="average"):
"""
Return average templates for multiple units.
Parameters
----------
unit_ids: list or None, default: None
Unit ids to retrieve waveforms for
operator: str
MUST be "average" (only one supported by fast_templates)
The argument exist to have the same signature as ComputeTemplates.get_templates
Returns
-------
templates: np.array
The returned templates (num_units, num_samples, num_channels)
"""

assert (
operator == "average"
), f"Analyzer extension `fast_templates` only works with 'average' templates. Given operator = {operator}"
templates = self.data["average"]

if unit_ids is not None:
unit_indices = self.sorting_analyzer.sorting.ids_to_indices(unit_ids)
templates = templates[unit_indices, :, :]

return np.array(templates)

def get_unit_template(self, unit_id):
"""
Return average template for a single unit.
Parameters
----------
unit_id: str | int
Unit id to retrieve waveforms for
Returns
-------
template: np.array
The returned template (num_samples, num_channels)
"""

templates = self.data["average"]
unit_index = self.sorting_analyzer.sorting.id_to_index(unit_id)

return np.array(templates[unit_index, :, :])


compute_fast_templates = ComputeFastTemplates.function_factory()
register_result_extension(ComputeFastTemplates)
Expand Down
6 changes: 4 additions & 2 deletions src/spikeinterface/exporters/to_phy.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,10 @@ def export_to_phy(

# export templates/templates_ind/similar_templates
# shape (num_units, num_samples, max_num_channels)
templates_ext = sorting_analyzer.get_extension("templates")
templates_ext is not None, "export_to_phy need SortingAnalyzer with extension 'templates'"
templates_ext = sorting_analyzer.get_extension("templates") or sorting_analyzer.get_extension("fast_templates")
assert (
templates_ext is not None
), "export_to_phy need SortingAnalyzer with extension 'templates' or 'fast_templates'"
max_num_channels = max(len(chan_inds) for chan_inds in sparse_dict.values())
dense_templates = templates_ext.get_templates(unit_ids=unit_ids, operator=template_mode)
num_samples = dense_templates.shape[1]
Expand Down

0 comments on commit 93f9cc5

Please sign in to comment.