Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decoding qol updates #1198

Merged
merged 11 commits into from
Dec 5, 2024
6 changes: 3 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ dj.FreeTable(dj.conn(), "common_session.session_group").drop()
- Remove numpy version restriction #1169
- Merge table delete removes orphaned master entries #1164
- Edit `merge_fetch` to expect positional before keyword arguments #1181
- Add mixin method `get_fully_defined_key` #1198

### Pipelines

Expand All @@ -57,6 +58,8 @@ dj.FreeTable(dj.conn(), "common_session.session_group").drop()
- Decoding

- Fix edge case errors in spike time loading #1083
- Allow fetch of partial key from `DecodingParameters` #1198
- Allow data fetching with partial but unique key #1198

- Linearization

Expand Down Expand Up @@ -115,7 +118,6 @@ dj.FreeTable(dj.conn(), "common_session.session_group").drop()
- Add tool for checking threads for metadata locks on a table #1063
- Use peripheral tables as fallback in `TableChains` #1035
- Ignore non-Spyglass tables during descendant check for `part_masters` #1035
- Add utility `full_key_decorator` for us in mixin tables #1198

### Pipelines

Expand All @@ -140,8 +142,6 @@ dj.FreeTable(dj.conn(), "common_session.session_group").drop()
- Add option to upsample data rate in `PositionGroup` #1008
- Avoid interpolating over large `nan` intervals in position #1033
- Minor code calling corrections #1073
- Allow fetch or partial key from `DecodingParameters` #1198
- Allow data fetching with partial but unique key #1198

- Position

Expand Down
51 changes: 30 additions & 21 deletions src/spyglass/decoding/v1/clusterless.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from spyglass.position.position_merge import PositionOutput # noqa: F401
from spyglass.settings import config
from spyglass.utils import SpyglassMixin, SpyglassMixinPart, logger
from spyglass.utils.dj_helper_fn import full_key_decorator
from spyglass.utils.spikesorting import firing_rate_from_spike_indicator

schema = dj.schema("decoding_clusterless_v1")
Expand Down Expand Up @@ -317,7 +316,6 @@ def fetch_model(self):
return ClusterlessDetector.load_model(self.fetch1("classifier_path"))

@classmethod
@full_key_decorator(required_keys=["decoding_param_name"])
def fetch_environments(cls, key):
"""Fetch the environments for the decoding model

Expand All @@ -331,6 +329,9 @@ def fetch_environments(cls, key):
List[TrackGraph]
list of track graphs in the trained model
"""
key = cls.get_fully_defined_key(
key, required_fields=["decoding_param_name"]
)
model_params = (
DecodingParameters
& {"decoding_param_name": key["decoding_param_name"]}
Expand All @@ -357,14 +358,6 @@ def fetch_environments(cls, key):
return classifier.environments

@classmethod
@full_key_decorator(
required_keys=[
"nwb_file_name",
"position_group_name",
"encoding_interval",
"decoding_interval",
]
)
def fetch_position_info(cls, key):
"""Fetch the position information for the decoding model

Expand All @@ -378,6 +371,15 @@ def fetch_position_info(cls, key):
Tuple[pd.DataFrame, List[str]]
The position information and the names of the position variables
"""
key = cls.get_fully_defined_key(
key,
required_fields=[
"nwb_file_name",
"position_group_name",
"encoding_interval",
"decoding_interval",
],
)
position_group_key = {
"position_group_name": key["position_group_name"],
"nwb_file_name": key["nwb_file_name"],
Expand All @@ -391,14 +393,6 @@ def fetch_position_info(cls, key):
return position_info, position_variable_names

@classmethod
@full_key_decorator(
required_keys=[
"nwb_file_name",
"position_group_name",
"encoding_interval",
"decoding_interval",
]
)
def fetch_linear_position_info(cls, key):
"""Fetch the position information and project it onto the track graph

Expand All @@ -412,6 +406,16 @@ def fetch_linear_position_info(cls, key):
pd.DataFrame
The linearized position information
"""
key = cls.get_fully_defined_key(
key,
required_fields=[
"nwb_file_name",
"position_group_name",
"encoding_interval",
"decoding_interval",
],
)

environment = ClusterlessDecodingV1.fetch_environments(key)[0]

position_df = ClusterlessDecodingV1.fetch_position_info(key)[0]
Expand All @@ -435,9 +439,6 @@ def fetch_linear_position_info(cls, key):
).loc[min_time:max_time]

@classmethod
@full_key_decorator(
required_keys=["nwb_file_name", "waveform_features_group_name"]
)
def fetch_spike_data(cls, key, filter_by_interval=True):
"""Fetch the spike times for the decoding model

Expand All @@ -454,6 +455,14 @@ def fetch_spike_data(cls, key, filter_by_interval=True):
list[np.ndarray]
List of spike times for each unit in the model's spike group
"""
key = cls.get_fully_defined_key(
key,
required_fields=[
"nwb_file_name",
"waveform_features_group_name",
],
)

waveform_keys = (
(
UnitWaveformFeaturesGroup.UnitFeatures
Expand Down
53 changes: 32 additions & 21 deletions src/spyglass/decoding/v1/sorted_spikes.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
SpikeSortingOutput,
) # noqa: F401
from spyglass.utils import SpyglassMixin, logger
from spyglass.utils.dj_helper_fn import full_key_decorator

schema = dj.schema("decoding_sorted_spikes_v1")

Expand Down Expand Up @@ -277,7 +276,6 @@ def fetch_model(self):
return SortedSpikesDetector.load_model(self.fetch1("classifier_path"))

@classmethod
@full_key_decorator(required_keys=["decoding_param_name"])
def fetch_environments(cls, key):
"""Fetch the environments for the decoding model

Expand All @@ -291,6 +289,10 @@ def fetch_environments(cls, key):
List[TrackGraph]
list of track graphs in the trained model
"""
key = cls.get_fully_defined_key(
key, required_fields=["decoding_param_name"]
)

model_params = (
DecodingParameters
& {"decoding_param_name": key["decoding_param_name"]}
Expand All @@ -317,14 +319,6 @@ def fetch_environments(cls, key):
return classifier.environments

@classmethod
@full_key_decorator(
required_keys=[
"position_group_name",
"nwb_file_name",
"encoding_interval",
"decoding_interval",
]
)
def fetch_position_info(cls, key):
"""Fetch the position information for the decoding model

Expand All @@ -338,6 +332,16 @@ def fetch_position_info(cls, key):
Tuple[pd.DataFrame, List[str]]
The position information and the names of the position variables
"""
key = cls.get_fully_defined_key(
key,
required_fields=[
"position_group_name",
"nwb_file_name",
"encoding_interval",
"decoding_interval",
],
)

position_group_key = {
"position_group_name": key["position_group_name"],
"nwb_file_name": key["nwb_file_name"],
Expand All @@ -350,14 +354,6 @@ def fetch_position_info(cls, key):
return position_info, position_variable_names

@classmethod
@full_key_decorator(
required_keys=[
"position_group_name",
"nwb_file_name",
"encoding_interval",
"decoding_interval",
]
)
def fetch_linear_position_info(cls, key):
"""Fetch the position information and project it onto the track graph

Expand All @@ -371,6 +367,16 @@ def fetch_linear_position_info(cls, key):
pd.DataFrame
The linearized position information
"""
key = cls.get_fully_defined_key(
key,
required_fields=[
"position_group_name",
"nwb_file_name",
"encoding_interval",
"decoding_interval",
],
)

environment = SortedSpikesDecodingV1.fetch_environments(key)[0]

position_df = SortedSpikesDecodingV1.fetch_position_info(key)[0]
Expand All @@ -393,9 +399,6 @@ def fetch_linear_position_info(cls, key):
).loc[min_time:max_time]

@classmethod
@full_key_decorator(
required_keys=["encoding_interval", "decoding_interval"]
)
def fetch_spike_data(
cls,
key,
Expand Down Expand Up @@ -424,6 +427,14 @@ def fetch_spike_data(
list[np.ndarray]
List of spike times for each unit in the model's spike group
"""
key = cls.get_fully_defined_key(
key,
required_fields=[
"encoding_interval",
"decoding_interval",
],
)

spike_times, unit_ids = SortedSpikesGroup.fetch_spike_data(
key, return_unit_ids=True
)
Expand Down
5 changes: 2 additions & 3 deletions src/spyglass/spikesorting/analysis/v1/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@

import datajoint as dj
import numpy as np
from ripple_detection import get_multiunit_population_firing_rate

from spyglass.common import Session # noqa: F401
from spyglass.settings import test_mode
from spyglass.spikesorting.spikesorting_merge import SpikeSortingOutput
from spyglass.utils.dj_helper_fn import full_key_decorator
from spyglass.utils.dj_mixin import SpyglassMixin, SpyglassMixinPart
from spyglass.utils.spikesorting import firing_rate_from_spike_indicator

Expand Down Expand Up @@ -129,7 +127,6 @@ def filter_units(
return include_mask

@classmethod
@full_key_decorator()
def fetch_spike_data(
cls,
key: dict,
Expand All @@ -153,6 +150,8 @@ def fetch_spike_data(
list of np.ndarray
list of spike times for each unit in the group
"""
key = cls.get_fully_defined_key(key)

# get merge_ids for SpikeSortingOutput
merge_ids = (
(
Expand Down
40 changes: 0 additions & 40 deletions src/spyglass/utils/dj_helper_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import inspect
import multiprocessing.pool
import os
from functools import wraps
from pathlib import Path
from typing import Iterable, List, Type, Union
from uuid import uuid4
Expand Down Expand Up @@ -581,42 +580,3 @@ def str_to_bool(value) -> bool:
if not value:
return False
return str(value).lower() in ("y", "yes", "t", "true", "on", "1")


def full_key_decorator(required_keys: list[str] = None):
"""Decorator to ensure that the key is fully specified before calling the
method. If the key is not fully specified, the method will attempt to fetch
the complete key from the database.

Parameters
----------
required_keys : list[str], optional
List of keys that must be present in the key. If None, all keys are
required. Default None
"""

def decorator(method):
@wraps(method)
def wrapper(cls, key=None, *args, **kwargs):
# Ensure key is not None
if key is None:
key = {}

# Check if required keys are in key, and fetch if not
key_check = (
cls.primary_key if required_keys is None else required_keys
)
if not all([k in key for k in key_check]):
if not len(query := cls() & key) == 1:
raise KeyError(
f"Key {key} is neither fully specified nor a unique entry in"
+ f"{cls.full_table_name}"
)
key = query.fetch1("KEY")

# Call the original method with the modified key
return method(cls, key, *args, **kwargs)

return wrapper

return decorator
21 changes: 21 additions & 0 deletions src/spyglass/utils/dj_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,27 @@ def _safe_context(cls):
else nullcontext()
)

@classmethod
def get_fully_defined_key(
cls, key: dict = None, required_fields: list[str] = None
) -> dict:
if key is None:
key = dict()

required_fields = required_fields or cls.primary_key
if isinstance(key, (str, dict)): # check is either keys or substrings
if not all(
field in key for field in required_fields
): # check if all required fields are in key
if not len(query := cls() & key) == 1: # check if key is unique
raise KeyError(
f"Key {key} is neither fully specified nor a unique entry in"
+ f"{cls.full_table_name}"
samuelbray32 marked this conversation as resolved.
Show resolved Hide resolved
)
key = query.fetch1("KEY")

return key

# ------------------------------- fetch_nwb -------------------------------

@cached_property
Expand Down
Loading