Skip to content

Commit

Permalink
Merge branch 'SpikeInterface:main' into shm_templates
Browse files Browse the repository at this point in the history
  • Loading branch information
yger authored Jan 31, 2025
2 parents 5dcbc9c + 12a1276 commit 3e4b039
Show file tree
Hide file tree
Showing 11 changed files with 159 additions and 65 deletions.
12 changes: 6 additions & 6 deletions examples/tutorials/curation/plot_1_automated_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,8 @@ def calculate_moving_avg(label_df, confidence_label, window_size):
# trained on real data.
#
# For example, the following classifiers are trained on Neuropixels data from 11 mice recorded in
# V1,SC and ALM: https://huggingface.co/AnoushkaJain3/noise_neural_classifier/ and
# https://huggingface.co/AnoushkaJain3/sua_mua_classifier/ . One will classify units into
# V1,SC and ALM: https://huggingface.co/SpikeInterface/UnitRefine_noise_neural_classifier/ and
# https://huggingface.co/SpikeInterface/UnitRefine_sua_mua_classifier/. One will classify units into
# `noise` or `not-noise` and the other will classify the `not-noise` units into single
# unit activity (sua) units and multi-unit activity (mua) units.
#
Expand All @@ -221,8 +221,8 @@ def calculate_moving_avg(label_df, confidence_label, window_size):

# Apply the noise/not-noise model
noise_neuron_labels = sc.auto_label_units(
sorting_analyzer = sorting_analyzer,
repo_id = "AnoushkaJain3/noise_neural_classifier",
sorting_analyzer=sorting_analyzer,
repo_id="SpikeInterface/UnitRefine_noise_neural_classifier",
trust_model=True,
)

Expand All @@ -231,8 +231,8 @@ def calculate_moving_avg(label_df, confidence_label, window_size):

# Apply the sua/mua model
sua_mua_labels = sc.auto_label_units(
sorting_analyzer = analyzer_neural,
repo_id = "AnoushkaJain3/sua_mua_classifier",
sorting_analyzer=analyzer_neural,
repo_id="SpikeInterface/UnitRefine_sua_mua_classifier",
trust_model=True,
)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ extractors = [
"MEArec>=1.8",
"pynwb>=2.6.0",
"hdmf-zarr>=0.11.0",
"pyedflib>=0.1.30",
"pyedflib>=0.1.30,<0.1.39",
"sonpy;python_version<'3.10'",
"lxml", # lxml for neuroscope
"scipy",
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def is_set_global_dataset_folder() -> bool:

########################################
_default_job_kwargs = dict(
pool_engine="thread", n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_worker=1
pool_engine="process", n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_worker=1
)

global global_job_kwargs
Expand Down
81 changes: 59 additions & 22 deletions src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def __init__(
self._dtype = spike_peak_dtype

self.include_spikes_in_margin = include_spikes_in_margin
if include_spikes_in_margin is not None:
if include_spikes_in_margin:
self._dtype = spike_peak_dtype + [("in_margin", "bool")]

self.peaks = sorting_to_peaks(sorting, extremum_channel_inds, self._dtype)
Expand Down Expand Up @@ -228,12 +228,6 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin, pea
# get local peaks
sl = self.segment_slices[segment_index]
peaks_in_segment = self.peaks[sl]
# if self.include_spikes_in_margin:
# i0, i1 = np.searchsorted(
# peaks_in_segment["sample_index"], [start_frame - max_margin, end_frame + max_margin]
# )
# else:
# i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame])
i0, i1 = peak_slice

local_peaks = peaks_in_segment[i0:i1]
Expand Down Expand Up @@ -435,21 +429,59 @@ def compute(self, traces, peaks):
return sparse_wfs


def find_parent_of_type(list_of_parents, parent_type, unique=True):
def find_parent_of_type(list_of_parents, parent_type):
"""
Find a single parent of a given type(s) in a list of parents.
If multiple parents of the given type are found, the first parent is returned.
Parameters
----------
list_of_parents : list of PipelineNode
List of parents to search through.
parent_type : type | tuple of types
The type of parent to search for.
Returns
-------
parent : PipelineNode or None
The parent of the given type. Returns None if no parent of the given type is found.
"""
if list_of_parents is None:
return None

parents = find_parents_of_type(list_of_parents, parent_type)

if len(parents) > 0:
return parents[0]
else:
return None


def find_parents_of_type(list_of_parents, parent_type):
"""
Find all parents of a given type(s) in a list of parents.
Parameters
----------
list_of_parents : list of PipelineNode
List of parents to search through.
parent_type : type | tuple of types
The type(s) of parents to search for.
Returns
-------
parents : list of PipelineNode
List of parents of the given type(s). Returns an empty list if no parents of the given type(s) are found.
"""
if list_of_parents is None:
return []

parents = []
for parent in list_of_parents:
if isinstance(parent, parent_type):
parents.append(parent)

if unique and len(parents) == 1:
return parents[0]
elif not unique and len(parents) > 1:
return parents[0]
else:
return None
return parents


def check_graph(nodes):
Expand All @@ -471,7 +503,7 @@ def check_graph(nodes):
assert parent in nodes, f"Node {node} has parent {parent} that was not passed in nodes"
assert (
nodes.index(parent) < i
), f"Node are ordered incorrectly: {node} before {parent} in the pipeline definition."
), f"Nodes are ordered incorrectly: {node} before {parent} in the pipeline definition."

return nodes

Expand Down Expand Up @@ -607,12 +639,16 @@ def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_c
skip_after_n_peaks_per_worker = worker_ctx["skip_after_n_peaks_per_worker"]

recording_segment = recording._recording_segments[segment_index]
node0 = nodes[0]

if isinstance(node0, (SpikeRetriever, PeakRetriever)):
# in this case PeakSource could have no peaks and so no need to load traces just skip
peak_slice = i0, i1 = node0.get_peak_slice(segment_index, start_frame, end_frame, max_margin)
load_trace_and_compute = i0 < i1
retrievers = find_parents_of_type(nodes, (SpikeRetriever, PeakRetriever))
# get peak slices once for all retrievers
peak_slice_by_retriever = {}
for retriever in retrievers:
peak_slice = i0, i1 = retriever.get_peak_slice(segment_index, start_frame, end_frame, max_margin)
peak_slice_by_retriever[retriever] = peak_slice

if len(peak_slice_by_retriever) > 0:
# in this case the retrievers could have no peaks, so we test if any spikes are in the chunk
load_trace_and_compute = any(i0 < i1 for i0, i1 in peak_slice_by_retriever.values())
else:
# PeakDetector always need traces
load_trace_and_compute = True
Expand Down Expand Up @@ -646,7 +682,8 @@ def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_c
node_output = node.compute(trace_detection, start_frame, end_frame, segment_index, max_margin)
# set sample index to local
node_output[0]["sample_index"] += extra_margin
elif isinstance(node, PeakSource):
elif isinstance(node, (PeakRetriever, SpikeRetriever)):
peak_slice = peak_slice_by_retriever[node]
node_output = node.compute(traces_chunk, start_frame, end_frame, segment_index, max_margin, peak_slice)
else:
# TODO later when in master: change the signature of all nodes (or maybe not!)
Expand Down
5 changes: 2 additions & 3 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2033,19 +2033,18 @@ def load(cls, sorting_analyzer):
return None

def load_run_info(self):
run_info = None
if self.format == "binary_folder":
extension_folder = self._get_binary_extension_folder()
run_info_file = extension_folder / "run_info.json"
if run_info_file.is_file():
with open(str(run_info_file), "r") as f:
run_info = json.load(f)
else:
warnings.warn(f"Found no run_info file for {self.extension_name}, extension should be re-computed.")
run_info = None

elif self.format == "zarr":
extension_group = self._get_zarr_extension_group(mode="r")
run_info = extension_group.attrs.get("run_info", None)

if run_info is None:
warnings.warn(f"Found no run_info file for {self.extension_name}, extension should be re-computed.")
self.run_info = run_info
Expand Down
6 changes: 3 additions & 3 deletions src/spikeinterface/core/tests/test_globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_global_tmp_folder(create_cache_folder):

def test_global_job_kwargs():
job_kwargs = dict(
pool_engine="thread",
pool_engine="process",
n_jobs=4,
chunk_duration="1s",
progress_bar=True,
Expand All @@ -47,7 +47,7 @@ def test_global_job_kwargs():
global_job_kwargs = get_global_job_kwargs()

assert global_job_kwargs == dict(
pool_engine="thread",
pool_engine="process",
n_jobs=1,
chunk_duration="1s",
progress_bar=True,
Expand All @@ -62,7 +62,7 @@ def test_global_job_kwargs():
set_global_job_kwargs(**partial_job_kwargs)
global_job_kwargs = get_global_job_kwargs()
assert global_job_kwargs == dict(
pool_engine="thread",
pool_engine="process",
n_jobs=2,
chunk_duration="1s",
progress_bar=True,
Expand Down
19 changes: 12 additions & 7 deletions src/spikeinterface/postprocessing/template_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,13 +327,18 @@ def _run(self, verbose=False):
)

existing_metrics = []
tm_extension = self.sorting_analyzer.get_extension("template_metrics")
if (
delete_existing_metrics is False
and tm_extension is not None
and tm_extension.data.get("metrics") is not None
):
existing_metrics = tm_extension.params["metric_names"]

# Check if we need to propogate any old metrics. If so, we'll do that.
# Otherwise, we'll avoid attempting to load an empty template_metrics.
if set(self.params["metrics_to_compute"]) != set(self.params["metric_names"]):

tm_extension = self.sorting_analyzer.get_extension("template_metrics")
if (
delete_existing_metrics is False
and tm_extension is not None
and tm_extension.data.get("metrics") is not None
):
existing_metrics = tm_extension.params["metric_names"]

existing_metrics = []
# here we get in the loaded via the dict only (to avoid full loading from disk after params reset)
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/sorters/external/kilosort.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def _get_specific_options(cls, ops, params) -> dict:

# options for posthoc merges (under construction)
ops["fracse"] = 0.1 # binning step along discriminant axis for posthoc merges (in units of sd)
ops["epu"] = np.Inf
ops["epu"] = np.inf

ops["ForceMaxRAMforDat"] = 20e9 # maximum RAM the algorithm will try to use; on Windows it will autodetect.

Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/sorters/external/kilosortbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def _generate_ops_file(cls, recording, params, sorter_output_folder, binary_file
ops["fbinary"] = str(binary_file_path.absolute()) # will be created for 'openEphys'
ops["fproc"] = str((sorter_output_folder / "temp_wh.dat").absolute()) # residual from RAM of preprocessed data
ops["root"] = str(sorter_output_folder.absolute()) # 'openEphys' only: where raw files are
ops["trange"] = [0, np.Inf] # time range to sort
ops["trange"] = [0, np.inf] # time range to sort
ops["chanMap"] = str((sorter_output_folder / "chanMap.mat").absolute())

ops["fs"] = recording.get_sampling_frequency() # sample rate
Expand Down
37 changes: 29 additions & 8 deletions src/spikeinterface/sortingcomponents/motion/motion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,10 +587,14 @@ def ensure_time_bins(time_bin_centers_s=None, time_bin_edges_s=None):
Going from centers to edges is done by taking midpoints and padding with the
left and rightmost centers.
To handle multi segment, this function is working both:
* array/array input
* list[array]/list[array] input
Parameters
----------
time_bin_centers_s : None or np.array
time_bin_edges_s : None or np.array
time_bin_centers_s : None or np.array or list[np.array]
time_bin_edges_s : None or np.array or list[np.array]
Returns
-------
Expand All @@ -600,17 +604,34 @@ def ensure_time_bins(time_bin_centers_s=None, time_bin_edges_s=None):
raise ValueError("Need at least one of time_bin_centers_s or time_bin_edges_s.")

if time_bin_centers_s is None:
assert time_bin_edges_s.ndim == 1 and time_bin_edges_s.size >= 2
time_bin_centers_s = 0.5 * (time_bin_edges_s[1:] + time_bin_edges_s[:-1])
if isinstance(time_bin_edges_s, list):
# multi segment cas
time_bin_centers_s = []
for be in time_bin_edges_s:
bc, _ = ensure_time_bins(time_bin_centers_s=None, time_bin_edges_s=be)
time_bin_centers_s.append(bc)
else:
# simple segment
assert time_bin_edges_s.ndim == 1 and time_bin_edges_s.size >= 2
time_bin_centers_s = 0.5 * (time_bin_edges_s[1:] + time_bin_edges_s[:-1])

if time_bin_edges_s is None:
time_bin_edges_s = np.empty(time_bin_centers_s.shape[0] + 1, dtype=time_bin_centers_s.dtype)
time_bin_edges_s[[0, -1]] = time_bin_centers_s[[0, -1]]
if time_bin_centers_s.size > 2:
time_bin_edges_s[1:-1] = 0.5 * (time_bin_centers_s[1:] + time_bin_centers_s[:-1])
if isinstance(time_bin_centers_s, list):
# multi segment cas
time_bin_edges_s = []
for bc in time_bin_centers_s:
_, be = ensure_time_bins(time_bin_centers_s=bc, time_bin_edges_s=None)
time_bin_edges_s.append(be)
else:
# simple segment
time_bin_edges_s = np.empty(time_bin_centers_s.shape[0] + 1, dtype=time_bin_centers_s.dtype)
time_bin_edges_s[[0, -1]] = time_bin_centers_s[[0, -1]]
if time_bin_centers_s.size > 2:
time_bin_edges_s[1:-1] = 0.5 * (time_bin_centers_s[1:] + time_bin_centers_s[:-1])

return time_bin_centers_s, time_bin_edges_s



def ensure_time_bin_edges(time_bin_centers_s=None, time_bin_edges_s=None):
return ensure_time_bins(time_bin_centers_s, time_bin_edges_s)[1]
Loading

0 comments on commit 3e4b039

Please sign in to comment.