Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/LorenFrankLab/spyglass in…
Browse files Browse the repository at this point in the history
…to rel
  • Loading branch information
CBroz1 committed Dec 20, 2024
2 parents 4c52a4c + 36bd132 commit ad54090
Show file tree
Hide file tree
Showing 20 changed files with 332 additions and 124 deletions.
13 changes: 11 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
- Remove numpy version restriction #1169
- Merge table delete removes orphaned master entries #1164
- Edit `merge_fetch` to expect positional before keyword arguments #1181
- Allow part restriction `SpyglassMixinPart.delete` #1192
- Move cleanup of `IntervalList` orphan entries to cron job cleanup process #1195
- Add mixin method `get_fully_defined_key` #1198

### Pipelines

Expand All @@ -40,12 +43,17 @@
- Improve electrodes import efficiency #1125
- Fix logger method call in `common_task` #1132
- Export fixes #1164
- Allow `get_abs_path` to add selection entry.
- Log restrictions and joins.
- Allow `get_abs_path` to add selection entry. #1164
- Log restrictions and joins. #1164
- Check if querying table inherits mixin in `fetch_nwb`. #1192, #1201
- Ensure externals entries before adding to export. #1192
- Error specificity in `LabMemberInfo` #1192

- Decoding

- Fix edge case errors in spike time loading #1083
- Allow fetch of partial key from `DecodingParameters` #1198
- Allow data fetching with partial but unique key #1198

- Linearization

Expand All @@ -62,6 +70,7 @@
`open-cv` #1168
- `VideoMaker` class to process frames in multithreaded batches #1168, #1174
- `TrodesPosVideo` updates for `matplotlib` processor #1174
- User prompt if ambiguous insert in `DLCModelSource` #1192

- Spike Sorting

Expand Down
11 changes: 9 additions & 2 deletions docs/src/ForDevelopers/Management.md
Original file line number Diff line number Diff line change
Expand Up @@ -228,10 +228,16 @@ disk. There are several tables that retain lists of files that have been
generated during analyses. If someone deletes analysis entries, files will still
be on disk.
To remove orphaned files, we run the following commands in our cron jobs:
Additionally, there are periphery tables such as `IntervalList` which are used
to store entries created by downstream tables. These entries are not
automatically deleted when the downstream entry is removed. To minimize interference
with ongoing user entry creation, we recommend running these cleanups on a less frequent
basis (e.g. weekly).
To remove orphaned files and entries, we run the following commands in our cron jobs:
```python
from spyglass.common import AnalysisNwbfile
from spyglass.common import AnalysisNwbfile, IntervalList
from spyglass.spikesorting import SpikeSorting
from spyglass.common.common_nwbfile import schema as nwbfile_schema
from spyglass.decoding.v1.sorted_spikes import schema as spikes_schema
Expand All @@ -241,6 +247,7 @@ from spyglass.decoding.v1.clusterless import schema as clusterless_schema
def main():
AnalysisNwbfile().nightly_cleanup()
SpikeSorting().nightly_cleanup()
IntervalList().cleanup()
nwbfile_schema.external['analysis'].delete(delete_external_files=True))
nwbfile_schema.external['raw'].delete(delete_external_files=True))
spikes_schema.external['analysis'].delete(delete_external_files=True))
Expand Down
8 changes: 7 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ addopts = [
# "--pdb", # drop into debugger on failure
"-p no:warnings",
# "--no-teardown", # don't teardown the database after tests
# "--quiet-spy", # don't show logging from spyglass
"--quiet-spy", # don't show logging from spyglass
# "--no-dlc", # don't run DLC tests
"--show-capture=no",
"--pdbcls=IPython.terminal.debugger:TerminalPdb", # use ipython debugger
Expand All @@ -148,6 +148,12 @@ env = [
"TF_ENABLE_ONEDNN_OPTS = 0", # TF disable approx calcs
"TF_CPP_MIN_LOG_LEVEL = 2", # Disable TF warnings
]
filterwarnings = [
"ignore::ResourceWarning:.*",
"ignore::DeprecationWarning:.*",
"ignore::UserWarning:.*",
"ignore::MissingRequiredBuildWarning:.*",
]

[tool.coverage.run]
source = ["*/src/spyglass/*"]
Expand Down
2 changes: 1 addition & 1 deletion src/spyglass/common/common_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def plot_epoch_pos_raw_intervals(self, figsize=(20, 5), return_fig=False):
if return_fig:
return fig

def nightly_cleanup(self, dry_run=True):
def cleanup(self, dry_run=True):
"""Clean up orphaned IntervalList entries."""
orphans = self - get_child_tables(self)
if dry_run:
Expand Down
6 changes: 4 additions & 2 deletions src/spyglass/common/common_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,11 @@ def get_djuser_name(cls, dj_user) -> str:
)

if len(query) != 1:
remedy = f"delete {len(query)-1}" if len(query) > 1 else "add one"
raise ValueError(
f"Could not find name for datajoint user {dj_user}"
+ f" in common.LabMember.LabMemberInfo: {query}"
f"Could not find exactly 1 datajoint user {dj_user}"
+ " in common.LabMember.LabMemberInfo. "
+ f"Please {remedy}: {query}"
)

return query[0]
Expand Down
40 changes: 21 additions & 19 deletions src/spyglass/common/common_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,10 @@
from typing import List, Union

import datajoint as dj
from datajoint import FreeTable
from datajoint import config as dj_config
from pynwb import NWBHDF5IO

from spyglass.common.common_nwbfile import AnalysisNwbfile, Nwbfile
from spyglass.settings import export_dir, test_mode
from spyglass.settings import test_mode
from spyglass.utils import SpyglassMixin, SpyglassMixinPart, logger
from spyglass.utils.dj_graph import RestrGraph
from spyglass.utils.dj_helper_fn import (
Expand Down Expand Up @@ -174,7 +172,6 @@ def list_file_paths(self, key: dict, as_dict=True) -> list[str]:
Return as a list of dicts: [{'file_path': x}]. Default True.
If False, returns a list of strings without key.
"""
file_table = self * self.File & key
unique_fp = {
*[
AnalysisNwbfile().get_abs_path(p)
Expand Down Expand Up @@ -210,21 +207,26 @@ def _add_externals_to_restr_graph(
restr_graph : RestrGraph
The updated RestrGraph
"""
raw_tbl = self._externals["raw"]
raw_name = raw_tbl.full_table_name
raw_restr = (
"filepath in ('" + "','".join(self._list_raw_files(key)) + "')"
)
restr_graph.graph.add_node(raw_name, ft=raw_tbl, restr=raw_restr)

analysis_tbl = self._externals["analysis"]
analysis_name = analysis_tbl.full_table_name
analysis_restr = ( # filepaths have analysis subdir. regexp substrings
"filepath REGEXP '" + "|".join(self._list_analysis_files(key)) + "'"
) # regexp is slow, but we're only doing this once, and future-proof
restr_graph.graph.add_node(
analysis_name, ft=analysis_tbl, restr=analysis_restr
)

if raw_files := self._list_raw_files(key):
raw_tbl = self._externals["raw"]
raw_name = raw_tbl.full_table_name
raw_restr = "filepath in ('" + "','".join(raw_files) + "')"
restr_graph.graph.add_node(raw_name, ft=raw_tbl, restr=raw_restr)
restr_graph.visited.add(raw_name)

if analysis_files := self._list_analysis_files(key):
analysis_tbl = self._externals["analysis"]
analysis_name = analysis_tbl.full_table_name
# to avoid issues with analysis subdir, we use REGEXP
# this is slow, but we're only doing this once, and future-proof
analysis_restr = (
"filepath REGEXP '" + "|".join(analysis_files) + "'"
)
restr_graph.graph.add_node(
analysis_name, ft=analysis_tbl, restr=analysis_restr
)
restr_graph.visited.add(analysis_name)

restr_graph.visited.update({raw_name, analysis_name})

Expand Down
46 changes: 38 additions & 8 deletions src/spyglass/decoding/v1/clusterless.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,8 @@ def fetch_model(self):
"""Retrieve the decoding model"""
return ClusterlessDetector.load_model(self.fetch1("classifier_path"))

@staticmethod
def fetch_environments(key):
@classmethod
def fetch_environments(cls, key):
"""Fetch the environments for the decoding model
Parameters
Expand All @@ -330,6 +330,9 @@ def fetch_environments(key):
List[TrackGraph]
list of track graphs in the trained model
"""
key = cls.get_fully_defined_key(
key, required_fields=["decoding_param_name"]
)
model_params = (
DecodingParameters
& {"decoding_param_name": key["decoding_param_name"]}
Expand All @@ -355,8 +358,8 @@ def fetch_environments(key):

return classifier.environments

@staticmethod
def fetch_position_info(key):
@classmethod
def fetch_position_info(cls, key):
"""Fetch the position information for the decoding model
Parameters
Expand All @@ -369,6 +372,15 @@ def fetch_position_info(key):
Tuple[pd.DataFrame, List[str]]
The position information and the names of the position variables
"""
key = cls.get_fully_defined_key(
key,
required_fields=[
"nwb_file_name",
"position_group_name",
"encoding_interval",
"decoding_interval",
],
)
position_group_key = {
"position_group_name": key["position_group_name"],
"nwb_file_name": key["nwb_file_name"],
Expand All @@ -381,8 +393,8 @@ def fetch_position_info(key):

return position_info, position_variable_names

@staticmethod
def fetch_linear_position_info(key):
@classmethod
def fetch_linear_position_info(cls, key):
"""Fetch the position information and project it onto the track graph
Parameters
Expand All @@ -395,6 +407,16 @@ def fetch_linear_position_info(key):
pd.DataFrame
The linearized position information
"""
key = cls.get_fully_defined_key(
key,
required_fields=[
"nwb_file_name",
"position_group_name",
"encoding_interval",
"decoding_interval",
],
)

environment = ClusterlessDecodingV1.fetch_environments(key)[0]

position_df = ClusterlessDecodingV1.fetch_position_info(key)[0]
Expand All @@ -417,8 +439,8 @@ def fetch_linear_position_info(key):
axis=1,
).loc[min_time:max_time]

@staticmethod
def fetch_spike_data(key, filter_by_interval=True):
@classmethod
def fetch_spike_data(cls, key, filter_by_interval=True):
"""Fetch the spike times for the decoding model
Parameters
Expand All @@ -434,6 +456,14 @@ def fetch_spike_data(key, filter_by_interval=True):
list[np.ndarray]
List of spike times for each unit in the model's spike group
"""
key = cls.get_fully_defined_key(
key,
required_fields=[
"nwb_file_name",
"waveform_features_group_name",
],
)

waveform_keys = (
(
UnitWaveformFeaturesGroup.UnitFeatures
Expand Down
52 changes: 36 additions & 16 deletions src/spyglass/decoding/v1/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,28 +70,48 @@ def insert(self, rows, *args, **kwargs):
def fetch(self, *args, **kwargs):
"""Return decoding parameters as a list of classes."""
rows = super().fetch(*args, **kwargs)
if len(rows) > 0 and len(rows[0]) > 1:
if kwargs.get("format", None) == "array":
# case when recalled by dj.fetch(), class conversion performed later in stack
return rows

if not len(args):
# infer args from table heading
args = tuple(self.heading)

if "decoding_params" not in args:
return rows

params_index = args.index("decoding_params")
if len(args) == 1:
# only fetching decoding_params
content = [restore_classes(r) for r in rows]
elif len(rows):
content = []
for (
decoding_param_name,
decoding_params,
decoding_kwargs,
) in rows:
content.append(
(
decoding_param_name,
restore_classes(decoding_params),
decoding_kwargs,
)
)
for row in zip(*rows):
row = list(row)
row[params_index] = restore_classes(row[params_index])
content.append(tuple(row))
else:
content = rows
return content

def fetch1(self, *args, **kwargs):
"""Return one decoding paramset as a class."""
row = super().fetch1(*args, **kwargs)
row["decoding_params"] = restore_classes(row["decoding_params"])

if len(args) == 0:
row["decoding_params"] = restore_classes(row["decoding_params"])
return row

if "decoding_params" in args:
if len(args) == 1:
return restore_classes(row)
row = list(row)
row[args.index("decoding_params")] = restore_classes(
row[args.index("decoding_params")]
)
return tuple(row)

return row


Expand Down Expand Up @@ -126,8 +146,8 @@ def create_group(
}
if self & group_key:
logger.error( # Easier for pytests to not raise error on duplicate
f"Group {nwb_file_name}: {group_name} already exists"
+ "please delete the group before creating a new one"
f"Group {nwb_file_name}: {group_name} already exists. "
+ "Please delete the group before creating a new one"
)
return
self.insert1(
Expand Down
Loading

0 comments on commit ad54090

Please sign in to comment.