Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sc2 making use of new recording slices #3518

Merged
merged 83 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from 66 commits
Commits
Show all changes
83 commits
Select commit Hold shift + click to select a range
6b9350c
Skipping peaks if enough have been collected
yger Nov 4, 2024
da9bd9d
fix
yger Nov 4, 2024
7f89e87
Merge branch 'SpikeInterface:main' into sc2_recording_slices
yger Nov 5, 2024
7ff0d12
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 5, 2024
611170d
Docs and cleaning n_jobs within sc2
yger Nov 13, 2024
9f86617
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 13, 2024
019b399
Merge branch 'clustering_components_api' into sc2_recording_slices
yger Nov 13, 2024
dfd9bff
Merge branch 'clustering_components_api' into sc2_recording_slices
yger Nov 13, 2024
bd525b2
Exclude 32 chan probe by default for motion correction
yger Nov 13, 2024
fd9ee0b
One pass to get the prototype
yger Nov 15, 2024
8e37d54
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 15, 2024
1c8443f
Verbose flag
yger Nov 15, 2024
80cca25
WIP
yger Nov 15, 2024
801a362
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 15, 2024
48026e3
WIP
yger Nov 15, 2024
472022d
Merge branch 'sc2_recording_slices' of github.com:yger/spikeinterface…
yger Nov 15, 2024
d3bee9a
Fixes for the clustering
yger Nov 15, 2024
0865dcd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 15, 2024
f4ff9ab
WIP
yger Nov 19, 2024
28229d8
WIP
yger Nov 19, 2024
70673ce
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 19, 2024
6002da4
WIP
yger Nov 21, 2024
2310811
WIP
yger Nov 21, 2024
0054974
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 21, 2024
9c61e71
Merge branch 'main' into sc2_recording_slices
yger Nov 21, 2024
2885aeb
WIP
yger Nov 22, 2024
94cae34
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 22, 2024
3b39a05
job_kwargs propagated everywhere
yger Nov 22, 2024
c5fecf4
Merge branch 'sc2_recording_slices' of github.com:yger/spikeinterface…
yger Nov 22, 2024
28431f4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 22, 2024
24b22da
Removing noise templates
yger Nov 22, 2024
0b7674c
Merge branch 'sc2_recording_slices' of github.com:yger/spikeinterface…
yger Nov 22, 2024
ce8432a
WIP
yger Nov 22, 2024
a396e1b
Skipping waveforms re-detection if matched filtering
yger Nov 27, 2024
cb560e5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 27, 2024
7c5bce7
In case of both peak detection, waveforms should be aligned
yger Nov 27, 2024
047cac4
Merge branch 'sc2_recording_slices' of github.com:yger/spikeinterface…
yger Nov 27, 2024
94c63f3
WIP
yger Nov 27, 2024
d39874f
Merge branch 'SpikeInterface:main' into sc2_recording_slices
yger Nov 28, 2024
0d1907f
Merge branch 'sc2_recording_slices' of github.com:yger/spikeinterface…
yger Nov 28, 2024
57a9cc9
Whitening
yger Nov 29, 2024
02cad49
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 29, 2024
29fda81
Merge branch 'main' of github.com:spikeinterface/spikeinterface into …
yger Dec 4, 2024
b9af739
Merge branch 'main' of github.com:spikeinterface/spikeinterface into …
yger Dec 5, 2024
06711d1
Merge branch 'SpikeInterface:main' into sc2_recording_slices
yger Dec 11, 2024
d18c2ef
Merge branch 'sc2_recording_slices' of github.com:yger/spikeinterface…
yger Dec 19, 2024
a6e3c26
WIP
yger Dec 19, 2024
1de6503
WIP
yger Jan 8, 2025
c51fd1a
Merge branch 'main' of github.com:spikeinterface/spikeinterface into …
yger Jan 8, 2025
e4d7e72
Harmonization of get_prototype_and_waveforms
yger Jan 8, 2025
158347c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 8, 2025
437d8b4
Docstrings and cosmetics
yger Jan 8, 2025
2cad313
Docstrings
yger Jan 8, 2025
c2cacd9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 8, 2025
1272008
Merge branch 'main' into sc2_recording_slices
yger Jan 8, 2025
0e8603d
Merge branch 'main' of github.com:spikeinterface/spikeinterface into …
yger Jan 8, 2025
6119609
Merge branch 'sc2_recording_slices' of github.com:yger/spikeinterface…
yger Jan 8, 2025
fb9747d
Refactoring get_prototype
yger Jan 9, 2025
6b3aba1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 9, 2025
e839894
Docstrings
yger Jan 9, 2025
731c713
Merge branch 'sc2_recording_slices' of github.com:yger/spikeinterface…
yger Jan 9, 2025
7c32a65
Removing artefactual templates due to matched filtering
yger Jan 9, 2025
c1c8b06
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 9, 2025
f994cd8
Fix
yger Jan 9, 2025
472b99c
Merge branch 'sc2_recording_slices' of github.com:yger/spikeinterface…
yger Jan 9, 2025
7cdb239
Cleaning and adapting code for clustering
yger Jan 9, 2025
efbb9a7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 9, 2025
6d5d180
Merge branch 'SpikeInterface:main' into sc2_recording_slices
yger Jan 13, 2025
b901241
Fixing engine
yger Jan 13, 2025
70be7a4
Typos
yger Jan 13, 2025
46eba7b
Revert split related stuffs
yger Jan 13, 2025
13fa06c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 13, 2025
5c1983c
Fix
yger Jan 13, 2025
64e6e67
WIp
yger Jan 13, 2025
0ec2179
Margin_ms for filtering is too short
yger Jan 14, 2025
23dcacd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 14, 2025
827fb4d
Update src/spikeinterface/sortingcomponents/tools.py
yger Jan 14, 2025
56064df
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 14, 2025
6371b82
Merge branch 'main' into sc2_recording_slices
yger Jan 14, 2025
beb8f23
Default params
yger Jan 15, 2025
eacaa0b
Verbose
yger Jan 15, 2025
57d8716
Faster mixture merging
yger Jan 15, 2025
0ca8311
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/spikeinterface/core/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def to_sparse(self, sparsity):
unit_ids=self.unit_ids,
probe=self.probe,
check_for_consistent_sparsity=self.check_for_consistent_sparsity,
is_scaled=self.is_scaled,
)

def get_one_template_dense(self, unit_index):
Expand Down
107 changes: 74 additions & 33 deletions src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,16 @@
import numpy as np

from spikeinterface.core import NumpySorting
from spikeinterface.core.job_tools import fix_job_kwargs
from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs
from spikeinterface.core.recording_tools import get_noise_levels
from spikeinterface.core.template import Templates
from spikeinterface.core.waveform_tools import estimate_templates
from spikeinterface.preprocessing import common_reference, whiten, bandpass_filter, correct_motion
from spikeinterface.sortingcomponents.tools import cache_preprocessing
from spikeinterface.sortingcomponents.tools import (
cache_preprocessing,
get_prototype_and_waveforms_from_recording,
get_shuffled_recording_slices,
)
from spikeinterface.core.basesorting import minimum_spike_dtype
from spikeinterface.core.sparsity import compute_sparsity
from spikeinterface.core.sortinganalyzer import create_sorting_analyzer
Expand All @@ -24,7 +28,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
sorter_name = "spykingcircus2"

_default_params = {
"general": {"ms_before": 2, "ms_after": 2, "radius_um": 75},
"general": {"ms_before": 2, "ms_after": 2, "radius_um": 50},
"sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25},
"filtering": {"freq_min": 150, "freq_max": 7000, "ftype": "bessel", "filter_order": 2},
"whitening": {"mode": "local", "regularize": False},
Expand All @@ -47,12 +51,13 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
},
},
"clustering": {"legacy": True},
"matching": {"method": "circus-omp-svd"},
"matching": {"method": "wobble"},
yger marked this conversation as resolved.
Show resolved Hide resolved
"apply_preprocessing": True,
"matched_filtering": True,
"cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True},
"multi_units_only": False,
"job_kwargs": {"n_jobs": 0.5},
"seed": 42,
"debug": False,
}

Expand All @@ -74,18 +79,21 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
"merging": "A dictionary to specify the final merging param to group cells after template matching (get_potential_auto_merge)",
"motion_correction": "A dictionary to be provided if motion correction has to be performed (dense probe only)",
"apply_preprocessing": "Boolean to specify whether circus 2 should preprocess the recording or not. If yes, then high_pass filtering + common\
median reference + zscore",
median reference + whitening",
"apply_motion_correction": "Boolean to specify whether circus 2 should apply motion correction to the recording or not",
"matched_filtering": "Boolean to specify whether circus 2 should detect peaks via matched filtering (slightly slower)",
"cache_preprocessing": "How to cache the preprocessed recording. Mode can be memory, file, zarr, with extra arguments. In case of memory (default), \
memory_limit will control how much RAM can be used. In case of folder or zarr, delete_cache controls if cache is cleaned after sorting",
"multi_units_only": "Boolean to get only multi units activity (i.e. one template per electrode)",
"job_kwargs": "A dictionary to specify how many jobs and which parameters they should used",
"seed": "An int to control how chunks are shuffled while detecting peaks",
"debug": "Boolean to specify if internal data structures made during the sorting should be kept for debugging",
}

sorter_description = """Spyking Circus 2 is a rewriting of Spyking Circus, within the SpikeInterface framework
It uses a more conservative clustering algorithm (compared to Spyking Circus), which is less prone to hallucinate units and/or find noise.
In addition, it also uses a full Orthogonal Matching Pursuit engine to reconstruct the traces, leading to more spikes
being discovered."""
being discovered. The code is much faster and memory efficient, inheriting from all the preprocessing possibilities of spikeinterface"""

@classmethod
def get_sorter_version(cls):
Expand Down Expand Up @@ -114,7 +122,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
from spikeinterface.sortingcomponents.clustering import find_cluster_from_peaks
from spikeinterface.sortingcomponents.matching import find_spikes_from_templates
from spikeinterface.sortingcomponents.tools import remove_empty_templates
from spikeinterface.sortingcomponents.tools import get_prototype_spike, check_probe_for_drift_correction
from spikeinterface.sortingcomponents.tools import check_probe_for_drift_correction

job_kwargs = fix_job_kwargs(params["job_kwargs"])
job_kwargs.update({"progress_bar": verbose})
Expand All @@ -125,16 +133,20 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
num_channels = recording.get_num_channels()
ms_before = params["general"].get("ms_before", 2)
ms_after = params["general"].get("ms_after", 2)
radius_um = params["general"].get("radius_um", 75)
radius_um = params["general"].get("radius_um", 50)
exclude_sweep_ms = params["detection"].get("exclude_sweep_ms", max(ms_before, ms_after) / 2)

## First, we are filtering the data
filtering_params = params["filtering"].copy()
if params["apply_preprocessing"]:
if verbose:
print("Preprocessing the recording (bandpass filtering + CMR + whitening)")
recording_f = bandpass_filter(recording, **filtering_params, dtype="float32")
if num_channels > 1:
recording_f = common_reference(recording_f)
else:
if verbose:
print("Skipping preprocessing (whitening only)")
recording_f = recording
recording_f.annotate(is_filtered=True)

Expand All @@ -157,12 +169,14 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
# TODO add , regularize=True chen ready
whitening_kwargs = params["whitening"].copy()
whitening_kwargs["dtype"] = "float32"
whitening_kwargs["radius_um"] = radius_um
whitening_kwargs["regularize"] = whitening_kwargs.get("regularize", False)
if num_channels == 1:
whitening_kwargs["regularize"] = False
if whitening_kwargs["regularize"]:
whitening_kwargs["regularize_kwargs"] = {"method": "LedoitWolf"}

recording_w = whiten(recording_f, **whitening_kwargs)
noise_levels = get_noise_levels(recording_w, return_scaled=False)
noise_levels = get_noise_levels(recording_w, return_scaled=False, **job_kwargs)

if recording_w.check_serializability("json"):
recording_w.dump(sorter_output_folder / "preprocessed_recording.json", relative_to=None)
Expand All @@ -173,24 +187,53 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):

## Then, we are detecting peaks with a locally_exclusive method
detection_params = params["detection"].copy()
detection_params.update(job_kwargs)

detection_params["radius_um"] = detection_params.get("radius_um", 50)
selection_params = params["selection"].copy()
detection_params["radius_um"] = radius_um
detection_params["exclude_sweep_ms"] = exclude_sweep_ms
detection_params["noise_levels"] = noise_levels

fs = recording_w.get_sampling_frequency()
nbefore = int(ms_before * fs / 1000.0)
nafter = int(ms_after * fs / 1000.0)

skip_peaks = not params["multi_units_only"] and selection_params.get("method", "uniform") == "uniform"
max_n_peaks = selection_params["n_peaks_per_channel"] * num_channels
n_peaks = max(selection_params["min_n_peaks"], max_n_peaks)

if params["debug"]:
clustering_folder = sorter_output_folder / "clustering"
clustering_folder.mkdir(parents=True, exist_ok=True)
np.save(clustering_folder / "noise_levels.npy", noise_levels)

if params["matched_filtering"]:
peaks = detect_peaks(recording_w, "locally_exclusive", **detection_params, skip_after_n_peaks=5000)
prototype = get_prototype_spike(recording_w, peaks, ms_before, ms_after, **job_kwargs)
prototype, waveforms, _ = get_prototype_and_waveforms_from_recording(
recording_w,
n_peaks=10000,
ms_before=ms_before,
ms_after=ms_after,
seed=params["seed"],
**detection_params,
**job_kwargs,
)
detection_params["prototype"] = prototype
detection_params["ms_before"] = ms_before
peaks = detect_peaks(recording_w, "matched_filtering", **detection_params)
if params["debug"]:
np.save(clustering_folder / "waveforms.npy", waveforms)
np.save(clustering_folder / "prototype.npy", prototype)
if skip_peaks:
detection_params["skip_after_n_peaks"] = n_peaks
detection_params["recording_slices"] = get_shuffled_recording_slices(
recording_w, seed=params["seed"], **job_kwargs
)
peaks = detect_peaks(recording_w, "matched_filtering", **detection_params, **job_kwargs)
else:
peaks = detect_peaks(recording_w, "locally_exclusive", **detection_params)
waveforms = None
if skip_peaks:
detection_params["skip_after_n_peaks"] = n_peaks
detection_params["recording_slices"] = get_shuffled_recording_slices(
recording_w, seed=params["seed"], **job_kwargs
)
peaks = detect_peaks(recording_w, "locally_exclusive", **detection_params, **job_kwargs)

if verbose:
print("We found %d peaks in total" % len(peaks))
Expand All @@ -201,9 +244,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
## We subselect a subset of all the peaks, by making the distributions os SNRs over all
## channels as flat as possible
selection_params = params["selection"]
selection_params["n_peaks"] = min(len(peaks), selection_params["n_peaks_per_channel"] * num_channels)
selection_params["n_peaks"] = max(selection_params["min_n_peaks"], selection_params["n_peaks"])

selection_params["n_peaks"] = n_peaks
selection_params.update({"noise_levels": noise_levels})
selected_peaks = select_peaks(peaks, **selection_params)

Expand All @@ -218,10 +259,13 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
clustering_params["radius_um"] = radius_um
clustering_params["waveforms"]["ms_before"] = ms_before
clustering_params["waveforms"]["ms_after"] = ms_after
clustering_params["few_waveforms"] = waveforms
clustering_params["noise_levels"] = noise_levels
clustering_params["ms_before"] = exclude_sweep_ms
clustering_params["ms_after"] = exclude_sweep_ms
clustering_params["ms_before"] = ms_before
clustering_params["ms_after"] = ms_after
clustering_params["verbose"] = verbose
clustering_params["tmp_folder"] = sorter_output_folder / "clustering"
clustering_params["noise_threshold"] = detection_params.get("detect_threshold", 4)

legacy = clustering_params.get("legacy", True)

Expand All @@ -246,12 +290,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
unit_ids = np.arange(len(np.unique(labeled_peaks["unit_index"])))
sorting = NumpySorting(labeled_peaks, sampling_frequency, unit_ids=unit_ids)

clustering_folder = sorter_output_folder / "clustering"
clustering_folder.mkdir(parents=True, exist_ok=True)

if not params["debug"]:
shutil.rmtree(clustering_folder)
else:
if params["debug"]:
np.save(clustering_folder / "peak_labels", peak_labels)
np.save(clustering_folder / "labels", labels)
np.save(clustering_folder / "peaks", selected_peaks)

Expand Down Expand Up @@ -334,7 +374,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
sorting.save(folder=curation_folder)
# np.save(fitting_folder / "amplitudes", guessed_amplitudes)

sorting = final_cleaning_circus(recording_w, sorting, templates, **merging_params)
sorting = final_cleaning_circus(recording_w, sorting, templates, merging_params, **job_kwargs)

if verbose:
print(f"Final merging, keeping {len(sorting.unit_ids)} units")
Expand Down Expand Up @@ -376,17 +416,18 @@ def create_sorting_analyzer_with_templates(sorting, recording, templates, remove
return sa


def final_cleaning_circus(recording, sorting, templates, **merging_kwargs):
def final_cleaning_circus(recording, sorting, templates, merging_kwargs, **job_kwargs):

from spikeinterface.core.sorting_tools import apply_merges_to_sorting

sa = create_sorting_analyzer_with_templates(sorting, recording, templates)

sa.compute("unit_locations", method="monopolar_triangulation")
sa.compute("unit_locations", method="monopolar_triangulation", **job_kwargs)
similarity_kwargs = merging_kwargs.pop("similarity_kwargs", {})
sa.compute("template_similarity", **similarity_kwargs)
sa.compute("template_similarity", **similarity_kwargs, **job_kwargs)
correlograms_kwargs = merging_kwargs.pop("correlograms_kwargs", {})
sa.compute("correlograms", **correlograms_kwargs)
sa.compute("correlograms", **correlograms_kwargs, **job_kwargs)

auto_merge_kwargs = merging_kwargs.pop("auto_merge", {})
merges = get_potential_auto_merge(sa, resolve_graph=True, **auto_merge_kwargs)
sorting = apply_merges_to_sorting(sa.sorting, merges)
Expand Down
Loading
Loading