Skip to content

Commit

Permalink
implement full_key_decorator within clusterless pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelbray32 committed Dec 2, 2024
1 parent 19930b7 commit 8ea1a39
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 42 deletions.
61 changes: 29 additions & 32 deletions src/spyglass/decoding/v1/clusterless.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

import copy
import uuid
from functools import wraps
from pathlib import Path

import datajoint as dj
Expand All @@ -33,31 +32,12 @@
from spyglass.position.position_merge import PositionOutput # noqa: F401
from spyglass.settings import config
from spyglass.utils import SpyglassMixin, SpyglassMixinPart, logger
from spyglass.utils.dj_helper_fn import full_key_decorator
from spyglass.utils.spikesorting import firing_rate_from_spike_indicator

schema = dj.schema("decoding_clusterless_v1")


def classmethod_full_key_decorator(required_keys=[]):
def decorator(method):
@wraps(method)
def wrapper(cls, key=None, *args, **kwargs):
# Ensure key is not None
if key is None:
key = {}

# Check if required keys are in key, and fetch if not
if not all([k in key for k in required_keys]):
key = (cls() & key).fetch1("KEY")

# Call the original method with the modified key
return method(cls, key, *args, **kwargs)

return wrapper

return decorator


@schema
class UnitWaveformFeaturesGroup(SpyglassMixin, dj.Manual):
definition = """
Expand Down Expand Up @@ -336,8 +316,9 @@ def fetch_model(self):
"""Retrieve the decoding model"""
return ClusterlessDetector.load_model(self.fetch1("classifier_path"))

@staticmethod
def fetch_environments(key):
@classmethod
@full_key_decorator(required_keys=["decoding_param_name"])
def fetch_environments(cls, key):
"""Fetch the environments for the decoding model
Parameters
Expand Down Expand Up @@ -375,8 +356,16 @@ def fetch_environments(key):

return classifier.environments

@staticmethod
def fetch_position_info(key):
@classmethod
@full_key_decorator(
required_keys=[
"nwb_file_name",
"position_group_name",
"encoding_interval",
"decoding_interval",
]
)
def fetch_position_info(cls, key):
"""Fetch the position information for the decoding model
Parameters
Expand All @@ -401,8 +390,16 @@ def fetch_position_info(key):

return position_info, position_variable_names

@staticmethod
def fetch_linear_position_info(key):
@classmethod
@full_key_decorator(
required_keys=[
"nwb_file_name",
"position_group_name",
"encoding_interval",
"decoding_interval",
]
)
def fetch_linear_position_info(cls, key):
"""Fetch the position information and project it onto the track graph
Parameters
Expand Down Expand Up @@ -437,8 +434,11 @@ def fetch_linear_position_info(key):
axis=1,
).loc[min_time:max_time]

@staticmethod
def fetch_spike_data(key, filter_by_interval=True):
@classmethod
@full_key_decorator(
required_keys=["nwb_file_name", "waveform_features_group_name"]
)
def fetch_spike_data(cls, key, filter_by_interval=True):
"""Fetch the spike times for the decoding model
Parameters
Expand Down Expand Up @@ -488,9 +488,6 @@ def fetch_spike_data(key, filter_by_interval=True):
return new_spike_times, new_waveform_features

@classmethod
@classmethod_full_key_decorator(
required_keys=["nwb_file_name", "waveform_features_group_name"]
)
def get_spike_indicator(cls, key, time):
"""get spike indicator matrix for the group
Expand Down
41 changes: 33 additions & 8 deletions src/spyglass/decoding/v1/sorted_spikes.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
SpikeSortingOutput,
) # noqa: F401
from spyglass.utils import SpyglassMixin, logger
from spyglass.utils.dj_helper_fn import full_key_decorator

schema = dj.schema("decoding_sorted_spikes_v1")

Expand Down Expand Up @@ -275,8 +276,9 @@ def fetch_model(self):
"""Retrieve the decoding model"""
return SortedSpikesDetector.load_model(self.fetch1("classifier_path"))

@staticmethod
def fetch_environments(key):
@classmethod
@full_key_decorator(required_keys=["decoding_param_name"])
def fetch_environments(cls, key):
"""Fetch the environments for the decoding model
Parameters
Expand Down Expand Up @@ -314,8 +316,16 @@ def fetch_environments(key):

return classifier.environments

@staticmethod
def fetch_position_info(key):
@classmethod
@full_key_decorator(
required_keys=[
"position_group_name",
"nwb_file_name",
"encoding_interval",
"decoding_interval",
]
)
def fetch_position_info(cls, key):
"""Fetch the position information for the decoding model
Parameters
Expand All @@ -339,8 +349,16 @@ def fetch_position_info(key):

return position_info, position_variable_names

@staticmethod
def fetch_linear_position_info(key):
@classmethod
@full_key_decorator(
required_keys=[
"position_group_name",
"nwb_file_name",
"encoding_interval",
"decoding_interval",
]
)
def fetch_linear_position_info(cls, key):
"""Fetch the position information and project it onto the track graph
Parameters
Expand Down Expand Up @@ -374,9 +392,16 @@ def fetch_linear_position_info(key):
axis=1,
).loc[min_time:max_time]

@staticmethod
@classmethod
@full_key_decorator(
required_keys=["encoding_interval", "decoding_interval"]
)
def fetch_spike_data(
key, filter_by_interval=True, time_slice=None, return_unit_ids=False
cls,
key,
filter_by_interval=True,
time_slice=None,
return_unit_ids=False,
) -> Union[list[np.ndarray], Optional[list[dict]]]:
"""Fetch the spike times for the decoding model
Expand Down
9 changes: 7 additions & 2 deletions src/spyglass/spikesorting/analysis/v1/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from spyglass.common import Session # noqa: F401
from spyglass.settings import test_mode
from spyglass.spikesorting.spikesorting_merge import SpikeSortingOutput
from spyglass.utils.dj_helper_fn import full_key_decorator
from spyglass.utils.dj_mixin import SpyglassMixin, SpyglassMixinPart
from spyglass.utils.spikesorting import firing_rate_from_spike_indicator

Expand Down Expand Up @@ -127,9 +128,13 @@ def filter_units(
include_mask[ind] = True
return include_mask

@staticmethod
@classmethod
@full_key_decorator()
def fetch_spike_data(
key: dict, time_slice: list[float] = None, return_unit_ids: bool = False
cls,
key: dict,
time_slice: list[float] = None,
return_unit_ids: bool = False,
) -> Union[list[np.ndarray], Optional[list[dict]]]:
"""fetch spike times for units in the group
Expand Down
40 changes: 40 additions & 0 deletions src/spyglass/utils/dj_helper_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import inspect
import multiprocessing.pool
import os
from functools import wraps
from pathlib import Path
from typing import Iterable, List, Type, Union
from uuid import uuid4
Expand Down Expand Up @@ -580,3 +581,42 @@ def str_to_bool(value) -> bool:
if not value:
return False
return str(value).lower() in ("y", "yes", "t", "true", "on", "1")


def full_key_decorator(required_keys: list[str] = None):
"""Decorator to ensure that the key is fully specified before calling the
method. If the key is not fully specified, the method will attempt to fetch
the complete key from the database.
Parameters
----------
required_keys : list[str], optional
List of keys that must be present in the key. If None, all keys are
required. Default None
"""

def decorator(method):
@wraps(method)
def wrapper(cls, key=None, *args, **kwargs):
# Ensure key is not None
if key is None:
key = {}

# Check if required keys are in key, and fetch if not
key_check = (
cls.primary_key if required_keys is None else required_keys
)
if not all([k in key for k in key_check]):
if not len(query := cls() & key) == 1:
raise KeyError(
f"Key {key} is neither fully specified nor a unique entry in"
+ f"{cls.full_table_name}"
)
key = query.fetch1("KEY")

# Call the original method with the modified key
return method(cls, key, *args, **kwargs)

return wrapper

return decorator

0 comments on commit 8ea1a39

Please sign in to comment.