Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decoding qol updates #1198

Merged
merged 11 commits into from
Dec 5, 2024
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ dj.FreeTable(dj.conn(), "common_session.session_group").drop()
- Add tool for checking threads for metadata locks on a table #1063
- Use peripheral tables as fallback in `TableChains` #1035
- Ignore non-Spyglass tables during descendant check for `part_masters` #1035
- Add utility `full_key_decorator` for us in mixin tables #1198

### Pipelines

Expand All @@ -139,6 +140,8 @@ dj.FreeTable(dj.conn(), "common_session.session_group").drop()
- Add option to upsample data rate in `PositionGroup` #1008
- Avoid interpolating over large `nan` intervals in position #1033
- Minor code calling corrections #1073
- Allow fetch or partial key from `DecodingParameters` #1198
- Allow data fetching with partial but unique key #1198

- Position

Expand Down
37 changes: 29 additions & 8 deletions src/spyglass/decoding/v1/clusterless.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from spyglass.position.position_merge import PositionOutput # noqa: F401
from spyglass.settings import config
from spyglass.utils import SpyglassMixin, SpyglassMixinPart, logger
from spyglass.utils.dj_helper_fn import full_key_decorator
from spyglass.utils.spikesorting import firing_rate_from_spike_indicator

schema = dj.schema("decoding_clusterless_v1")
Expand Down Expand Up @@ -315,8 +316,9 @@ def fetch_model(self):
"""Retrieve the decoding model"""
return ClusterlessDetector.load_model(self.fetch1("classifier_path"))

@staticmethod
def fetch_environments(key):
@classmethod
@full_key_decorator(required_keys=["decoding_param_name"])
def fetch_environments(cls, key):
"""Fetch the environments for the decoding model

Parameters
Expand Down Expand Up @@ -354,8 +356,16 @@ def fetch_environments(key):

return classifier.environments

@staticmethod
def fetch_position_info(key):
@classmethod
@full_key_decorator(
required_keys=[
"nwb_file_name",
"position_group_name",
"encoding_interval",
"decoding_interval",
]
)
def fetch_position_info(cls, key):
"""Fetch the position information for the decoding model

Parameters
Expand All @@ -380,8 +390,16 @@ def fetch_position_info(key):

return position_info, position_variable_names

@staticmethod
def fetch_linear_position_info(key):
@classmethod
@full_key_decorator(
required_keys=[
"nwb_file_name",
"position_group_name",
"encoding_interval",
"decoding_interval",
]
)
def fetch_linear_position_info(cls, key):
"""Fetch the position information and project it onto the track graph

Parameters
Expand Down Expand Up @@ -416,8 +434,11 @@ def fetch_linear_position_info(key):
axis=1,
).loc[min_time:max_time]

@staticmethod
def fetch_spike_data(key, filter_by_interval=True):
@classmethod
@full_key_decorator(
required_keys=["nwb_file_name", "waveform_features_group_name"]
)
def fetch_spike_data(cls, key, filter_by_interval=True):
"""Fetch the spike times for the decoding model

Parameters
Expand Down
48 changes: 34 additions & 14 deletions src/spyglass/decoding/v1/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,28 +69,48 @@ def insert(self, rows, *args, **kwargs):
def fetch(self, *args, **kwargs):
"""Return decoding parameters as a list of classes."""
rows = super().fetch(*args, **kwargs)
if len(rows) > 0 and len(rows[0]) > 1:
if kwargs.get("format", None) == "array":
# case when recalled by dj.fetch(), class conversion performed later in stack
return rows

if not len(args):
# infer args from table heading
args = tuple(self.heading)

if "decoding_params" not in args:
return rows

params_index = args.index("decoding_params")
if len(args) == 1:
# only fetching decoding_params
content = [restore_classes(r) for r in rows]
elif len(rows):
content = []
for (
decoding_param_name,
decoding_params,
decoding_kwargs,
) in rows:
content.append(
(
decoding_param_name,
restore_classes(decoding_params),
decoding_kwargs,
)
)
for row in zip(*rows):
row = list(row)
row[params_index] = restore_classes(row[params_index])
content.append(tuple(row))
else:
content = rows
return content

def fetch1(self, *args, **kwargs):
"""Return one decoding paramset as a class."""
row = super().fetch1(*args, **kwargs)
row["decoding_params"] = restore_classes(row["decoding_params"])

if len(args) == 0:
row["decoding_params"] = restore_classes(row["decoding_params"])
return row

if "decoding_params" in args:
if len(args) == 1:
return restore_classes(row)
row = list(row)
row[args.index("decoding_params")] = restore_classes(
row[args.index("decoding_params")]
)
return tuple(row)

return row


Expand Down
41 changes: 33 additions & 8 deletions src/spyglass/decoding/v1/sorted_spikes.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
SpikeSortingOutput,
) # noqa: F401
from spyglass.utils import SpyglassMixin, logger
from spyglass.utils.dj_helper_fn import full_key_decorator

schema = dj.schema("decoding_sorted_spikes_v1")

Expand Down Expand Up @@ -275,8 +276,9 @@ def fetch_model(self):
"""Retrieve the decoding model"""
return SortedSpikesDetector.load_model(self.fetch1("classifier_path"))

@staticmethod
def fetch_environments(key):
@classmethod
@full_key_decorator(required_keys=["decoding_param_name"])
def fetch_environments(cls, key):
"""Fetch the environments for the decoding model

Parameters
Expand Down Expand Up @@ -314,8 +316,16 @@ def fetch_environments(key):

return classifier.environments

@staticmethod
def fetch_position_info(key):
@classmethod
@full_key_decorator(
required_keys=[
"position_group_name",
"nwb_file_name",
"encoding_interval",
"decoding_interval",
]
)
def fetch_position_info(cls, key):
"""Fetch the position information for the decoding model

Parameters
Expand All @@ -339,8 +349,16 @@ def fetch_position_info(key):

return position_info, position_variable_names

@staticmethod
def fetch_linear_position_info(key):
@classmethod
@full_key_decorator(
required_keys=[
"position_group_name",
"nwb_file_name",
"encoding_interval",
"decoding_interval",
]
)
def fetch_linear_position_info(cls, key):
"""Fetch the position information and project it onto the track graph

Parameters
Expand Down Expand Up @@ -374,9 +392,16 @@ def fetch_linear_position_info(key):
axis=1,
).loc[min_time:max_time]

@staticmethod
@classmethod
@full_key_decorator(
required_keys=["encoding_interval", "decoding_interval"]
)
def fetch_spike_data(
key, filter_by_interval=True, time_slice=None, return_unit_ids=False
cls,
key,
filter_by_interval=True,
time_slice=None,
return_unit_ids=False,
) -> Union[list[np.ndarray], Optional[list[dict]]]:
"""Fetch the spike times for the decoding model

Expand Down
9 changes: 7 additions & 2 deletions src/spyglass/spikesorting/analysis/v1/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from spyglass.common import Session # noqa: F401
from spyglass.settings import test_mode
from spyglass.spikesorting.spikesorting_merge import SpikeSortingOutput
from spyglass.utils.dj_helper_fn import full_key_decorator
from spyglass.utils.dj_mixin import SpyglassMixin, SpyglassMixinPart
from spyglass.utils.spikesorting import firing_rate_from_spike_indicator

Expand Down Expand Up @@ -127,9 +128,13 @@ def filter_units(
include_mask[ind] = True
return include_mask

@staticmethod
@classmethod
@full_key_decorator()
def fetch_spike_data(
key: dict, time_slice: list[float] = None, return_unit_ids: bool = False
cls,
key: dict,
time_slice: list[float] = None,
return_unit_ids: bool = False,
) -> Union[list[np.ndarray], Optional[list[dict]]]:
"""fetch spike times for units in the group

Expand Down
40 changes: 40 additions & 0 deletions src/spyglass/utils/dj_helper_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import inspect
import multiprocessing.pool
import os
from functools import wraps
from pathlib import Path
from typing import Iterable, List, Type, Union
from uuid import uuid4
Expand Down Expand Up @@ -580,3 +581,42 @@ def str_to_bool(value) -> bool:
if not value:
return False
return str(value).lower() in ("y", "yes", "t", "true", "on", "1")


def full_key_decorator(required_keys: list[str] = None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @samuelbray32 - It's not clear to me what motivated using a decorator over adding a helper method to the mixin. A decorator might grant you this functionality to a non-table method, or classes that don't inherit the mixin, but it seems like that's the case here. It's also not clear to me why the methods you use it on require classmethod. Are there pieces of the funcs above that access the uninstanced class?

As a general rule, I'd like to be able to accept string restrictions anywhere we might accept a key as a dict, but I understand its more hoops

class SpyglassMixin:
    ...

    def get_single_entry_key(
        self, key: dict = None, required_fields: list[str] = None
    ) -> dict:
        if key is None:
            return dict()  # I like explicit dict() over {}, bc {} could be set

        required_fields = required_fields or self.primary_key
        if isinstance(key, (str, dict)):  # check is either keys or substrings
            if not all(field in key for field in required_fields):
                raise ValueError(
                    f"Key must contain all required fields: {required_fields}"
                )

        if len(query := self & key) == 1:
            return query.fetch("KEY", as_dict=True)[0]

        raise ValueError(
            f"Key must identify exactly one entry in {self.camel_name}: {key}"
        )


class SortedSpikesDecodingV1(SpyglassMixin, dj.Computed):
    ...

    def fetch_environments(self, key):
        key = self.get_single_entry_key(key, ["decoding_param_name"])

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Chris:

  • I wrote it as a decorator because my original plan was that it would require no additional arguments and simply fetch the fully key every time before running the function. However, several of the cases where it would be useful need to work prior to key insertion in the calling table, requiring a check if the existing key is sufficient before calling fetch. I agree that the decorator isn't as clean anymore and can move it into a mixin method.

  • With respect to the use of classmethod, I did that for functions that were previously staticmethod. I was trying to make least amount of change to the existing while still letting me add the utility. For those methods, I'm assuming they were made static because they are used in the make function before the key is inserted into the table.

  • For string restrictions, several of these functions already assume that the passed key is a dict. This utility method could handle case of strings by fetching the key from the table, which would generalize the accepted restrictions of the existing functions

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. To keep down the complexity of the codebase overall, I think I'd avoid decorators until we find a case where there's no other way around it

I do wish we had a better pattern for when we do and don't use classmethods. As is, it seems pretty unpredictable for the end user whether to use Table.method(arg) or Table().method(arg). I've tried to default to the latter whenever possible, without classmethod, and reserve the latter for only when the method is doing something at the DJ class level. No reason to change an existing method. Is 'before insertion' the convention you've used?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's the convention I inferred from the code. Particularly in cases where the function is interacting with upstream tables or in the make function, since those are things that could be called in any future new version of the table

"""Decorator to ensure that the key is fully specified before calling the
method. If the key is not fully specified, the method will attempt to fetch
the complete key from the database.

Parameters
----------
required_keys : list[str], optional
List of keys that must be present in the key. If None, all keys are
required. Default None
"""

def decorator(method):
@wraps(method)
def wrapper(cls, key=None, *args, **kwargs):
# Ensure key is not None
if key is None:
key = {}

# Check if required keys are in key, and fetch if not
key_check = (
cls.primary_key if required_keys is None else required_keys
)
if not all([k in key for k in key_check]):
if not len(query := cls() & key) == 1:
raise KeyError(
f"Key {key} is neither fully specified nor a unique entry in"
+ f"{cls.full_table_name}"
)
key = query.fetch1("KEY")

# Call the original method with the modified key
return method(cls, key, *args, **kwargs)

return wrapper

return decorator
Loading