diff --git a/src/spyglass/decoding/v1/clusterless.py b/src/spyglass/decoding/v1/clusterless.py index fd95b9bee..8dd7d8bcf 100644 --- a/src/spyglass/decoding/v1/clusterless.py +++ b/src/spyglass/decoding/v1/clusterless.py @@ -11,7 +11,6 @@ import copy import uuid -from functools import wraps from pathlib import Path import datajoint as dj @@ -33,31 +32,12 @@ from spyglass.position.position_merge import PositionOutput # noqa: F401 from spyglass.settings import config from spyglass.utils import SpyglassMixin, SpyglassMixinPart, logger +from spyglass.utils.dj_helper_fn import full_key_decorator from spyglass.utils.spikesorting import firing_rate_from_spike_indicator schema = dj.schema("decoding_clusterless_v1") -def classmethod_full_key_decorator(required_keys=[]): - def decorator(method): - @wraps(method) - def wrapper(cls, key=None, *args, **kwargs): - # Ensure key is not None - if key is None: - key = {} - - # Check if required keys are in key, and fetch if not - if not all([k in key for k in required_keys]): - key = (cls() & key).fetch1("KEY") - - # Call the original method with the modified key - return method(cls, key, *args, **kwargs) - - return wrapper - - return decorator - - @schema class UnitWaveformFeaturesGroup(SpyglassMixin, dj.Manual): definition = """ @@ -336,8 +316,9 @@ def fetch_model(self): """Retrieve the decoding model""" return ClusterlessDetector.load_model(self.fetch1("classifier_path")) - @staticmethod - def fetch_environments(key): + @classmethod + @full_key_decorator(required_keys=["decoding_param_name"]) + def fetch_environments(cls, key): """Fetch the environments for the decoding model Parameters @@ -375,8 +356,16 @@ def fetch_environments(key): return classifier.environments - @staticmethod - def fetch_position_info(key): + @classmethod + @full_key_decorator( + required_keys=[ + "nwb_file_name", + "position_group_name", + "encoding_interval", + "decoding_interval", + ] + ) + def fetch_position_info(cls, key): """Fetch the position information for the decoding model Parameters @@ -401,8 +390,16 @@ def fetch_position_info(key): return position_info, position_variable_names - @staticmethod - def fetch_linear_position_info(key): + @classmethod + @full_key_decorator( + required_keys=[ + "nwb_file_name", + "position_group_name", + "encoding_interval", + "decoding_interval", + ] + ) + def fetch_linear_position_info(cls, key): """Fetch the position information and project it onto the track graph Parameters @@ -437,8 +434,11 @@ def fetch_linear_position_info(key): axis=1, ).loc[min_time:max_time] - @staticmethod - def fetch_spike_data(key, filter_by_interval=True): + @classmethod + @full_key_decorator( + required_keys=["nwb_file_name", "waveform_features_group_name"] + ) + def fetch_spike_data(cls, key, filter_by_interval=True): """Fetch the spike times for the decoding model Parameters @@ -488,9 +488,6 @@ def fetch_spike_data(key, filter_by_interval=True): return new_spike_times, new_waveform_features @classmethod - @classmethod_full_key_decorator( - required_keys=["nwb_file_name", "waveform_features_group_name"] - ) def get_spike_indicator(cls, key, time): """get spike indicator matrix for the group diff --git a/src/spyglass/decoding/v1/sorted_spikes.py b/src/spyglass/decoding/v1/sorted_spikes.py index 9e4c2c3ba..8c1465b08 100644 --- a/src/spyglass/decoding/v1/sorted_spikes.py +++ b/src/spyglass/decoding/v1/sorted_spikes.py @@ -33,6 +33,7 @@ SpikeSortingOutput, ) # noqa: F401 from spyglass.utils import SpyglassMixin, logger +from spyglass.utils.dj_helper_fn import full_key_decorator schema = dj.schema("decoding_sorted_spikes_v1") @@ -275,8 +276,9 @@ def fetch_model(self): """Retrieve the decoding model""" return SortedSpikesDetector.load_model(self.fetch1("classifier_path")) - @staticmethod - def fetch_environments(key): + @classmethod + @full_key_decorator(required_keys=["decoding_param_name"]) + def fetch_environments(cls, key): """Fetch the environments for the decoding model Parameters @@ -314,8 +316,16 @@ def fetch_environments(key): return classifier.environments - @staticmethod - def fetch_position_info(key): + @classmethod + @full_key_decorator( + required_keys=[ + "position_group_name", + "nwb_file_name", + "encoding_interval", + "decoding_interval", + ] + ) + def fetch_position_info(cls, key): """Fetch the position information for the decoding model Parameters @@ -339,8 +349,16 @@ def fetch_position_info(key): return position_info, position_variable_names - @staticmethod - def fetch_linear_position_info(key): + @classmethod + @full_key_decorator( + required_keys=[ + "position_group_name", + "nwb_file_name", + "encoding_interval", + "decoding_interval", + ] + ) + def fetch_linear_position_info(cls, key): """Fetch the position information and project it onto the track graph Parameters @@ -374,9 +392,16 @@ def fetch_linear_position_info(key): axis=1, ).loc[min_time:max_time] - @staticmethod + @classmethod + @full_key_decorator( + required_keys=["encoding_interval", "decoding_interval"] + ) def fetch_spike_data( - key, filter_by_interval=True, time_slice=None, return_unit_ids=False + cls, + key, + filter_by_interval=True, + time_slice=None, + return_unit_ids=False, ) -> Union[list[np.ndarray], Optional[list[dict]]]: """Fetch the spike times for the decoding model diff --git a/src/spyglass/spikesorting/analysis/v1/group.py b/src/spyglass/spikesorting/analysis/v1/group.py index 2f862c4fb..f0bddc8bb 100644 --- a/src/spyglass/spikesorting/analysis/v1/group.py +++ b/src/spyglass/spikesorting/analysis/v1/group.py @@ -8,6 +8,7 @@ from spyglass.common import Session # noqa: F401 from spyglass.settings import test_mode from spyglass.spikesorting.spikesorting_merge import SpikeSortingOutput +from spyglass.utils.dj_helper_fn import full_key_decorator from spyglass.utils.dj_mixin import SpyglassMixin, SpyglassMixinPart from spyglass.utils.spikesorting import firing_rate_from_spike_indicator @@ -127,9 +128,13 @@ def filter_units( include_mask[ind] = True return include_mask - @staticmethod + @classmethod + @full_key_decorator() def fetch_spike_data( - key: dict, time_slice: list[float] = None, return_unit_ids: bool = False + cls, + key: dict, + time_slice: list[float] = None, + return_unit_ids: bool = False, ) -> Union[list[np.ndarray], Optional[list[dict]]]: """fetch spike times for units in the group diff --git a/src/spyglass/utils/dj_helper_fn.py b/src/spyglass/utils/dj_helper_fn.py index 42cf67ba0..fb6304624 100644 --- a/src/spyglass/utils/dj_helper_fn.py +++ b/src/spyglass/utils/dj_helper_fn.py @@ -3,6 +3,7 @@ import inspect import multiprocessing.pool import os +from functools import wraps from pathlib import Path from typing import Iterable, List, Type, Union from uuid import uuid4 @@ -580,3 +581,42 @@ def str_to_bool(value) -> bool: if not value: return False return str(value).lower() in ("y", "yes", "t", "true", "on", "1") + + +def full_key_decorator(required_keys: list[str] = None): + """Decorator to ensure that the key is fully specified before calling the + method. If the key is not fully specified, the method will attempt to fetch + the complete key from the database. + + Parameters + ---------- + required_keys : list[str], optional + List of keys that must be present in the key. If None, all keys are + required. Default None + """ + + def decorator(method): + @wraps(method) + def wrapper(cls, key=None, *args, **kwargs): + # Ensure key is not None + if key is None: + key = {} + + # Check if required keys are in key, and fetch if not + key_check = ( + cls.primary_key if required_keys is None else required_keys + ) + if not all([k in key for k in key_check]): + if not len(query := cls() & key) == 1: + raise KeyError( + f"Key {key} is neither fully specified nor a unique entry in" + + f"{cls.full_table_name}" + ) + key = query.fetch1("KEY") + + # Call the original method with the modified key + return method(cls, key, *args, **kwargs) + + return wrapper + + return decorator