Skip to content

Commit

Permalink
Merge branch 'main' into fast-correlogram-merge
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishalcrow authored Jan 31, 2025
2 parents 9d95496 + 12a1276 commit d2d5942
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 52 deletions.
81 changes: 59 additions & 22 deletions src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def __init__(
self._dtype = spike_peak_dtype

self.include_spikes_in_margin = include_spikes_in_margin
if include_spikes_in_margin is not None:
if include_spikes_in_margin:
self._dtype = spike_peak_dtype + [("in_margin", "bool")]

self.peaks = sorting_to_peaks(sorting, extremum_channel_inds, self._dtype)
Expand Down Expand Up @@ -228,12 +228,6 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin, pea
# get local peaks
sl = self.segment_slices[segment_index]
peaks_in_segment = self.peaks[sl]
# if self.include_spikes_in_margin:
# i0, i1 = np.searchsorted(
# peaks_in_segment["sample_index"], [start_frame - max_margin, end_frame + max_margin]
# )
# else:
# i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame])
i0, i1 = peak_slice

local_peaks = peaks_in_segment[i0:i1]
Expand Down Expand Up @@ -435,21 +429,59 @@ def compute(self, traces, peaks):
return sparse_wfs


def find_parent_of_type(list_of_parents, parent_type, unique=True):
def find_parent_of_type(list_of_parents, parent_type):
"""
Find a single parent of a given type(s) in a list of parents.
If multiple parents of the given type are found, the first parent is returned.
Parameters
----------
list_of_parents : list of PipelineNode
List of parents to search through.
parent_type : type | tuple of types
The type of parent to search for.
Returns
-------
parent : PipelineNode or None
The parent of the given type. Returns None if no parent of the given type is found.
"""
if list_of_parents is None:
return None

parents = find_parents_of_type(list_of_parents, parent_type)

if len(parents) > 0:
return parents[0]
else:
return None


def find_parents_of_type(list_of_parents, parent_type):
"""
Find all parents of a given type(s) in a list of parents.
Parameters
----------
list_of_parents : list of PipelineNode
List of parents to search through.
parent_type : type | tuple of types
The type(s) of parents to search for.
Returns
-------
parents : list of PipelineNode
List of parents of the given type(s). Returns an empty list if no parents of the given type(s) are found.
"""
if list_of_parents is None:
return []

parents = []
for parent in list_of_parents:
if isinstance(parent, parent_type):
parents.append(parent)

if unique and len(parents) == 1:
return parents[0]
elif not unique and len(parents) > 1:
return parents[0]
else:
return None
return parents


def check_graph(nodes):
Expand All @@ -471,7 +503,7 @@ def check_graph(nodes):
assert parent in nodes, f"Node {node} has parent {parent} that was not passed in nodes"
assert (
nodes.index(parent) < i
), f"Node are ordered incorrectly: {node} before {parent} in the pipeline definition."
), f"Nodes are ordered incorrectly: {node} before {parent} in the pipeline definition."

return nodes

Expand Down Expand Up @@ -607,12 +639,16 @@ def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_c
skip_after_n_peaks_per_worker = worker_ctx["skip_after_n_peaks_per_worker"]

recording_segment = recording._recording_segments[segment_index]
node0 = nodes[0]

if isinstance(node0, (SpikeRetriever, PeakRetriever)):
# in this case PeakSource could have no peaks and so no need to load traces just skip
peak_slice = i0, i1 = node0.get_peak_slice(segment_index, start_frame, end_frame, max_margin)
load_trace_and_compute = i0 < i1
retrievers = find_parents_of_type(nodes, (SpikeRetriever, PeakRetriever))
# get peak slices once for all retrievers
peak_slice_by_retriever = {}
for retriever in retrievers:
peak_slice = i0, i1 = retriever.get_peak_slice(segment_index, start_frame, end_frame, max_margin)
peak_slice_by_retriever[retriever] = peak_slice

if len(peak_slice_by_retriever) > 0:
# in this case the retrievers could have no peaks, so we test if any spikes are in the chunk
load_trace_and_compute = any(i0 < i1 for i0, i1 in peak_slice_by_retriever.values())
else:
# PeakDetector always need traces
load_trace_and_compute = True
Expand Down Expand Up @@ -646,7 +682,8 @@ def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_c
node_output = node.compute(trace_detection, start_frame, end_frame, segment_index, max_margin)
# set sample index to local
node_output[0]["sample_index"] += extra_margin
elif isinstance(node, PeakSource):
elif isinstance(node, (PeakRetriever, SpikeRetriever)):
peak_slice = peak_slice_by_retriever[node]
node_output = node.compute(traces_chunk, start_frame, end_frame, segment_index, max_margin, peak_slice)
else:
# TODO later when in master: change the signature of all nodes (or maybe not!)
Expand Down
5 changes: 2 additions & 3 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2033,19 +2033,18 @@ def load(cls, sorting_analyzer):
return None

def load_run_info(self):
run_info = None
if self.format == "binary_folder":
extension_folder = self._get_binary_extension_folder()
run_info_file = extension_folder / "run_info.json"
if run_info_file.is_file():
with open(str(run_info_file), "r") as f:
run_info = json.load(f)
else:
warnings.warn(f"Found no run_info file for {self.extension_name}, extension should be re-computed.")
run_info = None

elif self.format == "zarr":
extension_group = self._get_zarr_extension_group(mode="r")
run_info = extension_group.attrs.get("run_info", None)

if run_info is None:
warnings.warn(f"Found no run_info file for {self.extension_name}, extension should be re-computed.")
self.run_info = run_info
Expand Down
19 changes: 12 additions & 7 deletions src/spikeinterface/postprocessing/template_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,13 +327,18 @@ def _run(self, verbose=False):
)

existing_metrics = []
tm_extension = self.sorting_analyzer.get_extension("template_metrics")
if (
delete_existing_metrics is False
and tm_extension is not None
and tm_extension.data.get("metrics") is not None
):
existing_metrics = tm_extension.params["metric_names"]

# Check if we need to propogate any old metrics. If so, we'll do that.
# Otherwise, we'll avoid attempting to load an empty template_metrics.
if set(self.params["metrics_to_compute"]) != set(self.params["metric_names"]):

tm_extension = self.sorting_analyzer.get_extension("template_metrics")
if (
delete_existing_metrics is False
and tm_extension is not None
and tm_extension.data.get("metrics") is not None
):
existing_metrics = tm_extension.params["metric_names"]

existing_metrics = []
# here we get in the loaded via the dict only (to avoid full loading from disk after params reset)
Expand Down
37 changes: 29 additions & 8 deletions src/spikeinterface/sortingcomponents/motion/motion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,10 +587,14 @@ def ensure_time_bins(time_bin_centers_s=None, time_bin_edges_s=None):
Going from centers to edges is done by taking midpoints and padding with the
left and rightmost centers.
To handle multi segment, this function is working both:
* array/array input
* list[array]/list[array] input
Parameters
----------
time_bin_centers_s : None or np.array
time_bin_edges_s : None or np.array
time_bin_centers_s : None or np.array or list[np.array]
time_bin_edges_s : None or np.array or list[np.array]
Returns
-------
Expand All @@ -600,17 +604,34 @@ def ensure_time_bins(time_bin_centers_s=None, time_bin_edges_s=None):
raise ValueError("Need at least one of time_bin_centers_s or time_bin_edges_s.")

if time_bin_centers_s is None:
assert time_bin_edges_s.ndim == 1 and time_bin_edges_s.size >= 2
time_bin_centers_s = 0.5 * (time_bin_edges_s[1:] + time_bin_edges_s[:-1])
if isinstance(time_bin_edges_s, list):
# multi segment cas
time_bin_centers_s = []
for be in time_bin_edges_s:
bc, _ = ensure_time_bins(time_bin_centers_s=None, time_bin_edges_s=be)
time_bin_centers_s.append(bc)
else:
# simple segment
assert time_bin_edges_s.ndim == 1 and time_bin_edges_s.size >= 2
time_bin_centers_s = 0.5 * (time_bin_edges_s[1:] + time_bin_edges_s[:-1])

if time_bin_edges_s is None:
time_bin_edges_s = np.empty(time_bin_centers_s.shape[0] + 1, dtype=time_bin_centers_s.dtype)
time_bin_edges_s[[0, -1]] = time_bin_centers_s[[0, -1]]
if time_bin_centers_s.size > 2:
time_bin_edges_s[1:-1] = 0.5 * (time_bin_centers_s[1:] + time_bin_centers_s[:-1])
if isinstance(time_bin_centers_s, list):
# multi segment cas
time_bin_edges_s = []
for bc in time_bin_centers_s:
_, be = ensure_time_bins(time_bin_centers_s=bc, time_bin_edges_s=None)
time_bin_edges_s.append(be)
else:
# simple segment
time_bin_edges_s = np.empty(time_bin_centers_s.shape[0] + 1, dtype=time_bin_centers_s.dtype)
time_bin_edges_s[[0, -1]] = time_bin_centers_s[[0, -1]]
if time_bin_centers_s.size > 2:
time_bin_edges_s[1:-1] = 0.5 * (time_bin_centers_s[1:] + time_bin_centers_s[:-1])

return time_bin_centers_s, time_bin_edges_s



def ensure_time_bin_edges(time_bin_centers_s=None, time_bin_edges_s=None):
return ensure_time_bins(time_bin_centers_s, time_bin_edges_s)[1]
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,26 @@
interpolate_motion_on_traces,
)
from spikeinterface.sortingcomponents.tests.common import make_dataset

from spikeinterface.core import generate_ground_truth_recording

def make_fake_motion(rec):
# make a fake motion object
duration = rec.get_total_duration()

locs = rec.get_channel_locations()
temporal_bins = np.arange(0.5, duration - 0.49, 0.5)
spatial_bins = np.arange(locs[:, 1].min(), locs[:, 1].max(), 100)
displacement = np.zeros((temporal_bins.size, spatial_bins.size))
displacement[:, :] = np.linspace(-30, 30, temporal_bins.size)[:, None]

motion = Motion([displacement], [temporal_bins], spatial_bins, direction="y")
displacement = []
temporal_bins = []
for segment_index in range(rec.get_num_segments()):
duration = rec.get_duration(segment_index=segment_index)
seg_time_bins = np.arange(0.5, duration - 0.49, 0.5)
seg_disp = np.zeros((seg_time_bins.size, spatial_bins.size))
seg_disp[:, :] = np.linspace(-30, 30, seg_time_bins.size)[:, None]

temporal_bins.append(seg_time_bins)
displacement.append(seg_disp)

motion = Motion(displacement, temporal_bins, spatial_bins, direction="y")

return motion

Expand Down Expand Up @@ -176,7 +184,27 @@ def test_cross_band_interpolation():


def test_InterpolateMotionRecording():
rec, sorting = make_dataset()
# rec, sorting = make_dataset()

# 2 segments
rec, sorting = generate_ground_truth_recording(
durations=[30.0],
sampling_frequency=30000.0,
num_channels=32,
num_units=10,
generate_probe_kwargs=dict(
num_columns=2,
xpitch=20,
ypitch=20,
contact_shapes="circle",
contact_shape_params={"radius": 6},
),
generate_sorting_kwargs=dict(firing_rates=6.0, refractory_period_ms=4.0),
noise_kwargs=dict(noise_levels=5.0, strategy="on_the_fly"),
seed=2205,
)


motion = make_fake_motion(rec)

rec2 = InterpolateMotionRecording(rec, motion, border_mode="force_extrapolate")
Expand All @@ -187,15 +215,19 @@ def test_InterpolateMotionRecording():

rec2 = InterpolateMotionRecording(rec, motion, border_mode="remove_channels")
assert rec2.channel_ids.size == 24
for ch_id in (0, 1, 14, 15, 16, 17, 30, 31):
for ch_id in ("0", "1", "14", "15", "16", "17", "30", "31"):
assert ch_id not in rec2.channel_ids

traces = rec2.get_traces(segment_index=0, start_frame=0, end_frame=30000)
assert traces.shape == (30000, 24)

traces = rec2.get_traces(segment_index=0, start_frame=0, end_frame=30000, channel_ids=[3, 4])
traces = rec2.get_traces(segment_index=0, start_frame=0, end_frame=30000, channel_ids=["3", "4"])
assert traces.shape == (30000, 2)

# test dump.load when multi segments
rec2.dump("rec_motion_interp.pickle")
rec3 = sc.load("rec_motion_interp.pickle")

# import matplotlib.pyplot as plt
# import spikeinterface.widgets as sw
# fig, ax = plt.subplots()
Expand All @@ -207,7 +239,7 @@ def test_InterpolateMotionRecording():

if __name__ == "__main__":
# test_correct_motion_on_peaks()
test_interpolate_motion_on_traces()
# test_interpolate_motion_on_traces()
# test_interpolation_simple()
# test_InterpolateMotionRecording()
test_cross_band_interpolation()
test_InterpolateMotionRecording()
# test_cross_band_interpolation()

0 comments on commit d2d5942

Please sign in to comment.