From 582a09c4400497f5799345353bf6f3386c119fb1 Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Fri, 12 Jan 2024 10:13:26 -0800 Subject: [PATCH] Fix SpikeSortingOutput get_recording and get_sorting (#761) * Avoid merge restrict --- src/spyglass/spikesorting/merge.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/src/spyglass/spikesorting/merge.py b/src/spyglass/spikesorting/merge.py index f6be53f28..d01b720c0 100644 --- a/src/spyglass/spikesorting/merge.py +++ b/src/spyglass/spikesorting/merge.py @@ -50,22 +50,16 @@ class CuratedSpikeSorting(SpyglassMixin, dj.Part): # noqa: F811 def get_recording(cls, key): """get the recording associated with a spike sorting output""" - recording_key = cls.merge_restrict(key).proj() - query = ( - source_class_dict[ - to_camel_case(cls.merge_get_parent(key).table_name) - ] - & recording_key - ) - return query.get_recording(recording_key) + source_table = source_class_dict[ + to_camel_case(cls.merge_get_parent(key).table_name) + ] + query = source_table & cls.merge_get_part(key) + return query.get_recording(query.fetch("KEY")) def get_sorting(cls, key): """get the sorting associated with a spike sorting output""" - sorting_key = cls.merge_restrict(key).proj() - query = ( - source_class_dict[ - to_camel_case(cls.merge_get_parent(key).table_name) - ] - & sorting_key - ) - return query.get_sorting(sorting_key) + source_table = source_class_dict[ + to_camel_case(cls.merge_get_parent(key).table_name) + ] + query = source_table & cls.merge_get_part(key) + return query.get_sorting(query.fetch("KEY"))