Skip to content

Commit

Permalink
Merge branch 'main' into mpl_cmap
Browse files Browse the repository at this point in the history
  • Loading branch information
h-mayorquin authored May 22, 2024
2 parents 4b67b20 + 29ad02b commit cd3dc7a
Show file tree
Hide file tree
Showing 10 changed files with 256 additions and 98 deletions.
38 changes: 21 additions & 17 deletions src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,19 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
"select_per_channel": False,
"seed": 42,
},
"drift_correction": {"preset": "nonrigid_fast_and_accurate"},
"apply_motion_correction": True,
"motion_correction": {"preset": "nonrigid_fast_and_accurate"},
"merging": {
"minimum_spikes": 10,
"corr_diff_thresh": 0.5,
"template_metric": "cosine",
"censor_correlograms_ms": 0.4,
"num_channels": 5,
"num_channels": None,
},
"clustering": {"legacy": True},
"matching": {"method": "circus-omp-svd"},
"matching": {"method": "wobble"},
"apply_preprocessing": True,
"matched_filtering": False,
"matched_filtering": True,
"cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True},
"multi_units_only": False,
"job_kwargs": {"n_jobs": 0.8},
Expand Down Expand Up @@ -121,6 +122,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
ms_before = params["general"].get("ms_before", 2)
ms_after = params["general"].get("ms_after", 2)
radius_um = params["general"].get("radius_um", 100)
exclude_sweep_ms = params["detection"].get("exclude_sweep_ms", max(ms_before, ms_after) / 2)

## First, we are filtering the data
filtering_params = params["filtering"].copy()
Expand All @@ -133,15 +135,17 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
recording_f.annotate(is_filtered=True)

valid_geometry = check_probe_for_drift_correction(recording_f)
if params["drift_correction"] is not None:
if params["apply_motion_correction"]:
if not valid_geometry:
print("Geometry of the probe does not allow 1D drift correction")
if verbose:
print("Geometry of the probe does not allow 1D drift correction")
motion_folder = None
else:
print("Motion correction activated (probe geometry compatible)")
if verbose:
print("Motion correction activated (probe geometry compatible)")
motion_folder = sorter_output_folder / "motion"
params["drift_correction"].update({"folder": motion_folder})
recording_f = correct_motion(recording_f, **params["drift_correction"])
params["motion_correction"].update({"folder": motion_folder})
recording_f = correct_motion(recording_f, **params["motion_correction"])
else:
motion_folder = None

Expand All @@ -163,7 +167,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
detection_params.update(job_kwargs)

detection_params["radius_um"] = detection_params.get("radius_um", 50)
detection_params["exclude_sweep_ms"] = detection_params.get("exclude_sweep_ms", 0.5)
detection_params["exclude_sweep_ms"] = exclude_sweep_ms
detection_params["noise_levels"] = noise_levels

fs = recording_w.get_sampling_frequency()
Expand Down Expand Up @@ -212,6 +216,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
clustering_params["waveforms"]["ms_after"] = ms_after
clustering_params["job_kwargs"] = job_kwargs
clustering_params["noise_levels"] = noise_levels
clustering_params["ms_before"] = exclude_sweep_ms
clustering_params["ms_after"] = exclude_sweep_ms
clustering_params["tmp_folder"] = sorter_output_folder / "clustering"

legacy = clustering_params.get("legacy", True)
Expand Down Expand Up @@ -275,12 +281,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
matching_params["templates"] = templates
matching_job_params = job_kwargs.copy()

if matching_method == "circus-omp-svd":

for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]:
if value in matching_job_params:
matching_job_params[value] = None
matching_job_params["chunk_duration"] = "100ms"
for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]:
if value in matching_job_params:
matching_job_params[value] = None
matching_job_params["chunk_duration"] = "100ms"

spikes = find_spikes_from_templates(
recording_w, matching_method, method_kwargs=matching_params, **matching_job_params
Expand Down Expand Up @@ -308,7 +312,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
merging_params = params["merging"].copy()

if len(merging_params) > 0:
if params["drift_correction"] and motion_folder is not None:
if params["motion_correction"] and motion_folder is not None:
from spikeinterface.preprocessing.motion import load_motion_info

motion_info = load_motion_info(motion_folder)
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/sortingcomponents/clustering/circus.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ class CircusClustering:
},
"radius_um": 100,
"n_svd": [5, 2],
"ms_before": 2,
"ms_after": 2,
"ms_before": 0.5,
"ms_after": 0.5,
"noise_levels": None,
"tmp_folder": None,
"job_kwargs": {},
Expand Down
10 changes: 10 additions & 0 deletions src/spikeinterface/widgets/tests/test_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,16 @@ def test_plot_unit_templates(self):
backend=backend,
**self.backend_kwargs[backend],
)
# test with templates
templates_ext = self.sorting_analyzer_dense.get_extension("templates")
templates = templates_ext.get_data(outputs="Templates")
sw.plot_unit_templates(
templates,
sparsity=self.sparsity_strict,
unit_ids=unit_ids,
backend=backend,
**self.backend_kwargs[backend],
)
else:
# sortingview doesn't support more than 2 shadings
with self.assertRaises(AssertionError):
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/widgets/unit_depths.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(
unit_ids = sorting_analyzer.sorting.unit_ids

if unit_colors is None:
unit_colors = get_unit_colors(sorting_analyzer.sorting)
unit_colors = get_unit_colors(sorting_analyzer)

colors = [unit_colors[unit_id] for unit_id in unit_ids]

Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/widgets/unit_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(
sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer)

if unit_colors is None:
unit_colors = get_unit_colors(sorting_analyzer.sorting)
unit_colors = get_unit_colors(sorting_analyzer)

plot_data = dict(
sorting_analyzer=sorting_analyzer,
Expand Down
6 changes: 5 additions & 1 deletion src/spikeinterface/widgets/unit_templates.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from ..core import SortingAnalyzer
from .unit_waveforms import UnitWaveformsWidget
from .base import to_attr

Expand All @@ -17,6 +18,9 @@ def plot_sortingview(self, data_plot, **backend_kwargs):

dp = to_attr(data_plot)

sorting_analyzer = dp.sorting_analyzer_or_templates
assert isinstance(sorting_analyzer, SortingAnalyzer), "This widget requires a SortingAnalyzer as input"

assert len(dp.templates_shading) <= 4, "Only 2 ans 4 templates shading are supported in sortingview"

# ensure serializable for sortingview
Expand Down Expand Up @@ -50,7 +54,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs):
v_average_waveforms = vv.AverageWaveforms(average_waveforms=aw_items, channel_locations=locations)

if not dp.hide_unit_selector:
v_units_table = generate_unit_table_view(dp.sorting_analyzer.sorting)
v_units_table = generate_unit_table_view(sorting_analyzer.sorting)

self.view = vv.Box(
direction="horizontal",
Expand Down
Loading

0 comments on commit cd3dc7a

Please sign in to comment.