Skip to content

Commit

Permalink
Mixin class (LorenFrankLab#692)
Browse files Browse the repository at this point in the history
* WIP LorenFrankLab#530

* Add note

* nwb_table -> _nwb_table, @rly

* Update changelog
  • Loading branch information
CBroz1 authored Nov 30, 2023
1 parent 4fa761a commit 135d808
Show file tree
Hide file tree
Showing 23 changed files with 130 additions and 214 deletions.
7 changes: 4 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

## [0.4.4] (Unreleased)

- Additional documentation. #686
- Refactor input validation in DLC pipeline.
- Clean up following pre-commit checks.
- Additional documentation. #690
- Refactor input validation in DLC pipeline. #688
- Clean up following pre-commit checks. #688
- Add Mixin class to centralize `fetch_nwb` functionality. #692
- Minor fixes to LinearizedPositionV1 pipeline #695

## [0.4.3] (November 7, 2023)
Expand Down
23 changes: 9 additions & 14 deletions src/spyglass/common/common_behav.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pandas as pd
import pynwb

from ..utils.dj_helper_fn import fetch_nwb
from ..utils.dj_mixin import SpyglassMixin
from ..utils.nwb_helper_fn import (
get_all_spatial_series,
get_data_interface,
Expand Down Expand Up @@ -158,18 +158,15 @@ class RawPosition(dj.Imported):
-> PositionSource
"""

class PosObject(dj.Part):
class PosObject(SpyglassMixin, dj.Part):
definition = """
-> master
-> PositionSource.SpatialSeries.proj('id')
---
raw_position_object_id: varchar(40) # id of spatial series in NWB file
"""

def fetch_nwb(self, *attrs, **kwargs):
return fetch_nwb(
self, (Nwbfile, "nwb_file_abs_path"), *attrs, **kwargs
)
_nwb_table = Nwbfile

def fetch1_dataframe(self):
INDEX_ADJUST = 1 # adjust 0-index to 1-index (e.g., xloc0 -> xloc1)
Expand Down Expand Up @@ -254,13 +251,15 @@ def fetch1_dataframe(self):


@schema
class StateScriptFile(dj.Imported):
class StateScriptFile(SpyglassMixin, dj.Imported):
definition = """
-> TaskEpoch
---
file_object_id: varchar(40) # the object id of the file object
"""

_nwb_table = Nwbfile

def make(self, key):
"""Add a new row to the StateScriptFile table."""
nwb_file_name = key["nwb_file_name"]
Expand Down Expand Up @@ -309,12 +308,9 @@ def make(self, key):
else:
print("not a statescript file")

def fetch_nwb(self, *attrs, **kwargs):
return fetch_nwb(self, (Nwbfile, "nwb_file_abs_path"), *attrs, **kwargs)


@schema
class VideoFile(dj.Imported):
class VideoFile(SpyglassMixin, dj.Imported):
"""
Notes
Expand All @@ -333,6 +329,8 @@ class VideoFile(dj.Imported):
video_file_object_id: varchar(40) # the object id of the file object
"""

_nwb_table = Nwbfile

def make(self, key):
self._no_transaction_make(key)

Expand Down Expand Up @@ -395,9 +393,6 @@ def _no_transaction_make(self, key, verbose=True):
+ f"epoch {interval_list_name}"
)

def fetch_nwb(self, *attrs, **kwargs):
return fetch_nwb(self, (Nwbfile, "nwb_file_abs_path"), *attrs, **kwargs)

@classmethod
def update_entries(cls, restrict={}):
existing_entries = (cls & restrict).fetch("KEY")
Expand Down
9 changes: 4 additions & 5 deletions src/spyglass/common/common_dio.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pandas as pd
import pynwb

from ..utils.dj_helper_fn import fetch_nwb # dj_replace
from ..utils.dj_mixin import SpyglassMixin
from ..utils.nwb_helper_fn import get_data_interface, get_nwb_file
from .common_ephys import Raw
from .common_interval import IntervalList
Expand All @@ -15,7 +15,7 @@


@schema
class DIOEvents(dj.Imported):
class DIOEvents(SpyglassMixin, dj.Imported):
definition = """
-> Session
dio_event_name: varchar(80) # the name assigned to this DIO event
Expand All @@ -24,6 +24,8 @@ class DIOEvents(dj.Imported):
-> IntervalList # the list of intervals for this object
"""

_nwb_table = Nwbfile

def make(self, key):
nwb_file_name = key["nwb_file_name"]
nwb_file_abspath = Nwbfile.get_abs_path(nwb_file_name)
Expand All @@ -48,9 +50,6 @@ def make(self, key):
key["dio_object_id"] = event_series.object_id
self.insert1(key, skip_duplicates=True)

def fetch_nwb(self, *attrs, **kwargs):
return fetch_nwb(self, (Nwbfile, "nwb_file_abs_path"), *attrs, **kwargs)

def plot_all_dio_events(self):
"""Plot all DIO events in the session.
Expand Down
29 changes: 9 additions & 20 deletions src/spyglass/common/common_ephys.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pynwb

from ..utils.dj_helper_fn import fetch_nwb # dj_replace
from ..utils.dj_mixin import SpyglassMixin
from ..utils.nwb_helper_fn import (
estimate_sampling_rate,
get_config,
Expand Down Expand Up @@ -217,7 +218,7 @@ def create_from_config(cls, nwb_file_name: str):


@schema
class Raw(dj.Imported):
class Raw(SpyglassMixin, dj.Imported):
definition = """
# Raw voltage timeseries data, ElectricalSeries in NWB.
-> Session
Expand All @@ -229,6 +230,8 @@ class Raw(dj.Imported):
description: varchar(2000)
"""

_nwb_table = Nwbfile

def make(self, key):
nwb_file_name = key["nwb_file_name"]
nwb_file_abspath = Nwbfile.get_abs_path(nwb_file_name)
Expand Down Expand Up @@ -295,19 +298,18 @@ def nwb_object(self, key):
)
return nwbf.objects[raw_object_id]

def fetch_nwb(self, *attrs, **kwargs):
return fetch_nwb(self, (Nwbfile, "nwb_file_abs_path"), *attrs, **kwargs)


@schema
class SampleCount(dj.Imported):
class SampleCount(SpyglassMixin, dj.Imported):
definition = """
# Sample count :s timestamp timeseries
-> Session
---
sample_count_object_id: varchar(40) # the NWB object ID for loading this object from the file
"""

_nwb_table = Nwbfile

def make(self, key):
nwb_file_name = key["nwb_file_name"]
nwb_file_abspath = Nwbfile.get_abs_path(nwb_file_name)
Expand All @@ -323,9 +325,6 @@ def make(self, key):
key["sample_count_object_id"] = sample_count.object_id
self.insert1(key)

def fetch_nwb(self, *attrs, **kwargs):
return fetch_nwb(self, (Nwbfile, "nwb_file_abs_path"), *attrs, **kwargs)


@schema
class LFPSelection(dj.Manual):
Expand Down Expand Up @@ -376,7 +375,7 @@ def set_lfp_electrodes(self, nwb_file_name, electrode_list):


@schema
class LFP(dj.Imported):
class LFP(SpyglassMixin, dj.Imported):
definition = """
-> LFPSelection
---
Expand Down Expand Up @@ -487,11 +486,6 @@ def nwb_object(self, key):
)
return lfp_nwbf.objects[nwb_object_id]

def fetch_nwb(self, *attrs, **kwargs):
return fetch_nwb(
self, (AnalysisNwbfile, "analysis_file_abs_path"), *attrs, **kwargs
)

def fetch1_dataframe(self, *attrs, **kwargs):
nwb_lfp = self.fetch_nwb()[0]
return pd.DataFrame(
Expand Down Expand Up @@ -631,7 +625,7 @@ def set_lfp_band_electrodes(


@schema
class LFPBand(dj.Computed):
class LFPBand(SpyglassMixin, dj.Computed):
definition = """
-> LFPBandSelection
---
Expand Down Expand Up @@ -829,11 +823,6 @@ def make(self, key):

self.insert1(key)

def fetch_nwb(self, *attrs, **kwargs):
return fetch_nwb(
self, (AnalysisNwbfile, "analysis_file_abs_path"), *attrs, **kwargs
)

def fetch1_dataframe(self, *attrs, **kwargs):
filtered_nwb = self.fetch_nwb()[0]
return pd.DataFrame(
Expand Down
1 change: 0 additions & 1 deletion src/spyglass/common/common_nwbfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import random
import stat
import string
from pathlib import Path

import datajoint as dj
import numpy as np
Expand Down
16 changes: 3 additions & 13 deletions src/spyglass/common/common_position.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
)

from ..settings import raw_dir, video_dir
from ..utils.dj_helper_fn import fetch_nwb
from ..utils.dj_mixin import SpyglassMixin
from .common_behav import RawPosition, VideoFile
from .common_interval import IntervalList # noqa F401
from .common_nwbfile import AnalysisNwbfile
Expand Down Expand Up @@ -70,7 +70,7 @@ class IntervalPositionInfoSelection(dj.Lookup):


@schema
class IntervalPositionInfo(dj.Computed):
class IntervalPositionInfo(SpyglassMixin, dj.Computed):
"""Computes the smoothed head position, orientation and velocity for a given
interval."""

Expand Down Expand Up @@ -449,11 +449,6 @@ def calculate_position_info(
"speed": speed,
}

def fetch_nwb(self, *attrs, **kwargs):
return fetch_nwb(
self, (AnalysisNwbfile, "analysis_file_abs_path"), *attrs, **kwargs
)

def fetch1_dataframe(self):
return self._data_to_df(self.fetch_nwb()[0])

Expand Down Expand Up @@ -598,7 +593,7 @@ class IntervalLinearizationSelection(dj.Lookup):


@schema
class IntervalLinearizedPosition(dj.Computed):
class IntervalLinearizedPosition(SpyglassMixin, dj.Computed):
"""Linearized position for a given interval"""

definition = """
Expand Down Expand Up @@ -674,11 +669,6 @@ def make(self, key):

self.insert1(key)

def fetch_nwb(self, *attrs, **kwargs):
return fetch_nwb(
self, (AnalysisNwbfile, "analysis_file_abs_path"), *attrs, **kwargs
)

def fetch1_dataframe(self):
return self.fetch_nwb()[0]["linearized_position"].set_index("time")

Expand Down
17 changes: 5 additions & 12 deletions src/spyglass/common/common_ripple.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@
import pandas as pd
from ripple_detection import Karlsson_ripple_detector, Kay_ripple_detector
from ripple_detection.core import gaussian_smooth, get_envelope
from spyglass.common import (
IntervalList, # noqa
IntervalPositionInfo,
)
from spyglass.common import LFPBand, LFPBandSelection

from spyglass.common import IntervalList # noqa
from spyglass.common import IntervalPositionInfo, LFPBand, LFPBandSelection
from spyglass.common.common_nwbfile import AnalysisNwbfile
from spyglass.utils.dj_helper_fn import fetch_nwb
from spyglass.utils.dj_mixin import SpyglassMixin

schema = dj.schema("common_ripple")

Expand Down Expand Up @@ -129,7 +127,7 @@ def insert_default(self):


@schema
class RippleTimes(dj.Computed):
class RippleTimes(SpyglassMixin, dj.Computed):
definition = """
-> RippleParameters
-> RippleLFPSelection
Expand Down Expand Up @@ -178,11 +176,6 @@ def make(self, key):

self.insert1(key)

def fetch_nwb(self, *attrs, **kwargs):
return fetch_nwb(
self, (AnalysisNwbfile, "analysis_file_abs_path"), *attrs, **kwargs
)

def fetch1_dataframe(self):
"""Convenience function for returning the marks in a readable format"""
return self.fetch_dataframe()[0]
Expand Down
11 changes: 5 additions & 6 deletions src/spyglass/common/common_sensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,27 @@
import datajoint as dj
import pynwb

from ..utils.dj_mixin import SpyglassMixin
from ..utils.nwb_helper_fn import get_data_interface, get_nwb_file
from .common_ephys import Raw
from .common_interval import IntervalList # noqa: F401
from .common_nwbfile import Nwbfile
from .common_session import Session # noqa: F401
from ..utils.dj_helper_fn import fetch_nwb
from ..utils.nwb_helper_fn import get_data_interface, get_nwb_file

schema = dj.schema("common_sensors")


@schema
class SensorData(dj.Imported):
class SensorData(SpyglassMixin, dj.Imported):
definition = """
-> Session
---
sensor_data_object_id: varchar(40) # object id of the data in the NWB file
-> IntervalList # the list of intervals for this object
"""

_nwb_table = Nwbfile

def make(self, key):
nwb_file_name = key["nwb_file_name"]
nwb_file_abspath = Nwbfile().get_abs_path(nwb_file_name)
Expand All @@ -40,6 +42,3 @@ def make(self, key):
Raw & {"nwb_file_name": nwb_file_name}
).fetch1("interval_list_name")
self.insert1(key)

def fetch_nwb(self, *attrs, **kwargs):
return fetch_nwb(self, (Nwbfile, "nwb_file_abs_path"), *attrs, **kwargs)
Loading

0 comments on commit 135d808

Please sign in to comment.