Skip to content

Commit

Permalink
Merging with main
Browse files Browse the repository at this point in the history
  • Loading branch information
yger committed Jan 15, 2025
2 parents 835ef46 + 418bb86 commit c213697
Show file tree
Hide file tree
Showing 14 changed files with 379 additions and 94 deletions.
4 changes: 3 additions & 1 deletion doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,9 @@

# for sphinx gallery plugin
sphinx_gallery_conf = {
'only_warn_on_example_error': True,
# This is the default but including here explicitly. Should build all docs and fail on gallery failures only.
# other option would be abort_on_example_error, but this fails on first failure. So we decided against this.
'only_warn_on_example_error': False,
'examples_dirs': ['../examples/tutorials'],
'gallery_dirs': ['tutorials' ], # path where to save gallery generated examples
'subsection_order': ExplicitOrder([
Expand Down
4 changes: 4 additions & 0 deletions readthedocs.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
version: 2

sphinx:
# Path to your Sphinx configuration file.
configuration: doc/conf.py

build:
os: ubuntu-22.04
tools:
Expand Down
7 changes: 7 additions & 0 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2092,6 +2092,13 @@ def load_data(self):
import pandas as pd

ext_data = pd.read_csv(ext_data_file, index_col=0)
# we need to cast the index to the unit id dtype (int or str)
unit_ids = self.sorting_analyzer.unit_ids
if ext_data.shape[0] == unit_ids.size:
# we force dtype to be the same as unit_ids
if ext_data.index.dtype != unit_ids.dtype:
ext_data.index = ext_data.index.astype(unit_ids.dtype)

elif ext_data_file.suffix == ".pkl":
with ext_data_file.open("rb") as f:
ext_data = pickle.load(f)
Expand Down
1 change: 1 addition & 0 deletions src/spikeinterface/core/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def to_sparse(self, sparsity):
unit_ids=self.unit_ids,
probe=self.probe,
check_for_consistent_sparsity=self.check_for_consistent_sparsity,
is_scaled=self.is_scaled,
)

def get_one_template_dense(self, unit_index):
Expand Down
40 changes: 24 additions & 16 deletions src/spikeinterface/curation/curation_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,16 @@ def validate_curation_dict(curation_dict):
if not removed_units_set.issubset(unit_set):
raise ValueError("Curation format: some removed units are not in the unit list")

for group in curation_dict["merge_unit_groups"]:
if len(group) < 2:
raise ValueError("Curation format: 'merge_unit_groups' must be list of list with at least 2 elements")

all_merging_groups = [set(group) for group in curation_dict["merge_unit_groups"]]
for gp_1, gp_2 in combinations(all_merging_groups, 2):
if len(gp_1.intersection(gp_2)) != 0:
raise ValueError("Some units belong to multiple merge groups")
raise ValueError("Curation format: some units belong to multiple merge groups")
if len(removed_units_set.intersection(merged_units_set)) != 0:
raise ValueError("Some units were merged and deleted")
raise ValueError("Curation format: some units were merged and deleted")

# Check the labels exclusivity
for lbl in curation_dict["manual_labels"]:
Expand Down Expand Up @@ -238,7 +242,7 @@ def apply_curation_labels(sorting, new_unit_ids, curation_dict):
all_values = np.zeros(sorting.unit_ids.size, dtype=values.dtype)
for unit_ind, unit_id in enumerate(sorting.unit_ids):
if unit_id not in new_unit_ids:
ind = curation_dict["unit_ids"].index(unit_id)
ind = list(curation_dict["unit_ids"]).index(unit_id)
all_values[unit_ind] = values[ind]
sorting.set_property(key, all_values)

Expand All @@ -253,7 +257,7 @@ def apply_curation_labels(sorting, new_unit_ids, curation_dict):
group_values.append(value)
if len(set(group_values)) == 1:
# all group has the same label or empty
sorting.set_property(key, values=group_values, ids=[new_unit_id])
sorting.set_property(key, values=group_values[:1], ids=[new_unit_id])
else:

for key in label_def["label_options"]:
Expand Down Expand Up @@ -339,18 +343,22 @@ def apply_curation(

elif isinstance(sorting_or_analyzer, SortingAnalyzer):
analyzer = sorting_or_analyzer
analyzer = analyzer.remove_units(curation_dict["removed_units"])
analyzer, new_unit_ids = analyzer.merge_units(
curation_dict["merge_unit_groups"],
censor_ms=censor_ms,
merging_mode=merging_mode,
sparsity_overlap=sparsity_overlap,
new_id_strategy=new_id_strategy,
return_new_unit_ids=True,
format="memory",
verbose=verbose,
**job_kwargs,
)
if len(curation_dict["removed_units"]) > 0:
analyzer = analyzer.remove_units(curation_dict["removed_units"])
if len(curation_dict["merge_unit_groups"]) > 0:
analyzer, new_unit_ids = analyzer.merge_units(
curation_dict["merge_unit_groups"],
censor_ms=censor_ms,
merging_mode=merging_mode,
sparsity_overlap=sparsity_overlap,
new_id_strategy=new_id_strategy,
return_new_unit_ids=True,
format="memory",
verbose=verbose,
**job_kwargs,
)
else:
new_unit_ids = []
apply_curation_labels(analyzer.sorting, new_unit_ids, curation_dict)
return analyzer
else:
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/sorters/external/kilosort4.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class Kilosort4Sorter(BaseSorter):
"do_correction": True,
"keep_good_only": False,
"skip_kilosort_preprocessing": False,
"use_binary_file": None,
"use_binary_file": True,
"delete_recording_dat": True,
}

Expand Down Expand Up @@ -116,7 +116,7 @@ class Kilosort4Sorter(BaseSorter):
"keep_good_only": "If True, only the units labeled as 'good' by Kilosort are returned in the output. (spikeinterface parameter)",
"use_binary_file": "If True then Kilosort is run using a binary file. In this case, if the input recording is not binary compatible, it is written to a binary file in the output folder. "
"If False then Kilosort is run on the recording object directly using the RecordingExtractorAsArray object. If None, then if the recording is binary compatible, the sorter will use the binary file, otherwise the RecordingExtractorAsArray. "
"Default is None. (spikeinterface parameter)",
"Default is True. (spikeinterface parameter)",
"delete_recording_dat": "If True, if a temporary binary file is created, it is deleted after the sorting is done. Default is True. (spikeinterface parameter)",
}

Expand Down
102 changes: 71 additions & 31 deletions src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,16 @@
import numpy as np

from spikeinterface.core import NumpySorting
from spikeinterface.core.job_tools import fix_job_kwargs
from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs
from spikeinterface.core.recording_tools import get_noise_levels
from spikeinterface.core.template import Templates
from spikeinterface.core.waveform_tools import estimate_templates
from spikeinterface.preprocessing import common_reference, whiten, bandpass_filter, correct_motion
from spikeinterface.sortingcomponents.tools import cache_preprocessing
from spikeinterface.sortingcomponents.tools import (
cache_preprocessing,
get_prototype_and_waveforms_from_recording,
get_shuffled_recording_slices,
)
from spikeinterface.core.basesorting import minimum_spike_dtype
from spikeinterface.core.sparsity import compute_sparsity

Expand All @@ -22,7 +26,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
_default_params = {
"general": {"ms_before": 2, "ms_after": 2, "radius_um": 75},
"sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25},
"filtering": {"freq_min": 150, "freq_max": 7000, "ftype": "bessel", "filter_order": 2},
"filtering": {"freq_min": 150, "freq_max": 7000, "ftype": "bessel", "filter_order": 2, "margin_ms": 10},
"whitening": {"mode": "local", "regularize": False},
"detection": {"peak_sign": "neg", "detect_threshold": 4},
"selection": {
Expand All @@ -42,6 +46,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
"cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True},
"multi_units_only": False,
"job_kwargs": {"n_jobs": 0.5},
"seed": 42,
"debug": False,
}

Expand All @@ -63,18 +68,21 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
"merging": "A dictionary to specify the final merging param to group cells after template matching (auto_merge_units)",
"motion_correction": "A dictionary to be provided if motion correction has to be performed (dense probe only)",
"apply_preprocessing": "Boolean to specify whether circus 2 should preprocess the recording or not. If yes, then high_pass filtering + common\
median reference + zscore",
median reference + whitening",
"apply_motion_correction": "Boolean to specify whether circus 2 should apply motion correction to the recording or not",
"matched_filtering": "Boolean to specify whether circus 2 should detect peaks via matched filtering (slightly slower)",
"cache_preprocessing": "How to cache the preprocessed recording. Mode can be memory, file, zarr, with extra arguments. In case of memory (default), \
memory_limit will control how much RAM can be used. In case of folder or zarr, delete_cache controls if cache is cleaned after sorting",
"multi_units_only": "Boolean to get only multi units activity (i.e. one template per electrode)",
"job_kwargs": "A dictionary to specify how many jobs and which parameters they should used",
"seed": "An int to control how chunks are shuffled while detecting peaks",
"debug": "Boolean to specify if internal data structures made during the sorting should be kept for debugging",
}

sorter_description = """Spyking Circus 2 is a rewriting of Spyking Circus, within the SpikeInterface framework
It uses a more conservative clustering algorithm (compared to Spyking Circus), which is less prone to hallucinate units and/or find noise.
In addition, it also uses a full Orthogonal Matching Pursuit engine to reconstruct the traces, leading to more spikes
being discovered."""
being discovered. The code is much faster and memory efficient, inheriting from all the preprocessing possibilities of spikeinterface"""

@classmethod
def get_sorter_version(cls):
Expand Down Expand Up @@ -103,7 +111,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
from spikeinterface.sortingcomponents.clustering import find_cluster_from_peaks
from spikeinterface.sortingcomponents.matching import find_spikes_from_templates
from spikeinterface.sortingcomponents.tools import remove_empty_templates
from spikeinterface.sortingcomponents.tools import get_prototype_spike, check_probe_for_drift_correction
from spikeinterface.sortingcomponents.tools import check_probe_for_drift_correction

job_kwargs = fix_job_kwargs(params["job_kwargs"])
job_kwargs.update({"progress_bar": verbose})
Expand All @@ -120,10 +128,14 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
## First, we are filtering the data
filtering_params = params["filtering"].copy()
if params["apply_preprocessing"]:
if verbose:
print("Preprocessing the recording (bandpass filtering + CMR + whitening)")
recording_f = bandpass_filter(recording, **filtering_params, dtype="float32")
if num_channels > 1:
recording_f = common_reference(recording_f)
else:
if verbose:
print("Skipping preprocessing (whitening only)")
recording_f = recording
recording_f.annotate(is_filtered=True)

Expand All @@ -146,12 +158,14 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
# TODO add , regularize=True chen ready
whitening_kwargs = params["whitening"].copy()
whitening_kwargs["dtype"] = "float32"
whitening_kwargs["radius_um"] = radius_um
whitening_kwargs["regularize"] = whitening_kwargs.get("regularize", False)
if num_channels == 1:
whitening_kwargs["regularize"] = False
if whitening_kwargs["regularize"]:
whitening_kwargs["regularize_kwargs"] = {"method": "LedoitWolf"}

recording_w = whiten(recording_f, **whitening_kwargs)
noise_levels = get_noise_levels(recording_w, return_scaled=False)
noise_levels = get_noise_levels(recording_w, return_scaled=False, **job_kwargs)

if recording_w.check_serializability("json"):
recording_w.dump(sorter_output_folder / "preprocessed_recording.json", relative_to=None)
Expand All @@ -162,42 +176,69 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):

## Then, we are detecting peaks with a locally_exclusive method
detection_params = params["detection"].copy()
detection_params.update(job_kwargs)

detection_params["radius_um"] = detection_params.get("radius_um", 50)
selection_params = params["selection"].copy()
detection_params["radius_um"] = radius_um
detection_params["exclude_sweep_ms"] = exclude_sweep_ms
detection_params["noise_levels"] = noise_levels

fs = recording_w.get_sampling_frequency()
nbefore = int(ms_before * fs / 1000.0)
nafter = int(ms_after * fs / 1000.0)

skip_peaks = not params["multi_units_only"] and selection_params.get("method", "uniform") == "uniform"
max_n_peaks = selection_params["n_peaks_per_channel"] * num_channels
n_peaks = max(selection_params["min_n_peaks"], max_n_peaks)

if params["debug"]:
clustering_folder = sorter_output_folder / "clustering"
clustering_folder.mkdir(parents=True, exist_ok=True)
np.save(clustering_folder / "noise_levels.npy", noise_levels)

if params["matched_filtering"]:
peaks = detect_peaks(recording_w, "locally_exclusive", **detection_params, skip_after_n_peaks=5000)
prototype = get_prototype_spike(recording_w, peaks, ms_before, ms_after, **job_kwargs)
prototype, waveforms, _ = get_prototype_and_waveforms_from_recording(
recording_w,
n_peaks=10000,
ms_before=ms_before,
ms_after=ms_after,
seed=params["seed"],
**detection_params,
**job_kwargs,
)
detection_params["prototype"] = prototype
detection_params["ms_before"] = ms_before
peaks = detect_peaks(recording_w, "matched_filtering", **detection_params)
if params["debug"]:
np.save(clustering_folder / "waveforms.npy", waveforms)
np.save(clustering_folder / "prototype.npy", prototype)
if skip_peaks:
detection_params["skip_after_n_peaks"] = n_peaks
detection_params["recording_slices"] = get_shuffled_recording_slices(
recording_w, seed=params["seed"], **job_kwargs
)
peaks = detect_peaks(recording_w, "matched_filtering", **detection_params, **job_kwargs)
else:
peaks = detect_peaks(recording_w, "locally_exclusive", **detection_params)
waveforms = None
if skip_peaks:
detection_params["skip_after_n_peaks"] = n_peaks
detection_params["recording_slices"] = get_shuffled_recording_slices(
recording_w, seed=params["seed"], **job_kwargs
)
peaks = detect_peaks(recording_w, "locally_exclusive", **detection_params, **job_kwargs)

if verbose:
print("We found %d peaks in total" % len(peaks))
if not skip_peaks and verbose:
print("Found %d peaks in total" % len(peaks))

if params["multi_units_only"]:
sorting = NumpySorting.from_peaks(peaks, sampling_frequency, unit_ids=recording_w.unit_ids)
else:
## We subselect a subset of all the peaks, by making the distributions os SNRs over all
## channels as flat as possible
selection_params = params["selection"]
selection_params["n_peaks"] = min(len(peaks), selection_params["n_peaks_per_channel"] * num_channels)
selection_params["n_peaks"] = max(selection_params["min_n_peaks"], selection_params["n_peaks"])

selection_params["n_peaks"] = n_peaks
selection_params.update({"noise_levels": noise_levels})
selected_peaks = select_peaks(peaks, **selection_params)

if verbose:
print("We kept %d peaks for clustering" % len(selected_peaks))
print("Kept %d peaks for clustering" % len(selected_peaks))

## We launch a clustering (using hdbscan) relying on positions and features extracted on
## the fly from the snippets
Expand All @@ -207,10 +248,13 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
clustering_params["radius_um"] = radius_um
clustering_params["waveforms"]["ms_before"] = ms_before
clustering_params["waveforms"]["ms_after"] = ms_after
clustering_params["few_waveforms"] = waveforms
clustering_params["noise_levels"] = noise_levels
clustering_params["ms_before"] = exclude_sweep_ms
clustering_params["ms_after"] = exclude_sweep_ms
clustering_params["ms_before"] = ms_before
clustering_params["ms_after"] = ms_after
clustering_params["verbose"] = verbose
clustering_params["tmp_folder"] = sorter_output_folder / "clustering"
clustering_params["noise_threshold"] = detection_params.get("detect_threshold", 4)

legacy = clustering_params.get("legacy", True)

Expand All @@ -235,12 +279,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
unit_ids = np.arange(len(np.unique(labeled_peaks["unit_index"])))
sorting = NumpySorting(labeled_peaks, sampling_frequency, unit_ids=unit_ids)

clustering_folder = sorter_output_folder / "clustering"
clustering_folder.mkdir(parents=True, exist_ok=True)

if not params["debug"]:
shutil.rmtree(clustering_folder)
else:
if params["debug"]:
np.save(clustering_folder / "peak_labels", peak_labels)
np.save(clustering_folder / "labels", labels)
np.save(clustering_folder / "peaks", selected_peaks)

Expand Down Expand Up @@ -283,7 +323,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
np.save(fitting_folder / "spikes", spikes)

if verbose:
print("We found %d spikes" % len(spikes))
print("Found %d spikes" % len(spikes))

## And this is it! We have a spyking circus
sorting = np.zeros(spikes.size, dtype=minimum_spike_dtype)
Expand Down Expand Up @@ -320,7 +360,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
sorting = final_cleaning_circus(recording_w, sorting, templates, **merging_params, **job_kwargs)

if verbose:
print(f"Final merging, keeping {len(sorting.unit_ids)} units")
print(f"Kept {len(sorting.unit_ids)} units after final merging")

folder_to_delete = None
cache_mode = params["cache_preprocessing"].get("mode", "memory")
Expand Down
Loading

0 comments on commit c213697

Please sign in to comment.