From 135d808dd2999fcab1f2721818887e7b25805507 Mon Sep 17 00:00:00 2001 From: Chris Brozdowski Date: Thu, 30 Nov 2023 08:00:22 -0800 Subject: [PATCH] Mixin class (#692) * WIP #530 * Add note * nwb_table -> _nwb_table, @rly * Update changelog --- CHANGELOG.md | 7 +-- src/spyglass/common/common_behav.py | 23 ++++----- src/spyglass/common/common_dio.py | 9 ++-- src/spyglass/common/common_ephys.py | 29 ++++-------- src/spyglass/common/common_nwbfile.py | 1 - src/spyglass/common/common_position.py | 16 ++----- src/spyglass/common/common_ripple.py | 17 ++----- src/spyglass/common/common_sensors.py | 11 ++--- src/spyglass/decoding/clusterless.py | 27 ++--------- src/spyglass/decoding/sorted_spikes.py | 8 +--- src/spyglass/lfp/analysis/v1/lfp_band.py | 11 ++--- src/spyglass/lfp/v1/lfp.py | 11 ++--- .../position/v1/position_dlc_centroid.py | 18 ++----- .../position/v1/position_dlc_cohort.py | 14 ++---- .../position/v1/position_dlc_orient.py | 9 +--- .../v1/position_dlc_pose_estimation.py | 16 ++----- .../position/v1/position_dlc_position.py | 16 ++----- .../position/v1/position_dlc_selection.py | 11 ++--- .../position/v1/position_trodes_position.py | 9 +--- .../v1/linearization.py | 8 +--- src/spyglass/ripple/v1/ripple.py | 8 +--- .../spikesorting/spikesorting_curation.py | 18 ++----- src/spyglass/utils/dj_mixin.py | 47 +++++++++++++++++++ 23 files changed, 130 insertions(+), 214 deletions(-) create mode 100644 src/spyglass/utils/dj_mixin.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 4be2a8fb9..437e27493 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,9 +2,10 @@ ## [0.4.4] (Unreleased) -- Additional documentation. #686 -- Refactor input validation in DLC pipeline. -- Clean up following pre-commit checks. +- Additional documentation. #690 +- Refactor input validation in DLC pipeline. #688 +- Clean up following pre-commit checks. #688 +- Add Mixin class to centralize `fetch_nwb` functionality. #692 - Minor fixes to LinearizedPositionV1 pipeline #695 ## [0.4.3] (November 7, 2023) diff --git a/src/spyglass/common/common_behav.py b/src/spyglass/common/common_behav.py index 4f35bcd9e..b01d80b04 100644 --- a/src/spyglass/common/common_behav.py +++ b/src/spyglass/common/common_behav.py @@ -8,7 +8,7 @@ import pandas as pd import pynwb -from ..utils.dj_helper_fn import fetch_nwb +from ..utils.dj_mixin import SpyglassMixin from ..utils.nwb_helper_fn import ( get_all_spatial_series, get_data_interface, @@ -158,7 +158,7 @@ class RawPosition(dj.Imported): -> PositionSource """ - class PosObject(dj.Part): + class PosObject(SpyglassMixin, dj.Part): definition = """ -> master -> PositionSource.SpatialSeries.proj('id') @@ -166,10 +166,7 @@ class PosObject(dj.Part): raw_position_object_id: varchar(40) # id of spatial series in NWB file """ - def fetch_nwb(self, *attrs, **kwargs): - return fetch_nwb( - self, (Nwbfile, "nwb_file_abs_path"), *attrs, **kwargs - ) + _nwb_table = Nwbfile def fetch1_dataframe(self): INDEX_ADJUST = 1 # adjust 0-index to 1-index (e.g., xloc0 -> xloc1) @@ -254,13 +251,15 @@ def fetch1_dataframe(self): @schema -class StateScriptFile(dj.Imported): +class StateScriptFile(SpyglassMixin, dj.Imported): definition = """ -> TaskEpoch --- file_object_id: varchar(40) # the object id of the file object """ + _nwb_table = Nwbfile + def make(self, key): """Add a new row to the StateScriptFile table.""" nwb_file_name = key["nwb_file_name"] @@ -309,12 +308,9 @@ def make(self, key): else: print("not a statescript file") - def fetch_nwb(self, *attrs, **kwargs): - return fetch_nwb(self, (Nwbfile, "nwb_file_abs_path"), *attrs, **kwargs) - @schema -class VideoFile(dj.Imported): +class VideoFile(SpyglassMixin, dj.Imported): """ Notes @@ -333,6 +329,8 @@ class VideoFile(dj.Imported): video_file_object_id: varchar(40) # the object id of the file object """ + _nwb_table = Nwbfile + def make(self, key): self._no_transaction_make(key) @@ -395,9 +393,6 @@ def _no_transaction_make(self, key, verbose=True): + f"epoch {interval_list_name}" ) - def fetch_nwb(self, *attrs, **kwargs): - return fetch_nwb(self, (Nwbfile, "nwb_file_abs_path"), *attrs, **kwargs) - @classmethod def update_entries(cls, restrict={}): existing_entries = (cls & restrict).fetch("KEY") diff --git a/src/spyglass/common/common_dio.py b/src/spyglass/common/common_dio.py index e8c12c662..e4dbc7f88 100644 --- a/src/spyglass/common/common_dio.py +++ b/src/spyglass/common/common_dio.py @@ -4,7 +4,7 @@ import pandas as pd import pynwb -from ..utils.dj_helper_fn import fetch_nwb # dj_replace +from ..utils.dj_mixin import SpyglassMixin from ..utils.nwb_helper_fn import get_data_interface, get_nwb_file from .common_ephys import Raw from .common_interval import IntervalList @@ -15,7 +15,7 @@ @schema -class DIOEvents(dj.Imported): +class DIOEvents(SpyglassMixin, dj.Imported): definition = """ -> Session dio_event_name: varchar(80) # the name assigned to this DIO event @@ -24,6 +24,8 @@ class DIOEvents(dj.Imported): -> IntervalList # the list of intervals for this object """ + _nwb_table = Nwbfile + def make(self, key): nwb_file_name = key["nwb_file_name"] nwb_file_abspath = Nwbfile.get_abs_path(nwb_file_name) @@ -48,9 +50,6 @@ def make(self, key): key["dio_object_id"] = event_series.object_id self.insert1(key, skip_duplicates=True) - def fetch_nwb(self, *attrs, **kwargs): - return fetch_nwb(self, (Nwbfile, "nwb_file_abs_path"), *attrs, **kwargs) - def plot_all_dio_events(self): """Plot all DIO events in the session. diff --git a/src/spyglass/common/common_ephys.py b/src/spyglass/common/common_ephys.py index eb792ff17..e2d999755 100644 --- a/src/spyglass/common/common_ephys.py +++ b/src/spyglass/common/common_ephys.py @@ -7,6 +7,7 @@ import pynwb from ..utils.dj_helper_fn import fetch_nwb # dj_replace +from ..utils.dj_mixin import SpyglassMixin from ..utils.nwb_helper_fn import ( estimate_sampling_rate, get_config, @@ -217,7 +218,7 @@ def create_from_config(cls, nwb_file_name: str): @schema -class Raw(dj.Imported): +class Raw(SpyglassMixin, dj.Imported): definition = """ # Raw voltage timeseries data, ElectricalSeries in NWB. -> Session @@ -229,6 +230,8 @@ class Raw(dj.Imported): description: varchar(2000) """ + _nwb_table = Nwbfile + def make(self, key): nwb_file_name = key["nwb_file_name"] nwb_file_abspath = Nwbfile.get_abs_path(nwb_file_name) @@ -295,12 +298,9 @@ def nwb_object(self, key): ) return nwbf.objects[raw_object_id] - def fetch_nwb(self, *attrs, **kwargs): - return fetch_nwb(self, (Nwbfile, "nwb_file_abs_path"), *attrs, **kwargs) - @schema -class SampleCount(dj.Imported): +class SampleCount(SpyglassMixin, dj.Imported): definition = """ # Sample count :s timestamp timeseries -> Session @@ -308,6 +308,8 @@ class SampleCount(dj.Imported): sample_count_object_id: varchar(40) # the NWB object ID for loading this object from the file """ + _nwb_table = Nwbfile + def make(self, key): nwb_file_name = key["nwb_file_name"] nwb_file_abspath = Nwbfile.get_abs_path(nwb_file_name) @@ -323,9 +325,6 @@ def make(self, key): key["sample_count_object_id"] = sample_count.object_id self.insert1(key) - def fetch_nwb(self, *attrs, **kwargs): - return fetch_nwb(self, (Nwbfile, "nwb_file_abs_path"), *attrs, **kwargs) - @schema class LFPSelection(dj.Manual): @@ -376,7 +375,7 @@ def set_lfp_electrodes(self, nwb_file_name, electrode_list): @schema -class LFP(dj.Imported): +class LFP(SpyglassMixin, dj.Imported): definition = """ -> LFPSelection --- @@ -487,11 +486,6 @@ def nwb_object(self, key): ) return lfp_nwbf.objects[nwb_object_id] - def fetch_nwb(self, *attrs, **kwargs): - return fetch_nwb( - self, (AnalysisNwbfile, "analysis_file_abs_path"), *attrs, **kwargs - ) - def fetch1_dataframe(self, *attrs, **kwargs): nwb_lfp = self.fetch_nwb()[0] return pd.DataFrame( @@ -631,7 +625,7 @@ def set_lfp_band_electrodes( @schema -class LFPBand(dj.Computed): +class LFPBand(SpyglassMixin, dj.Computed): definition = """ -> LFPBandSelection --- @@ -829,11 +823,6 @@ def make(self, key): self.insert1(key) - def fetch_nwb(self, *attrs, **kwargs): - return fetch_nwb( - self, (AnalysisNwbfile, "analysis_file_abs_path"), *attrs, **kwargs - ) - def fetch1_dataframe(self, *attrs, **kwargs): filtered_nwb = self.fetch_nwb()[0] return pd.DataFrame( diff --git a/src/spyglass/common/common_nwbfile.py b/src/spyglass/common/common_nwbfile.py index f6ee9c49b..d66e8ad9f 100644 --- a/src/spyglass/common/common_nwbfile.py +++ b/src/spyglass/common/common_nwbfile.py @@ -2,7 +2,6 @@ import random import stat import string -from pathlib import Path import datajoint as dj import numpy as np diff --git a/src/spyglass/common/common_position.py b/src/spyglass/common/common_position.py index 67b3ce95d..d0d1b1774 100644 --- a/src/spyglass/common/common_position.py +++ b/src/spyglass/common/common_position.py @@ -24,7 +24,7 @@ ) from ..settings import raw_dir, video_dir -from ..utils.dj_helper_fn import fetch_nwb +from ..utils.dj_mixin import SpyglassMixin from .common_behav import RawPosition, VideoFile from .common_interval import IntervalList # noqa F401 from .common_nwbfile import AnalysisNwbfile @@ -70,7 +70,7 @@ class IntervalPositionInfoSelection(dj.Lookup): @schema -class IntervalPositionInfo(dj.Computed): +class IntervalPositionInfo(SpyglassMixin, dj.Computed): """Computes the smoothed head position, orientation and velocity for a given interval.""" @@ -449,11 +449,6 @@ def calculate_position_info( "speed": speed, } - def fetch_nwb(self, *attrs, **kwargs): - return fetch_nwb( - self, (AnalysisNwbfile, "analysis_file_abs_path"), *attrs, **kwargs - ) - def fetch1_dataframe(self): return self._data_to_df(self.fetch_nwb()[0]) @@ -598,7 +593,7 @@ class IntervalLinearizationSelection(dj.Lookup): @schema -class IntervalLinearizedPosition(dj.Computed): +class IntervalLinearizedPosition(SpyglassMixin, dj.Computed): """Linearized position for a given interval""" definition = """ @@ -674,11 +669,6 @@ def make(self, key): self.insert1(key) - def fetch_nwb(self, *attrs, **kwargs): - return fetch_nwb( - self, (AnalysisNwbfile, "analysis_file_abs_path"), *attrs, **kwargs - ) - def fetch1_dataframe(self): return self.fetch_nwb()[0]["linearized_position"].set_index("time") diff --git a/src/spyglass/common/common_ripple.py b/src/spyglass/common/common_ripple.py index 9bc10f4fc..f2744fdd5 100644 --- a/src/spyglass/common/common_ripple.py +++ b/src/spyglass/common/common_ripple.py @@ -4,13 +4,11 @@ import pandas as pd from ripple_detection import Karlsson_ripple_detector, Kay_ripple_detector from ripple_detection.core import gaussian_smooth, get_envelope -from spyglass.common import ( - IntervalList, # noqa - IntervalPositionInfo, -) -from spyglass.common import LFPBand, LFPBandSelection + +from spyglass.common import IntervalList # noqa +from spyglass.common import IntervalPositionInfo, LFPBand, LFPBandSelection from spyglass.common.common_nwbfile import AnalysisNwbfile -from spyglass.utils.dj_helper_fn import fetch_nwb +from spyglass.utils.dj_mixin import SpyglassMixin schema = dj.schema("common_ripple") @@ -129,7 +127,7 @@ def insert_default(self): @schema -class RippleTimes(dj.Computed): +class RippleTimes(SpyglassMixin, dj.Computed): definition = """ -> RippleParameters -> RippleLFPSelection @@ -178,11 +176,6 @@ def make(self, key): self.insert1(key) - def fetch_nwb(self, *attrs, **kwargs): - return fetch_nwb( - self, (AnalysisNwbfile, "analysis_file_abs_path"), *attrs, **kwargs - ) - def fetch1_dataframe(self): """Convenience function for returning the marks in a readable format""" return self.fetch_dataframe()[0] diff --git a/src/spyglass/common/common_sensors.py b/src/spyglass/common/common_sensors.py index 2507abb82..b90aa5a71 100644 --- a/src/spyglass/common/common_sensors.py +++ b/src/spyglass/common/common_sensors.py @@ -3,18 +3,18 @@ import datajoint as dj import pynwb +from ..utils.dj_mixin import SpyglassMixin +from ..utils.nwb_helper_fn import get_data_interface, get_nwb_file from .common_ephys import Raw from .common_interval import IntervalList # noqa: F401 from .common_nwbfile import Nwbfile from .common_session import Session # noqa: F401 -from ..utils.dj_helper_fn import fetch_nwb -from ..utils.nwb_helper_fn import get_data_interface, get_nwb_file schema = dj.schema("common_sensors") @schema -class SensorData(dj.Imported): +class SensorData(SpyglassMixin, dj.Imported): definition = """ -> Session --- @@ -22,6 +22,8 @@ class SensorData(dj.Imported): -> IntervalList # the list of intervals for this object """ + _nwb_table = Nwbfile + def make(self, key): nwb_file_name = key["nwb_file_name"] nwb_file_abspath = Nwbfile().get_abs_path(nwb_file_name) @@ -40,6 +42,3 @@ def make(self, key): Raw & {"nwb_file_name": nwb_file_name} ).fetch1("interval_list_name") self.insert1(key) - - def fetch_nwb(self, *attrs, **kwargs): - return fetch_nwb(self, (Nwbfile, "nwb_file_abs_path"), *attrs, **kwargs) diff --git a/src/spyglass/decoding/clusterless.py b/src/spyglass/decoding/clusterless.py index 7cc149eda..e2551872a 100644 --- a/src/spyglass/decoding/clusterless.py +++ b/src/spyglass/decoding/clusterless.py @@ -31,10 +31,7 @@ from replay_trajectory_classification.initial_conditions import ( UniformInitialConditions, ) -from ripple_detection import ( - get_multiunit_population_firing_rate, - multiunit_HSE_detector, -) +from ripple_detection import get_multiunit_population_firing_rate from tqdm.auto import tqdm from spyglass.common.common_behav import ( @@ -61,6 +58,7 @@ SpikeSortingSelection, ) from spyglass.utils.dj_helper_fn import fetch_nwb +from spyglass.utils.dj_mixin import SpyglassMixin schema = dj.schema("decoding_clusterless") @@ -119,7 +117,7 @@ class UnitMarkParameters(dj.Manual): @schema -class UnitMarks(dj.Computed): +class UnitMarks(SpyglassMixin, dj.Computed): """For each spike time, compute a spike waveform feature associated with that spike. Used for clusterless decoding. """ @@ -225,11 +223,6 @@ def make(self, key): AnalysisNwbfile().add(key["nwb_file_name"], key["analysis_file_name"]) self.insert1(key) - def fetch_nwb(self, *attrs, **kwargs): - return fetch_nwb( - self, (AnalysisNwbfile, "analysis_file_abs_path"), *attrs, **kwargs - ) - def fetch1_dataframe(self): """Convenience function for returning the marks in a readable format""" return self.fetch_dataframe()[0] @@ -325,7 +318,7 @@ class UnitMarksIndicatorSelection(dj.Lookup): @schema -class UnitMarksIndicator(dj.Computed): +class UnitMarksIndicator(SpyglassMixin, dj.Computed): """Bins the spike times and associated spike waveform features into regular time bins according to the sampling rate. Features that fall into the same time bin are averaged. @@ -437,11 +430,6 @@ def plot_all_marks( s=s, ) - def fetch_nwb(self, *attrs, **kwargs): - return fetch_nwb( - self, (AnalysisNwbfile, "analysis_file_abs_path"), *attrs, **kwargs - ) - def fetch1_dataframe(self): return self.fetch_dataframe()[0] @@ -583,7 +571,7 @@ def fetch1(self, *args, **kwargs): @schema -class MultiunitFiringRate(dj.Computed): +class MultiunitFiringRate(SpyglassMixin, dj.Computed): """Computes the population multiunit firing rate from the spikes in MarksIndicator.""" @@ -627,11 +615,6 @@ def make(self, key): self.insert1(key) - def fetch_nwb(self, *attrs, **kwargs): - return fetch_nwb( - self, (AnalysisNwbfile, "analysis_file_abs_path"), *attrs, **kwargs - ) - def fetch1_dataframe(self): return self.fetch_dataframe()[0] diff --git a/src/spyglass/decoding/sorted_spikes.py b/src/spyglass/decoding/sorted_spikes.py index 2a9c1cb8c..c20314392 100644 --- a/src/spyglass/decoding/sorted_spikes.py +++ b/src/spyglass/decoding/sorted_spikes.py @@ -40,6 +40,7 @@ restore_classes, ) from spyglass.spikesorting.spikesorting_curation import CuratedSpikeSorting +from spyglass.utils.dj_mixin import SpyglassMixin schema = dj.schema("decoding_sortedspikes") @@ -59,7 +60,7 @@ class SortedSpikesIndicatorSelection(dj.Lookup): @schema -class SortedSpikesIndicator(dj.Computed): +class SortedSpikesIndicator(SpyglassMixin, dj.Computed): """Bins spike times into regular intervals given by the sampling rate. Useful for GLMs and for decoding. @@ -147,11 +148,6 @@ def get_time_bins_from_interval(interval_times, sampling_rate): return np.linspace(start_time, end_time, n_samples) - def fetch_nwb(self, *attrs, **kwargs): - return fetch_nwb( - self, (AnalysisNwbfile, "analysis_file_abs_path"), *attrs, **kwargs - ) - def fetch1_dataframe(self): return self.fetch_dataframe()[0] diff --git a/src/spyglass/lfp/analysis/v1/lfp_band.py b/src/spyglass/lfp/analysis/v1/lfp_band.py index 4bcc1fa73..1b2596461 100644 --- a/src/spyglass/lfp/analysis/v1/lfp_band.py +++ b/src/spyglass/lfp/analysis/v1/lfp_band.py @@ -13,9 +13,9 @@ interval_list_intersect, ) from spyglass.common.common_nwbfile import AnalysisNwbfile -from spyglass.lfp.lfp_merge import LFPOutput from spyglass.lfp.lfp_electrode import LFPElectrodeGroup -from spyglass.utils.dj_helper_fn import fetch_nwb +from spyglass.lfp.lfp_merge import LFPOutput +from spyglass.utils.dj_mixin import SpyglassMixin from spyglass.utils.nwb_helper_fn import get_electrode_indices schema = dj.schema("lfp_band_v1") @@ -165,7 +165,7 @@ def set_lfp_band_electrodes( @schema -class LFPBandV1(dj.Computed): +class LFPBandV1(SpyglassMixin, dj.Computed): definition = """ -> LFPBandSelection # the LFP band selection --- @@ -364,11 +364,6 @@ def make(self, key): self.insert1(key) - def fetch_nwb(self, *attrs, **kwargs): - return fetch_nwb( - self, (AnalysisNwbfile, "analysis_file_abs_path"), *attrs, **kwargs - ) - def fetch1_dataframe(self, *attrs, **kwargs): """Fetches the filtered data as a dataframe""" filtered_nwb = self.fetch_nwb()[0] diff --git a/src/spyglass/lfp/v1/lfp.py b/src/spyglass/lfp/v1/lfp.py index 32d168f46..9f7f709a2 100644 --- a/src/spyglass/lfp/v1/lfp.py +++ b/src/spyglass/lfp/v1/lfp.py @@ -14,7 +14,9 @@ from spyglass.common.common_nwbfile import AnalysisNwbfile from spyglass.common.common_session import Session # noqa: F401 from spyglass.lfp.lfp_electrode import LFPElectrodeGroup -from spyglass.utils.dj_helper_fn import fetch_nwb # dj_replace + +# from spyglass.utils.dj_helper_fn import fetch_nwb # dj_replace +from spyglass.utils.dj_mixin import SpyglassMixin schema = dj.schema("lfp_v1") @@ -43,7 +45,7 @@ class LFPSelection(dj.Manual): @schema -class LFPV1(dj.Computed): +class LFPV1(SpyglassMixin, dj.Computed): """The filtered LFP data""" definition = """ @@ -170,11 +172,6 @@ def make(self, key): orig_key["lfp_object_id"] = lfp_object_id LFPOutput.insert1(orig_key) - def fetch_nwb(self, *attrs, **kwargs): - return fetch_nwb( - self, (AnalysisNwbfile, "analysis_file_abs_path"), *attrs, **kwargs - ) - def fetch1_dataframe(self, *attrs, **kwargs): nwb_lfp = self.fetch_nwb()[0] return pd.DataFrame( diff --git a/src/spyglass/position/v1/position_dlc_centroid.py b/src/spyglass/position/v1/position_dlc_centroid.py index d1e7e6dba..9f6bcc401 100644 --- a/src/spyglass/position/v1/position_dlc_centroid.py +++ b/src/spyglass/position/v1/position_dlc_centroid.py @@ -8,15 +8,8 @@ from ...common.common_behav import RawPosition from ...common.common_nwbfile import AnalysisNwbfile -from ...utils.dj_helper_fn import fetch_nwb -from .dlc_utils import ( - _key_to_smooth_func_dict, - get_span_start_stop, - interp_pos, - validate_list, - validate_option, - validate_smooth_params, -) +from ...utils.dj_mixin import SpyglassMixin +from .dlc_utils import _key_to_smooth_func_dict, get_span_start_stop, interp_pos from .position_dlc_cohort import DLCSmoothInterpCohort from .position_dlc_position import DLCSmoothInterpParams @@ -118,7 +111,7 @@ class DLCCentroidSelection(dj.Manual): @schema -class DLCCentroid(dj.Computed): +class DLCCentroid(SpyglassMixin, dj.Computed): """ Table to calculate the centroid of a group of bodyparts """ @@ -321,11 +314,6 @@ def make(self, key): self.insert1(key) logger.logger.info("inserted entry into DLCCentroid") - def fetch_nwb(self, *attrs, **kwargs): - return fetch_nwb( - self, (AnalysisNwbfile, "analysis_file_abs_path"), *attrs, **kwargs - ) - def fetch1_dataframe(self): nwb_data = self.fetch_nwb()[0] index = pd.Index( diff --git a/src/spyglass/position/v1/position_dlc_cohort.py b/src/spyglass/position/v1/position_dlc_cohort.py index 4faa39bad..acc94eb56 100644 --- a/src/spyglass/position/v1/position_dlc_cohort.py +++ b/src/spyglass/position/v1/position_dlc_cohort.py @@ -1,9 +1,9 @@ +import datajoint as dj import numpy as np import pandas as pd -import datajoint as dj from ...common.common_nwbfile import AnalysisNwbfile -from ...utils.dj_helper_fn import fetch_nwb +from ...utils.dj_mixin import SpyglassMixin from .position_dlc_pose_estimation import DLCPoseEstimation # noqa: F401 from .position_dlc_position import DLCSmoothInterp @@ -38,7 +38,7 @@ class DLCSmoothInterpCohort(dj.Computed): --- """ - class BodyPart(dj.Part): + class BodyPart(SpyglassMixin, dj.Part): definition = """ -> DLCSmoothInterpCohort -> DLCSmoothInterp @@ -48,14 +48,6 @@ class BodyPart(dj.Part): dlc_smooth_interp_info_object_id : varchar(80) """ - def fetch_nwb(self, *attrs, **kwargs): - return fetch_nwb( - self, - (AnalysisNwbfile, "analysis_file_abs_path"), - *attrs, - **kwargs, - ) - def fetch1_dataframe(self): nwb_data = self.fetch_nwb()[0] index = pd.Index( diff --git a/src/spyglass/position/v1/position_dlc_orient.py b/src/spyglass/position/v1/position_dlc_orient.py index 09dd862dc..0c47e2ff4 100644 --- a/src/spyglass/position/v1/position_dlc_orient.py +++ b/src/spyglass/position/v1/position_dlc_orient.py @@ -6,7 +6,7 @@ from ...common.common_behav import RawPosition from ...common.common_nwbfile import AnalysisNwbfile -from ...utils.dj_helper_fn import fetch_nwb +from ...utils.dj_mixin import SpyglassMixin from .dlc_utils import get_span_start_stop from .position_dlc_cohort import DLCSmoothInterpCohort @@ -70,7 +70,7 @@ class DLCOrientationSelection(dj.Manual): @schema -class DLCOrientation(dj.Computed): +class DLCOrientation(SpyglassMixin, dj.Computed): """ Determines and smooths orientation of a set of bodyparts given a specified method """ @@ -155,11 +155,6 @@ def make(self, key): self.insert1(key) - def fetch_nwb(self, *attrs, **kwargs): - return fetch_nwb( - self, (AnalysisNwbfile, "analysis_file_abs_path"), *attrs, **kwargs - ) - def fetch1_dataframe(self): nwb_data = self.fetch_nwb()[0] index = pd.Index( diff --git a/src/spyglass/position/v1/position_dlc_pose_estimation.py b/src/spyglass/position/v1/position_dlc_pose_estimation.py index 899a96a15..f43ad8564 100644 --- a/src/spyglass/position/v1/position_dlc_pose_estimation.py +++ b/src/spyglass/position/v1/position_dlc_pose_estimation.py @@ -9,13 +9,13 @@ import pynwb from IPython.display import display -from ...common.common_behav import ( +from ...common.common_behav import ( # noqa: F401 RawPosition, VideoFile, convert_epoch_interval_name_to_position_interval_name, -) # noqa: F401 +) from ...common.common_nwbfile import AnalysisNwbfile -from ...utils.dj_helper_fn import fetch_nwb +from ...utils.dj_mixin import SpyglassMixin from .dlc_utils import OutputLogger, infer_output_dir from .position_dlc_model import DLCModel @@ -132,7 +132,7 @@ class DLCPoseEstimation(dj.Computed): meters_per_pixel : double # conversion of meters per pixel for analyzed video """ - class BodyPart(dj.Part): + class BodyPart(SpyglassMixin, dj.Part): definition = """ # uses DeepLabCut h5 output for body part position -> DLCPoseEstimation -> DLCModel.BodyPart @@ -142,13 +142,7 @@ class BodyPart(dj.Part): dlc_pose_estimation_likelihood_object_id : varchar(80) """ - def fetch_nwb(self, *attrs, **kwargs): - return fetch_nwb( - self, - (AnalysisNwbfile, "analysis_file_abs_path"), - *attrs, - **kwargs, - ) + _nwb_table = AnalysisNwbfile def fetch1_dataframe(self): nwb_data = self.fetch_nwb()[0] diff --git a/src/spyglass/position/v1/position_dlc_position.py b/src/spyglass/position/v1/position_dlc_position.py index 230779929..56cba5978 100644 --- a/src/spyglass/position/v1/position_dlc_position.py +++ b/src/spyglass/position/v1/position_dlc_position.py @@ -5,13 +5,8 @@ from ...common.common_nwbfile import AnalysisNwbfile from ...utils.dj_helper_fn import fetch_nwb -from .dlc_utils import ( - _key_to_smooth_func_dict, - get_span_start_stop, - interp_pos, - validate_option, - validate_smooth_params, -) +from ...utils.dj_mixin import SpyglassMixin +from .dlc_utils import _key_to_smooth_func_dict, get_span_start_stop, interp_pos from .position_dlc_pose_estimation import DLCPoseEstimation schema = dj.schema("position_v1_dlc_position") @@ -141,7 +136,7 @@ class DLCSmoothInterpSelection(dj.Manual): @schema -class DLCSmoothInterp(dj.Computed): +class DLCSmoothInterp(SpyglassMixin, dj.Computed): """ Interpolates across low likelihood periods and smooths the position Can take a few minutes. @@ -267,11 +262,6 @@ def make(self, key): self.insert1(key) logger.logger.info("inserted entry into DLCSmoothInterp") - def fetch_nwb(self, *attrs, **kwargs): - return fetch_nwb( - self, (AnalysisNwbfile, "analysis_file_abs_path"), *attrs, **kwargs - ) - def fetch1_dataframe(self): nwb_data = self.fetch_nwb()[0] index = pd.Index( diff --git a/src/spyglass/position/v1/position_dlc_selection.py b/src/spyglass/position/v1/position_dlc_selection.py index af4aa15ee..604e23f57 100644 --- a/src/spyglass/position/v1/position_dlc_selection.py +++ b/src/spyglass/position/v1/position_dlc_selection.py @@ -11,8 +11,8 @@ convert_epoch_interval_name_to_position_interval_name, ) from ...common.common_nwbfile import AnalysisNwbfile -from ...utils.dj_helper_fn import fetch_nwb -from .dlc_utils import make_video +from ...utils.dj_mixin import SpyglassMixin +from .dlc_utils import get_video_path, make_video from .position_dlc_centroid import DLCCentroid from .position_dlc_cohort import DLCSmoothInterpCohort from .position_dlc_orient import DLCOrientation @@ -39,7 +39,7 @@ class DLCPosSelection(dj.Manual): @schema -class DLCPosV1(dj.Computed): +class DLCPosV1(SpyglassMixin, dj.Computed): """ Combines upstream DLCCentroid and DLCOrientation entries into a single entry with a single Analysis NWB file @@ -139,11 +139,6 @@ def make(self, key): [orig_key], part_name=part_name, skip_duplicates=True ) - def fetch_nwb(self, *attrs, **kwargs): - return fetch_nwb( - self, (AnalysisNwbfile, "analysis_file_abs_path"), *attrs, **kwargs - ) - def fetch1_dataframe(self): nwb_data = self.fetch_nwb()[0] index = pd.Index( diff --git a/src/spyglass/position/v1/position_trodes_position.py b/src/spyglass/position/v1/position_trodes_position.py index 4f0500949..fe43630b9 100644 --- a/src/spyglass/position/v1/position_trodes_position.py +++ b/src/spyglass/position/v1/position_trodes_position.py @@ -11,7 +11,7 @@ from ...common.common_behav import RawPosition from ...common.common_nwbfile import AnalysisNwbfile from ...common.common_position import IntervalPositionInfo -from ...utils.dj_helper_fn import fetch_nwb +from ...utils.dj_mixin import SpyglassMixin from .dlc_utils import check_videofile, get_video_path schema = dj.schema("position_v1_trodes_position") @@ -143,7 +143,7 @@ def insert_with_default( @schema -class TrodesPosV1(dj.Computed): +class TrodesPosV1(SpyglassMixin, dj.Computed): """ Table to calculate the position based on Trodes tracking """ @@ -212,11 +212,6 @@ def calculate_position_info(*args, **kwargs): """Calculate position info from 2D spatial series.""" return IntervalPositionInfo().calculate_position_info(*args, **kwargs) - def fetch_nwb(self, *attrs, **kwargs): - return fetch_nwb( - self, (AnalysisNwbfile, "analysis_file_abs_path"), *attrs, **kwargs - ) - def fetch1_dataframe(self): return IntervalPositionInfo._data_to_df( self.fetch_nwb()[0], prefix="", add_frame_ind=True diff --git a/src/spyglass/position_linearization/v1/linearization.py b/src/spyglass/position_linearization/v1/linearization.py index 6ca118edd..eef836bd3 100644 --- a/src/spyglass/position_linearization/v1/linearization.py +++ b/src/spyglass/position_linearization/v1/linearization.py @@ -13,6 +13,7 @@ from spyglass.common.common_nwbfile import AnalysisNwbfile from spyglass.position.position_merge import PositionOutput from spyglass.utils.dj_helper_fn import fetch_nwb +from spyglass.utils.dj_mixin import SpyglassMixin schema = dj.schema("position_linearization_v1") @@ -103,7 +104,7 @@ class LinearizationSelection(dj.Lookup): @schema -class LinearizedPositionV1(dj.Computed): +class LinearizedPositionV1(SpyglassMixin, dj.Computed): """Linearized position for a given interval""" definition = """ @@ -181,10 +182,5 @@ def make(self, key): [orig_key], part_name=part_name, skip_duplicates=True ) - def fetch_nwb(self, *attrs, **kwargs): - return fetch_nwb( - self, (AnalysisNwbfile, "analysis_file_abs_path"), *attrs, **kwargs - ) - def fetch1_dataframe(self): return self.fetch_nwb()[0]["linearized_position"].set_index("time") diff --git a/src/spyglass/ripple/v1/ripple.py b/src/spyglass/ripple/v1/ripple.py index 51832de7d..679c6dfed 100644 --- a/src/spyglass/ripple/v1/ripple.py +++ b/src/spyglass/ripple/v1/ripple.py @@ -13,6 +13,7 @@ from spyglass.lfp.analysis.v1.lfp_band import LFPBandSelection, LFPBandV1 from spyglass.position import PositionOutput from spyglass.utils.dj_helper_fn import fetch_nwb +from spyglass.utils.dj_mixin import SpyglassMixin from spyglass.utils.nwb_helper_fn import get_electrode_indices schema = dj.schema("ripple_v1") @@ -144,7 +145,7 @@ def insert_default(self): @schema -class RippleTimesV1(dj.Computed): +class RippleTimesV1(SpyglassMixin, dj.Computed): definition = """ -> RippleLFPSelection -> RippleParameters @@ -196,11 +197,6 @@ def make(self, key): self.insert1(key) - def fetch_nwb(self, *attrs, **kwargs): - return fetch_nwb( - self, (AnalysisNwbfile, "analysis_file_abs_path"), *attrs, **kwargs - ) - def fetch1_dataframe(self): """Convenience function for returning the marks in a readable format""" return self.fetch_dataframe()[0] diff --git a/src/spyglass/spikesorting/spikesorting_curation.py b/src/spyglass/spikesorting/spikesorting_curation.py index 6b7cfc315..9dc587068 100644 --- a/src/spyglass/spikesorting/spikesorting_curation.py +++ b/src/spyglass/spikesorting/spikesorting_curation.py @@ -15,7 +15,7 @@ from ..common.common_interval import IntervalList from ..common.common_nwbfile import AnalysisNwbfile -from ..utils.dj_helper_fn import fetch_nwb +from ..utils.dj_helper_fn import SpyglassMixin from .merged_sorting_extractor import MergedSortingExtractor from .spikesorting_recording import SortInterval, SpikeSortingRecording from .spikesorting_sorting import SpikeSorting @@ -38,7 +38,7 @@ def apply_merge_groups_to_sorting( @schema -class Curation(dj.Manual): +class Curation(SpyglassMixin, dj.Manual): definition = """ # Stores each spike sorting; similar to IntervalList curation_id: int # a number corresponding to the index of this curation @@ -52,6 +52,8 @@ class Curation(dj.Manual): time_of_creation: int # in Unix time, to the nearest second """ + _nwb_table = AnalysisNwbfile + @staticmethod def insert_curation( sorting_key: dict, @@ -252,11 +254,6 @@ def save_sorting_nwb( return analysis_file_name, units_object_id - def fetch_nwb(self, *attrs, **kwargs): - return fetch_nwb( - self, (AnalysisNwbfile, "analysis_file_abs_path"), *attrs, **kwargs - ) - @schema class WaveformParameters(dj.Manual): @@ -879,7 +876,7 @@ class CuratedSpikeSortingSelection(dj.Manual): @schema -class CuratedSpikeSorting(dj.Computed): +class CuratedSpikeSorting(SpyglassMixin, dj.Computed): definition = """ -> CuratedSpikeSortingSelection --- @@ -998,11 +995,6 @@ def metrics_fields(self): unit_fields.remove("label") return unit_fields - def fetch_nwb(self, *attrs, **kwargs): - return fetch_nwb( - self, (AnalysisNwbfile, "analysis_file_abs_path"), *attrs, **kwargs - ) - @schema class UnitInclusionParameters(dj.Manual): diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py new file mode 100644 index 000000000..3a01456d5 --- /dev/null +++ b/src/spyglass/utils/dj_mixin.py @@ -0,0 +1,47 @@ +from ..common.common_nwbfile import AnalysisNwbfile, Nwbfile +from .dj_helper_fn import fetch_nwb + + +class SpyglassMixin: + """Mixin for Spyglass DataJoint tables .""" + + _nwb_table_dict = { + AnalysisNwbfile: "analysis_file_abs_path", + Nwbfile: "nwb_file_abs_path", + } + + def fetch_nwb(self, *attrs, **kwargs): + """Fetch NWBFile object from relevant table. + + Impleminting class must have a foreign key to Nwbfile or + AnalysisNwbfile or a _nwb_table attribute. + + A class that does not have with either '-> Nwbfile' or + '-> AnalysisNwbfile' in its definition can use a _nwb_table attribute to + specify which table to use. + """ + + if not hasattr(self, "_nwb_table"): + self._nwb_table = ( + AnalysisNwbfile + if "-> AnalysisNwbfile" in self.definition + else Nwbfile + if "-> Nwbfile" in self.definition + else None + ) + + if getattr(self, "_nwb_table", None) is None: + raise NotImplementedError( + f"{self.__class__.__name__} does not have a (Analysis)Nwbfile " + "foreign key or _nwb_table attribute." + ) + + return fetch_nwb( + self, + (self._nwb_table, self._nwb_table_dict[self._nwb_table]), + *attrs, + **kwargs, + ) + + # def delete(self): + # print(f"Deleting with mixin {self.__class__.__name__}...")