Skip to content

Commit

Permalink
Components of SC2 (#2870)
Browse files Browse the repository at this point in the history
SC2 parameters
  • Loading branch information
yger authored May 22, 2024
1 parent 70ebe17 commit c52be0e
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 19 deletions.
38 changes: 21 additions & 17 deletions src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,19 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
"select_per_channel": False,
"seed": 42,
},
"drift_correction": {"preset": "nonrigid_fast_and_accurate"},
"apply_motion_correction": True,
"motion_correction": {"preset": "nonrigid_fast_and_accurate"},
"merging": {
"minimum_spikes": 10,
"corr_diff_thresh": 0.5,
"template_metric": "cosine",
"censor_correlograms_ms": 0.4,
"num_channels": 5,
"num_channels": None,
},
"clustering": {"legacy": True},
"matching": {"method": "circus-omp-svd"},
"matching": {"method": "wobble"},
"apply_preprocessing": True,
"matched_filtering": False,
"matched_filtering": True,
"cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True},
"multi_units_only": False,
"job_kwargs": {"n_jobs": 0.8},
Expand Down Expand Up @@ -121,6 +122,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
ms_before = params["general"].get("ms_before", 2)
ms_after = params["general"].get("ms_after", 2)
radius_um = params["general"].get("radius_um", 100)
exclude_sweep_ms = params["detection"].get("exclude_sweep_ms", max(ms_before, ms_after) / 2)

## First, we are filtering the data
filtering_params = params["filtering"].copy()
Expand All @@ -133,15 +135,17 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
recording_f.annotate(is_filtered=True)

valid_geometry = check_probe_for_drift_correction(recording_f)
if params["drift_correction"] is not None:
if params["apply_motion_correction"]:
if not valid_geometry:
print("Geometry of the probe does not allow 1D drift correction")
if verbose:
print("Geometry of the probe does not allow 1D drift correction")
motion_folder = None
else:
print("Motion correction activated (probe geometry compatible)")
if verbose:
print("Motion correction activated (probe geometry compatible)")
motion_folder = sorter_output_folder / "motion"
params["drift_correction"].update({"folder": motion_folder})
recording_f = correct_motion(recording_f, **params["drift_correction"])
params["motion_correction"].update({"folder": motion_folder})
recording_f = correct_motion(recording_f, **params["motion_correction"])
else:
motion_folder = None

Expand All @@ -163,7 +167,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
detection_params.update(job_kwargs)

detection_params["radius_um"] = detection_params.get("radius_um", 50)
detection_params["exclude_sweep_ms"] = detection_params.get("exclude_sweep_ms", 0.5)
detection_params["exclude_sweep_ms"] = exclude_sweep_ms
detection_params["noise_levels"] = noise_levels

fs = recording_w.get_sampling_frequency()
Expand Down Expand Up @@ -212,6 +216,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
clustering_params["waveforms"]["ms_after"] = ms_after
clustering_params["job_kwargs"] = job_kwargs
clustering_params["noise_levels"] = noise_levels
clustering_params["ms_before"] = exclude_sweep_ms
clustering_params["ms_after"] = exclude_sweep_ms
clustering_params["tmp_folder"] = sorter_output_folder / "clustering"

legacy = clustering_params.get("legacy", True)
Expand Down Expand Up @@ -275,12 +281,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
matching_params["templates"] = templates
matching_job_params = job_kwargs.copy()

if matching_method == "circus-omp-svd":

for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]:
if value in matching_job_params:
matching_job_params[value] = None
matching_job_params["chunk_duration"] = "100ms"
for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]:
if value in matching_job_params:
matching_job_params[value] = None
matching_job_params["chunk_duration"] = "100ms"

spikes = find_spikes_from_templates(
recording_w, matching_method, method_kwargs=matching_params, **matching_job_params
Expand Down Expand Up @@ -308,7 +312,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
merging_params = params["merging"].copy()

if len(merging_params) > 0:
if params["drift_correction"] and motion_folder is not None:
if params["motion_correction"] and motion_folder is not None:
from spikeinterface.preprocessing.motion import load_motion_info

motion_info = load_motion_info(motion_folder)
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/sortingcomponents/clustering/circus.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ class CircusClustering:
},
"radius_um": 100,
"n_svd": [5, 2],
"ms_before": 2,
"ms_after": 2,
"ms_before": 0.5,
"ms_after": 0.5,
"noise_levels": None,
"tmp_folder": None,
"job_kwargs": {},
Expand Down

0 comments on commit c52be0e

Please sign in to comment.