From cb7ffd611d372d765bc00c6cb08449f25c1413b4 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Tue, 15 Oct 2024 15:32:52 +0100 Subject: [PATCH 1/7] Only load template_metrics on compute if propogating some metrics --- src/spikeinterface/core/sortinganalyzer.py | 5 ++--- .../postprocessing/template_metrics.py | 19 ++++++++++++------- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 55cbe6070a..e6eb2c23b9 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -2033,19 +2033,18 @@ def load(cls, sorting_analyzer): return None def load_run_info(self): + run_info = None if self.format == "binary_folder": extension_folder = self._get_binary_extension_folder() run_info_file = extension_folder / "run_info.json" if run_info_file.is_file(): with open(str(run_info_file), "r") as f: run_info = json.load(f) - else: - warnings.warn(f"Found no run_info file for {self.extension_name}, extension should be re-computed.") - run_info = None elif self.format == "zarr": extension_group = self._get_zarr_extension_group(mode="r") run_info = extension_group.attrs.get("run_info", None) + if run_info is None: warnings.warn(f"Found no run_info file for {self.extension_name}, extension should be re-computed.") self.run_info = run_info diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 306e9594b8..0a04fa65da 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -335,13 +335,18 @@ def _run(self, verbose=False): ) existing_metrics = [] - tm_extension = self.sorting_analyzer.get_extension("template_metrics") - if ( - delete_existing_metrics is False - and tm_extension is not None - and tm_extension.data.get("metrics") is not None - ): - existing_metrics = tm_extension.params["metric_names"] + + # Check if we need to propogate any old metrics. If so, we'll do that. + # Otherwise, we'll avoid attempting to load an empty template_metrics. + if set(self.params["metrics_to_compute"]) != set(self.params["metric_names"]): + + tm_extension = self.sorting_analyzer.get_extension("template_metrics") + if ( + delete_existing_metrics is False + and tm_extension is not None + and tm_extension.data.get("metrics") is not None + ): + existing_metrics = tm_extension.params["metric_names"] # append the metrics which were previously computed for metric_name in set(existing_metrics).difference(metrics_to_compute): From b64669a0d95b3e6ec8de7982ddb38efa8500412c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 27 Jan 2025 17:20:21 +0100 Subject: [PATCH 2/7] Retrieve correct peak slice for multiple retrievers --- src/spikeinterface/core/node_pipeline.py | 47 ++++++++++++++++-------- 1 file changed, 31 insertions(+), 16 deletions(-) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 53c2445c77..6167a3ab75 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -191,7 +191,7 @@ def __init__( self._dtype = spike_peak_dtype self.include_spikes_in_margin = include_spikes_in_margin - if include_spikes_in_margin is not None: + if include_spikes_in_margin: self._dtype = spike_peak_dtype + [("in_margin", "bool")] self.peaks = sorting_to_peaks(sorting, extremum_channel_inds, self._dtype) @@ -228,12 +228,6 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin, pea # get local peaks sl = self.segment_slices[segment_index] peaks_in_segment = self.peaks[sl] - # if self.include_spikes_in_margin: - # i0, i1 = np.searchsorted( - # peaks_in_segment["sample_index"], [start_frame - max_margin, end_frame + max_margin] - # ) - # else: - # i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame]) i0, i1 = peak_slice local_peaks = peaks_in_segment[i0:i1] @@ -452,6 +446,21 @@ def find_parent_of_type(list_of_parents, parent_type, unique=True): return None +def find_parents_of_type(list_of_parents, parent_type): + if list_of_parents is None: + return None + + parents = [] + for parent in list_of_parents: + if isinstance(parent, parent_type): + parents.append(parent) + + if len(parents) > 0: + return parents + else: + return None + + def check_graph(nodes): """ Check that node list is orderd in a good (parents are before children) @@ -471,7 +480,7 @@ def check_graph(nodes): assert parent in nodes, f"Node {node} has parent {parent} that was not passed in nodes" assert ( nodes.index(parent) < i - ), f"Node are ordered incorrectly: {node} beforeĀ {parent} in the pipeline definition." + ), f"Node are ordered incorrectly: {node} before {parent} in the pipeline definition." return nodes @@ -607,12 +616,17 @@ def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_c skip_after_n_peaks_per_worker = worker_ctx["skip_after_n_peaks_per_worker"] recording_segment = recording._recording_segments[segment_index] - node0 = nodes[0] - - if isinstance(node0, (SpikeRetriever, PeakRetriever)): - # in this case PeakSource could have no peaks and so no need to load traces just skip - peak_slice = i0, i1 = node0.get_peak_slice(segment_index, start_frame, end_frame, max_margin) - load_trace_and_compute = i0 < i1 + retrievers = find_parents_of_type(nodes, (SpikeRetriever, PeakRetriever)) + # get peak slices once for all retrievers + retriever_node = None + peak_slice_by_retriever = {} + for retriever in retrievers: + peak_slice = i0, i1 = retriever_node.get_peak_slice(segment_index, start_frame, end_frame, max_margin) + peak_slice_by_retriever[retriever] = peak_slice + + if len(peak_slice_by_retriever) > 0: + # in this case the retrievers could have no peaks, so we test if any spikes are in the chunk + load_trace_and_compute = any(i0 < i1 for i0, i1 in peak_slice_by_retriever.values()) else: # PeakDetector always need traces load_trace_and_compute = True @@ -627,7 +641,7 @@ def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_c ) # compute the graph pipeline_outputs = {} - for node in nodes: + for i, node in enumerate(nodes): node_parents = node.parents if node.parents else list() node_input_args = tuple() for parent in node_parents: @@ -646,7 +660,8 @@ def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_c node_output = node.compute(trace_detection, start_frame, end_frame, segment_index, max_margin) # set sample index to local node_output[0]["sample_index"] += extra_margin - elif isinstance(node, PeakSource): + elif isinstance(node, (PeakRetriever, SpikeRetriever)): + peak_slice = peak_slice_by_retriever[node] node_output = node.compute(traces_chunk, start_frame, end_frame, segment_index, max_margin, peak_slice) else: # TODO later when in master: change the signature of all nodes (or maybe not!) From cf7c3935907cd01a937a60b81acc52c152a7aa5a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 27 Jan 2025 17:26:01 +0100 Subject: [PATCH 3/7] oups --- src/spikeinterface/core/node_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 6167a3ab75..202b25b9a3 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -621,7 +621,7 @@ def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_c retriever_node = None peak_slice_by_retriever = {} for retriever in retrievers: - peak_slice = i0, i1 = retriever_node.get_peak_slice(segment_index, start_frame, end_frame, max_margin) + peak_slice = i0, i1 = retriever.get_peak_slice(segment_index, start_frame, end_frame, max_margin) peak_slice_by_retriever[retriever] = peak_slice if len(peak_slice_by_retriever) > 0: From acd5c171a4603b07eb97ca4c47210831e34ff04f Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 27 Jan 2025 17:38:33 +0100 Subject: [PATCH 4/7] Remove unused unique arg and fix empty parents --- src/spikeinterface/core/node_pipeline.py | 50 +++++++++++++++++------- 1 file changed, 36 insertions(+), 14 deletions(-) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 202b25b9a3..aac4ee8214 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -429,36 +429,59 @@ def compute(self, traces, peaks): return sparse_wfs -def find_parent_of_type(list_of_parents, parent_type, unique=True): +def find_parent_of_type(list_of_parents, parent_type): + """ + Find a single parent of a given type(s) in a list of parents. + If multiple parents of the given type are found, the first parent is returned. + + Parameters + ---------- + list_of_parents : list of PipelineNode + List of parents to search through. + parent_type : type + The type of parent to search for. + + Returns + ------- + parent : PipelineNode or None + The parent of the given type. Returns None if no parent of the given type is found. + """ if list_of_parents is None: return None - parents = [] - for parent in list_of_parents: - if isinstance(parent, parent_type): - parents.append(parent) + parents = find_parents_of_type(list_of_parents, parent_type) - if unique and len(parents) == 1: - return parents[0] - elif not unique and len(parents) > 1: + if len(parents) > 0: return parents[0] else: return None def find_parents_of_type(list_of_parents, parent_type): + """ + Find all parents of a given type(s) in a list of parents. + + Parameters + ---------- + list_of_parents : list of PipelineNode + List of parents to search through. + parent_type : type | tuple of types + The type(s) of parents to search for. + + Returns + ------- + parents : list of PipelineNode + List of parents of the given type(s). Returns an empty list if no parents of the given type(s) are found. + """ if list_of_parents is None: - return None + return [] parents = [] for parent in list_of_parents: if isinstance(parent, parent_type): parents.append(parent) - if len(parents) > 0: - return parents - else: - return None + return parents def check_graph(nodes): @@ -618,7 +641,6 @@ def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_c recording_segment = recording._recording_segments[segment_index] retrievers = find_parents_of_type(nodes, (SpikeRetriever, PeakRetriever)) # get peak slices once for all retrievers - retriever_node = None peak_slice_by_retriever = {} for retriever in retrievers: peak_slice = i0, i1 = retriever.get_peak_slice(segment_index, start_frame, end_frame, max_margin) From 47e183953c3174f29aae3dedd30c5402ec736892 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 27 Jan 2025 17:52:23 +0100 Subject: [PATCH 5/7] Update src/spikeinterface/core/node_pipeline.py Co-authored-by: Chris Halcrow <57948917+chrishalcrow@users.noreply.github.com> --- src/spikeinterface/core/node_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index aac4ee8214..9a81534333 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -663,7 +663,7 @@ def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_c ) # compute the graph pipeline_outputs = {} - for i, node in enumerate(nodes): + for node in nodes: node_parents = node.parents if node.parents else list() node_input_args = tuple() for parent in node_parents: From 988b2a0ca0edc58c049bcd4b72392b808bf917eb Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 28 Jan 2025 09:52:19 +0100 Subject: [PATCH 6/7] Apply suggestions from code review Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/node_pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 9a81534333..9496e03ba7 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -438,7 +438,7 @@ def find_parent_of_type(list_of_parents, parent_type): ---------- list_of_parents : list of PipelineNode List of parents to search through. - parent_type : type + parent_type : type | tuple of types The type of parent to search for. Returns @@ -503,7 +503,7 @@ def check_graph(nodes): assert parent in nodes, f"Node {node} has parent {parent} that was not passed in nodes" assert ( nodes.index(parent) < i - ), f"Node are ordered incorrectly: {node} before {parent} in the pipeline definition." + ), f"Nodes are ordered incorrectly: {node} before {parent} in the pipeline definition." return nodes From a0d658daf34a2d2ff74b24c72bc98dc371ae8fb3 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 30 Jan 2025 14:43:18 +0100 Subject: [PATCH 7/7] Multi segment handling in ensure_time_bins --- .../sortingcomponents/motion/motion_utils.py | 37 +++++++++--- .../motion/tests/test_motion_interpolation.py | 56 +++++++++++++++---- 2 files changed, 73 insertions(+), 20 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/motion/motion_utils.py b/src/spikeinterface/sortingcomponents/motion/motion_utils.py index 680d75f221..5c02646497 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_utils.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_utils.py @@ -587,10 +587,14 @@ def ensure_time_bins(time_bin_centers_s=None, time_bin_edges_s=None): Going from centers to edges is done by taking midpoints and padding with the left and rightmost centers. + To handle multi segment, this function is working both: + * array/array input + * list[array]/list[array] input + Parameters ---------- - time_bin_centers_s : None or np.array - time_bin_edges_s : None or np.array + time_bin_centers_s : None or np.array or list[np.array] + time_bin_edges_s : None or np.array or list[np.array] Returns ------- @@ -600,17 +604,34 @@ def ensure_time_bins(time_bin_centers_s=None, time_bin_edges_s=None): raise ValueError("Need at least one of time_bin_centers_s or time_bin_edges_s.") if time_bin_centers_s is None: - assert time_bin_edges_s.ndim == 1 and time_bin_edges_s.size >= 2 - time_bin_centers_s = 0.5 * (time_bin_edges_s[1:] + time_bin_edges_s[:-1]) + if isinstance(time_bin_edges_s, list): + # multi segment cas + time_bin_centers_s = [] + for be in time_bin_edges_s: + bc, _ = ensure_time_bins(time_bin_centers_s=None, time_bin_edges_s=be) + time_bin_centers_s.append(bc) + else: + # simple segment + assert time_bin_edges_s.ndim == 1 and time_bin_edges_s.size >= 2 + time_bin_centers_s = 0.5 * (time_bin_edges_s[1:] + time_bin_edges_s[:-1]) if time_bin_edges_s is None: - time_bin_edges_s = np.empty(time_bin_centers_s.shape[0] + 1, dtype=time_bin_centers_s.dtype) - time_bin_edges_s[[0, -1]] = time_bin_centers_s[[0, -1]] - if time_bin_centers_s.size > 2: - time_bin_edges_s[1:-1] = 0.5 * (time_bin_centers_s[1:] + time_bin_centers_s[:-1]) + if isinstance(time_bin_centers_s, list): + # multi segment cas + time_bin_edges_s = [] + for bc in time_bin_centers_s: + _, be = ensure_time_bins(time_bin_centers_s=bc, time_bin_edges_s=None) + time_bin_edges_s.append(be) + else: + # simple segment + time_bin_edges_s = np.empty(time_bin_centers_s.shape[0] + 1, dtype=time_bin_centers_s.dtype) + time_bin_edges_s[[0, -1]] = time_bin_centers_s[[0, -1]] + if time_bin_centers_s.size > 2: + time_bin_edges_s[1:-1] = 0.5 * (time_bin_centers_s[1:] + time_bin_centers_s[:-1]) return time_bin_centers_s, time_bin_edges_s + def ensure_time_bin_edges(time_bin_centers_s=None, time_bin_edges_s=None): return ensure_time_bins(time_bin_centers_s, time_bin_edges_s)[1] diff --git a/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py index e4ba870325..807b8e6c9e 100644 --- a/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/tests/test_motion_interpolation.py @@ -10,18 +10,26 @@ interpolate_motion_on_traces, ) from spikeinterface.sortingcomponents.tests.common import make_dataset - +from spikeinterface.core import generate_ground_truth_recording def make_fake_motion(rec): # make a fake motion object - duration = rec.get_total_duration() + locs = rec.get_channel_locations() - temporal_bins = np.arange(0.5, duration - 0.49, 0.5) spatial_bins = np.arange(locs[:, 1].min(), locs[:, 1].max(), 100) - displacement = np.zeros((temporal_bins.size, spatial_bins.size)) - displacement[:, :] = np.linspace(-30, 30, temporal_bins.size)[:, None] - motion = Motion([displacement], [temporal_bins], spatial_bins, direction="y") + displacement = [] + temporal_bins = [] + for segment_index in range(rec.get_num_segments()): + duration = rec.get_duration(segment_index=segment_index) + seg_time_bins = np.arange(0.5, duration - 0.49, 0.5) + seg_disp = np.zeros((seg_time_bins.size, spatial_bins.size)) + seg_disp[:, :] = np.linspace(-30, 30, seg_time_bins.size)[:, None] + + temporal_bins.append(seg_time_bins) + displacement.append(seg_disp) + + motion = Motion(displacement, temporal_bins, spatial_bins, direction="y") return motion @@ -176,7 +184,27 @@ def test_cross_band_interpolation(): def test_InterpolateMotionRecording(): - rec, sorting = make_dataset() + # rec, sorting = make_dataset() + + # 2 segments + rec, sorting = generate_ground_truth_recording( + durations=[30.0], + sampling_frequency=30000.0, + num_channels=32, + num_units=10, + generate_probe_kwargs=dict( + num_columns=2, + xpitch=20, + ypitch=20, + contact_shapes="circle", + contact_shape_params={"radius": 6}, + ), + generate_sorting_kwargs=dict(firing_rates=6.0, refractory_period_ms=4.0), + noise_kwargs=dict(noise_levels=5.0, strategy="on_the_fly"), + seed=2205, + ) + + motion = make_fake_motion(rec) rec2 = InterpolateMotionRecording(rec, motion, border_mode="force_extrapolate") @@ -187,15 +215,19 @@ def test_InterpolateMotionRecording(): rec2 = InterpolateMotionRecording(rec, motion, border_mode="remove_channels") assert rec2.channel_ids.size == 24 - for ch_id in (0, 1, 14, 15, 16, 17, 30, 31): + for ch_id in ("0", "1", "14", "15", "16", "17", "30", "31"): assert ch_id not in rec2.channel_ids traces = rec2.get_traces(segment_index=0, start_frame=0, end_frame=30000) assert traces.shape == (30000, 24) - traces = rec2.get_traces(segment_index=0, start_frame=0, end_frame=30000, channel_ids=[3, 4]) + traces = rec2.get_traces(segment_index=0, start_frame=0, end_frame=30000, channel_ids=["3", "4"]) assert traces.shape == (30000, 2) + # test dump.load when multi segments + rec2.dump("rec_motion_interp.pickle") + rec3 = sc.load("rec_motion_interp.pickle") + # import matplotlib.pyplot as plt # import spikeinterface.widgets as sw # fig, ax = plt.subplots() @@ -207,7 +239,7 @@ def test_InterpolateMotionRecording(): if __name__ == "__main__": # test_correct_motion_on_peaks() - test_interpolate_motion_on_traces() + # test_interpolate_motion_on_traces() # test_interpolation_simple() - # test_InterpolateMotionRecording() - test_cross_band_interpolation() + test_InterpolateMotionRecording() + # test_cross_band_interpolation()