Skip to content

Commit

Permalink
Remove unused unique arg and fix empty parents
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 committed Jan 27, 2025
1 parent cf7c393 commit acd5c17
Showing 1 changed file with 36 additions and 14 deletions.
50 changes: 36 additions & 14 deletions src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,36 +429,59 @@ def compute(self, traces, peaks):
return sparse_wfs


def find_parent_of_type(list_of_parents, parent_type, unique=True):
def find_parent_of_type(list_of_parents, parent_type):
"""
Find a single parent of a given type(s) in a list of parents.
If multiple parents of the given type are found, the first parent is returned.
Parameters
----------
list_of_parents : list of PipelineNode
List of parents to search through.
parent_type : type
The type of parent to search for.
Returns
-------
parent : PipelineNode or None
The parent of the given type. Returns None if no parent of the given type is found.
"""
if list_of_parents is None:
return None

parents = []
for parent in list_of_parents:
if isinstance(parent, parent_type):
parents.append(parent)
parents = find_parents_of_type(list_of_parents, parent_type)

if unique and len(parents) == 1:
return parents[0]
elif not unique and len(parents) > 1:
if len(parents) > 0:
return parents[0]
else:
return None


def find_parents_of_type(list_of_parents, parent_type):
"""
Find all parents of a given type(s) in a list of parents.
Parameters
----------
list_of_parents : list of PipelineNode
List of parents to search through.
parent_type : type | tuple of types
The type(s) of parents to search for.
Returns
-------
parents : list of PipelineNode
List of parents of the given type(s). Returns an empty list if no parents of the given type(s) are found.
"""
if list_of_parents is None:
return None
return []

parents = []
for parent in list_of_parents:
if isinstance(parent, parent_type):
parents.append(parent)

if len(parents) > 0:
return parents
else:
return None
return parents


def check_graph(nodes):
Expand Down Expand Up @@ -618,7 +641,6 @@ def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_c
recording_segment = recording._recording_segments[segment_index]
retrievers = find_parents_of_type(nodes, (SpikeRetriever, PeakRetriever))
# get peak slices once for all retrievers
retriever_node = None
peak_slice_by_retriever = {}
for retriever in retrievers:
peak_slice = i0, i1 = retriever.get_peak_slice(segment_index, start_frame, end_frame, max_margin)
Expand Down

0 comments on commit acd5c17

Please sign in to comment.