Skip to content

Commit

Permalink
WIP: fetch all spikes in waveform fetch
Browse files Browse the repository at this point in the history
  • Loading branch information
CBroz1 committed Jan 14, 2025
1 parent 9d11d80 commit 4eaebf8
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 38 deletions.
71 changes: 37 additions & 34 deletions notebooks/_TEMP_Burst.ipynb

Large diffs are not rendered by default.

10 changes: 8 additions & 2 deletions src/spyglass/spikesorting/v1/burst_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,10 @@ def _get_peak_amps1(

def _truncate_to_shortest(self, msg="", *args):
"""Truncate all arrays to the shortest length"""
if msg and not all([len(a) == len(args[0]) for a in args]):
mismatch = not all([len(a) == len(args[0]) for a in args])
if not mismatch:
return args
if msg and mismatch:
logger.warning(f"Truncating arrays to shortest length: {msg}")
min_len = min([len(a) for a in args])
return [a[:min_len] for a in args]
Expand All @@ -183,7 +186,9 @@ def get_peak_amps(
if cached := self._peak_amp_cache.get(key_hash):
return cached

waves = MetricCuration().get_waveforms(key, overwrite=False)
waves = MetricCuration().get_waveforms(
key, overwrite=False, fetch_all=True
)

curation_key = self._curation_key(key)
sorting = CurationV1.get_sorting(curation_key, as_dataframe=True)
Expand Down Expand Up @@ -550,6 +555,7 @@ def plot_1peak_over_time(
)

# PROBLEM: example key showed sub_ind larger than voltages
# SOLVED: fetch waveforms with "max_spikes_per_unit" as None
def select_voltages(voltages, sub_ind):
if len(sub_ind) > len(voltages):
sub_ind = sub_ind[: len(voltages)]
Expand Down
21 changes: 19 additions & 2 deletions src/spyglass/spikesorting/v1/metric_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,21 @@ def make(self, key):
AnalysisNwbfile().log(key, table=self.full_table_name)
self.insert1(key)

def get_waveforms(self, key: dict, overwrite: bool = True):
"""Returns waveforms identified by metric curation."""
def get_waveforms(
self, key: dict, overwrite: bool = True, fetch_all: bool = False
):
"""Returns waveforms identified by metric curation.
Parameters
----------
key : dict
primary key to MetricCuration
overwrite : bool, optional
whether to overwrite existing waveforms, by default True
fetch_all : bool, optional
fetch all spikes for units, by default False. Overrides
max_spikes_per_unit in waveform_params
"""
key_hash = dj.hash.key_hash(key)
if cached := self._waves_cache.get(key_hash):
return cached
Expand All @@ -311,6 +324,10 @@ def get_waveforms(self, key: dict, overwrite: bool = True):
if not any(wf_dir_obj.iterdir()): # if the directory is empty
overwrite = True

if fetch_all:
waveform_params["max_spikes_per_unit"] = None
waveforms_dir += "_all" # TODO: would it be better to overwrite?

# Extract non-sparse waveforms by default
waveform_params.setdefault("sparse", False)
waveforms = si.extract_waveforms(
Expand Down

0 comments on commit 4eaebf8

Please sign in to comment.