From f56aba0de288e1a9bac10b64830fdfbfcaca2de2 Mon Sep 17 00:00:00 2001 From: Chris Broz Date: Thu, 5 Dec 2024 14:16:48 -0600 Subject: [PATCH] Misc fixes (#1192) * #1175 * #1185 * #1183 * Fix circular import * #1163 * #1105 * Fix failing tests, close download subprocesses * WIP: fix decode changes spikesort tests * Fix fickle test * Revert typo --- CHANGELOG.md | 9 ++- pyproject.toml | 8 ++- src/spyglass/common/common_lab.py | 6 +- src/spyglass/common/common_usage.py | 40 +++++++------ src/spyglass/decoding/v1/core.py | 4 +- .../position/v1/position_dlc_model.py | 18 ++++-- .../position/v1/position_dlc_project.py | 2 - src/spyglass/spikesorting/v1/curation.py | 13 ++--- src/spyglass/utils/dj_helper_fn.py | 14 ++++- src/spyglass/utils/dj_mixin.py | 11 +++- tests/conftest.py | 45 ++++++++++---- tests/data_downloader.py | 58 +++++++++++++------ tests/decoding/test_clusterless.py | 1 + tests/spikesorting/test_curation.py | 1 - tests/spikesorting/test_merge.py | 16 ++--- 15 files changed, 162 insertions(+), 84 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4d2fa2741..9ca5b47e1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -43,6 +43,7 @@ dj.FreeTable(dj.conn(), "common_session.session_group").drop() - Remove numpy version restriction #1169 - Merge table delete removes orphaned master entries #1164 - Edit `merge_fetch` to expect positional before keyword arguments #1181 +- Allow part restriction `SpyglassMixinPart.delete` #1192 ### Pipelines @@ -52,8 +53,11 @@ dj.FreeTable(dj.conn(), "common_session.session_group").drop() - Improve electrodes import efficiency #1125 - Fix logger method call in `common_task` #1132 - Export fixes #1164 - - Allow `get_abs_path` to add selection entry. - - Log restrictions and joins. + - Allow `get_abs_path` to add selection entry. #1164 + - Log restrictions and joins. #1164 + - Check if querying table inherits mixin in `fetch_nwb`. #1192 + - Ensure externals entries before adding to export. #1192 + - Error specificity in `LabMemberInfo` #1192 - Decoding @@ -74,6 +78,7 @@ dj.FreeTable(dj.conn(), "common_session.session_group").drop() `open-cv` #1168 - `VideoMaker` class to process frames in multithreaded batches #1168, #1174 - `TrodesPosVideo` updates for `matplotlib` processor #1174 + - User prompt if ambiguous insert in `DLCModelSource` #1192 - Spike Sorting diff --git a/pyproject.toml b/pyproject.toml index a5d8d032d..0a1cd627f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -131,7 +131,7 @@ addopts = [ # "--pdb", # drop into debugger on failure "-p no:warnings", # "--no-teardown", # don't teardown the database after tests - # "--quiet-spy", # don't show logging from spyglass + "--quiet-spy", # don't show logging from spyglass # "--no-dlc", # don't run DLC tests "--show-capture=no", "--pdbcls=IPython.terminal.debugger:TerminalPdb", # use ipython debugger @@ -148,6 +148,12 @@ env = [ "TF_ENABLE_ONEDNN_OPTS = 0", # TF disable approx calcs "TF_CPP_MIN_LOG_LEVEL = 2", # Disable TF warnings ] +filterwarnings = [ + "ignore::ResourceWarning:.*", + "ignore::DeprecationWarning:.*", + "ignore::UserWarning:.*", + "ignore::MissingRequiredBuildWarning:.*", +] [tool.coverage.run] source = ["*/src/spyglass/*"] diff --git a/src/spyglass/common/common_lab.py b/src/spyglass/common/common_lab.py index 486041abc..72bdafa6b 100644 --- a/src/spyglass/common/common_lab.py +++ b/src/spyglass/common/common_lab.py @@ -133,9 +133,11 @@ def get_djuser_name(cls, dj_user) -> str: ) if len(query) != 1: + remedy = f"delete {len(query)-1}" if len(query) > 1 else "add one" raise ValueError( - f"Could not find name for datajoint user {dj_user}" - + f" in common.LabMember.LabMemberInfo: {query}" + f"Could not find exactly 1 datajoint user {dj_user}" + + " in common.LabMember.LabMemberInfo. " + + f"Please {remedy}: {query}" ) return query[0] diff --git a/src/spyglass/common/common_usage.py b/src/spyglass/common/common_usage.py index ccbf7c909..58b28c0f6 100644 --- a/src/spyglass/common/common_usage.py +++ b/src/spyglass/common/common_usage.py @@ -9,12 +9,10 @@ from typing import List, Union import datajoint as dj -from datajoint import FreeTable -from datajoint import config as dj_config from pynwb import NWBHDF5IO from spyglass.common.common_nwbfile import AnalysisNwbfile, Nwbfile -from spyglass.settings import export_dir, test_mode +from spyglass.settings import test_mode from spyglass.utils import SpyglassMixin, SpyglassMixinPart, logger from spyglass.utils.dj_graph import RestrGraph from spyglass.utils.dj_helper_fn import ( @@ -174,7 +172,6 @@ def list_file_paths(self, key: dict, as_dict=True) -> list[str]: Return as a list of dicts: [{'file_path': x}]. Default True. If False, returns a list of strings without key. """ - file_table = self * self.File & key unique_fp = { *[ AnalysisNwbfile().get_abs_path(p) @@ -210,21 +207,26 @@ def _add_externals_to_restr_graph( restr_graph : RestrGraph The updated RestrGraph """ - raw_tbl = self._externals["raw"] - raw_name = raw_tbl.full_table_name - raw_restr = ( - "filepath in ('" + "','".join(self._list_raw_files(key)) + "')" - ) - restr_graph.graph.add_node(raw_name, ft=raw_tbl, restr=raw_restr) - - analysis_tbl = self._externals["analysis"] - analysis_name = analysis_tbl.full_table_name - analysis_restr = ( # filepaths have analysis subdir. regexp substrings - "filepath REGEXP '" + "|".join(self._list_analysis_files(key)) + "'" - ) # regexp is slow, but we're only doing this once, and future-proof - restr_graph.graph.add_node( - analysis_name, ft=analysis_tbl, restr=analysis_restr - ) + + if raw_files := self._list_raw_files(key): + raw_tbl = self._externals["raw"] + raw_name = raw_tbl.full_table_name + raw_restr = "filepath in ('" + "','".join(raw_files) + "')" + restr_graph.graph.add_node(raw_name, ft=raw_tbl, restr=raw_restr) + restr_graph.visited.add(raw_name) + + if analysis_files := self._list_analysis_files(key): + analysis_tbl = self._externals["analysis"] + analysis_name = analysis_tbl.full_table_name + # to avoid issues with analysis subdir, we use REGEXP + # this is slow, but we're only doing this once, and future-proof + analysis_restr = ( + "filepath REGEXP '" + "|".join(analysis_files) + "'" + ) + restr_graph.graph.add_node( + analysis_name, ft=analysis_tbl, restr=analysis_restr + ) + restr_graph.visited.add(analysis_name) restr_graph.visited.update({raw_name, analysis_name}) diff --git a/src/spyglass/decoding/v1/core.py b/src/spyglass/decoding/v1/core.py index d58af1643..177a87d22 100644 --- a/src/spyglass/decoding/v1/core.py +++ b/src/spyglass/decoding/v1/core.py @@ -126,8 +126,8 @@ def create_group( } if self & group_key: logger.error( # Easier for pytests to not raise error on duplicate - f"Group {nwb_file_name}: {group_name} already exists" - + "please delete the group before creating a new one" + f"Group {nwb_file_name}: {group_name} already exists. " + + "Please delete the group before creating a new one" ) return self.insert1( diff --git a/src/spyglass/position/v1/position_dlc_model.py b/src/spyglass/position/v1/position_dlc_model.py index 2a64c05c4..78325b306 100644 --- a/src/spyglass/position/v1/position_dlc_model.py +++ b/src/spyglass/position/v1/position_dlc_model.py @@ -98,16 +98,24 @@ def insert_entry( dj.conn(), full_table_name=part_table.parents()[-1] ) & {"project_name": project_name} - if cls._test_mode: # temporary fix for #1105 - project_path = table_query.fetch(limit=1)[0] - else: - project_path = table_query.fetch1("project_path") + n_found = len(table_query) + if n_found != 1: + logger.warning( + f"Found {len(table_query)} entries found for project " + + f"{project_name}:\n{table_query}" + ) + + choice = "y" + if n_found > 1 and not cls._test_mode: + choice = dj.utils.user_choice("Use first entry?")[0] + if n_found == 0 or choice != "y": + return part_table.insert1( { "dlc_model_name": dlc_model_name, "project_name": project_name, - "project_path": project_path, + "project_path": table_query.fetch("project_path", limit=1)[0], **key, }, **kwargs, diff --git a/src/spyglass/position/v1/position_dlc_project.py b/src/spyglass/position/v1/position_dlc_project.py index 2f19b1664..86617a526 100644 --- a/src/spyglass/position/v1/position_dlc_project.py +++ b/src/spyglass/position/v1/position_dlc_project.py @@ -57,8 +57,6 @@ class DLCProject(SpyglassMixin, dj.Manual): With ability to edit config, extract frames, label frames """ - # Add more parameters as secondary keys... - # TODO: collapse params into blob dict definition = """ project_name : varchar(100) # name of DLC project --- diff --git a/src/spyglass/spikesorting/v1/curation.py b/src/spyglass/spikesorting/v1/curation.py index 00b1ef81e..593d2c1de 100644 --- a/src/spyglass/spikesorting/v1/curation.py +++ b/src/spyglass/spikesorting/v1/curation.py @@ -9,7 +9,6 @@ import spikeinterface.extractors as se from spyglass.common import BrainRegion, Electrode -from spyglass.common.common_ephys import Raw from spyglass.common.common_nwbfile import AnalysisNwbfile from spyglass.spikesorting.v1.recording import ( SortGroup, @@ -17,7 +16,7 @@ SpikeSortingRecordingSelection, ) from spyglass.spikesorting.v1.sorting import SpikeSorting, SpikeSortingSelection -from spyglass.utils.dj_mixin import SpyglassMixin +from spyglass.utils import SpyglassMixin, logger schema = dj.schema("spikesorting_v1_curation") @@ -84,13 +83,13 @@ def insert_curation( sort_query = cls & {"sorting_id": sorting_id} parent_curation_id = max(parent_curation_id, -1) - if parent_curation_id == -1: + + parent_query = sort_query & {"curation_id": parent_curation_id} + if parent_curation_id == -1 and len(parent_query): # check to see if this sorting with a parent of -1 # has already been inserted and if so, warn the user - query = sort_query & {"parent_curation_id": -1} - if query: - Warning("Sorting has already been inserted.") - return query.fetch("KEY") + logger.warning("Sorting has already been inserted.") + return parent_query.fetch("KEY") # generate curation ID existing_curation_ids = sort_query.fetch("curation_id") diff --git a/src/spyglass/utils/dj_helper_fn.py b/src/spyglass/utils/dj_helper_fn.py index 42cf67ba0..de07de85b 100644 --- a/src/spyglass/utils/dj_helper_fn.py +++ b/src/spyglass/utils/dj_helper_fn.py @@ -223,6 +223,7 @@ def get_nwb_table(query_expression, tbl, attr_name, *attrs, **kwargs): Function to get the absolute path to the NWB file. """ from spyglass.common.common_nwbfile import AnalysisNwbfile, Nwbfile + from spyglass.utils.dj_mixin import SpyglassMixin kwargs["as_dict"] = True # force return as dictionary attrs = attrs or query_expression.heading.names # if none, all @@ -234,11 +235,18 @@ def get_nwb_table(query_expression, tbl, attr_name, *attrs, **kwargs): } file_name_str, file_path_fn = tbl_map[which] + # logging arg only if instanced table inherits Mixin + inst = ( # instancing may not be necessary + query_expression() + if isinstance(query_expression, type) + and issubclass(query_expression, dj.Table) + else query_expression + ) + arg = dict(log_export=False) if isinstance(inst, SpyglassMixin) else dict() + # TODO: check that the query_expression restricts tbl - CBroz nwb_files = ( - query_expression.join( - tbl.proj(nwb2load_filepath=attr_name), log_export=False - ) + query_expression.join(tbl.proj(nwb2load_filepath=attr_name), **arg) ).fetch(file_name_str) # Disabled #1024 diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index 72e34c04f..91cf35870 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -145,10 +145,10 @@ def _nwb_table_tuple(self) -> tuple: Used to determine fetch_nwb behavior. Also used in Merge.fetch_nwb. Implemented as a cached_property to avoid circular imports.""" - from spyglass.common.common_nwbfile import ( + from spyglass.common.common_nwbfile import ( # noqa F401 AnalysisNwbfile, Nwbfile, - ) # noqa F401 + ) table_dict = { AnalysisNwbfile: "analysis_file_abs_path", @@ -857,4 +857,9 @@ def delete(self, *args, **kwargs): """Delete master and part entries.""" restriction = self.restriction or True # for (tbl & restr).delete() - (self.master & restriction).delete(*args, **kwargs) + try: # try restriction on master + restricted = self.master & restriction + except DataJointError: # if error, assume restr of self + restricted = self & restriction + + restricted.delete(*args, **kwargs) diff --git a/tests/conftest.py b/tests/conftest.py index 22ffa5d2b..1cb54cbd1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,19 +17,13 @@ import pynwb import pytest from datajoint.logging import logger as dj_logger +from hdmf.build.warnings import MissingRequiredBuildWarning from numba import NumbaWarning from pandas.errors import PerformanceWarning from .container import DockerMySQLManager from .data_downloader import DataDownloader -warnings.filterwarnings("ignore", category=UserWarning, module="hdmf") -warnings.filterwarnings("ignore", module="tensorflow") -warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn") -warnings.filterwarnings("ignore", category=PerformanceWarning, module="pandas") -warnings.filterwarnings("ignore", category=NumbaWarning, module="numba") -warnings.filterwarnings("ignore", category=ResourceWarning, module="datajoint") - # ------------------------------- TESTS CONFIG ------------------------------- # globals in pytest_configure: @@ -114,6 +108,19 @@ def pytest_configure(config): download_dlc=not NO_DLC, ) + warnings.filterwarnings("ignore", module="tensorflow") + warnings.filterwarnings("ignore", category=UserWarning, module="hdmf") + warnings.filterwarnings( + "ignore", category=MissingRequiredBuildWarning, module="hdmf" + ) + warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn") + warnings.filterwarnings( + "ignore", category=PerformanceWarning, module="pandas" + ) + warnings.filterwarnings("ignore", category=NumbaWarning, module="numba") + warnings.simplefilter("ignore", category=ResourceWarning) + warnings.simplefilter("ignore", category=DeprecationWarning) + def pytest_unconfigure(config): from spyglass.utils.nwb_helper_fn import close_nwb_files @@ -121,6 +128,9 @@ def pytest_unconfigure(config): close_nwb_files() if TEARDOWN: SERVER.stop() + analysis_dir = BASE_DIR / "analysis" + for file in analysis_dir.glob("*.nwb"): + file.unlink() # ---------------------------- FIXTURES, TEST ENV ---------------------------- @@ -1357,6 +1367,8 @@ def sorter_dict(): @pytest.fixture(scope="session") def pop_sort(spike_v1, pop_rec, pop_art, mini_dict, sorter_dict): + pre = spike_v1.SpikeSorting().fetch("KEY", as_dict=True) + key = { **mini_dict, **sorter_dict, @@ -1367,7 +1379,9 @@ def pop_sort(spike_v1, pop_rec, pop_art, mini_dict, sorter_dict): spike_v1.SpikeSortingSelection.insert_selection(key) spike_v1.SpikeSorting.populate() - yield spike_v1.SpikeSorting().fetch("KEY", as_dict=True)[0] + yield (spike_v1.SpikeSorting() - pre).fetch( + "KEY", as_dict=True, order_by="time_of_sort desc" + )[0] @pytest.fixture(scope="session") @@ -1379,9 +1393,16 @@ def sorting_objs(spike_v1, pop_sort): @pytest.fixture(scope="session") def pop_curation(spike_v1, pop_sort): + + parent_curation_id = -1 + has_sort = spike_v1.CurationV1 & {"sorting_id": pop_sort["sorting_id"]} + if has_sort: + parent_curation_id = has_sort.fetch1("curation_id") + spike_v1.CurationV1.insert_curation( sorting_id=pop_sort["sorting_id"], description="testing sort", + parent_curation_id=parent_curation_id, ) yield (spike_v1.CurationV1() & {"parent_curation_id": -1}).fetch( @@ -1418,20 +1439,20 @@ def metric_objs(spike_v1, pop_metric): @pytest.fixture(scope="session") def pop_curation_metric(spike_v1, pop_metric, metric_objs): labels, merge_groups, metrics = metric_objs - parent_dict = {"parent_curation_id": 0} + desc_dict = dict(description="after metric curation") spike_v1.CurationV1.insert_curation( sorting_id=( spike_v1.MetricCurationSelection & {"metric_curation_id": pop_metric["metric_curation_id"]} ).fetch1("sorting_id"), - **parent_dict, + parent_curation_id=0, labels=labels, merge_groups=merge_groups, metrics=metrics, - description="after metric curation", + **desc_dict, ) - yield (spike_v1.CurationV1 & parent_dict).fetch("KEY", as_dict=True)[0] + yield (spike_v1.CurationV1 & desc_dict).fetch("KEY", as_dict=True)[0] @pytest.fixture(scope="session") diff --git a/tests/data_downloader.py b/tests/data_downloader.py index 40af1ea88..99f1c3ee4 100644 --- a/tests/data_downloader.py +++ b/tests/data_downloader.py @@ -85,29 +85,53 @@ def file_downloads(self) -> Dict[str, Union[Popen, None]]: if dest.exists(): cmd = ["echo", f"Already have {target}"] + ret[target] = "Done" else: cmd = ["curl", "-L", "--output", str(dest), f"{path['url']}"] - - print(f"cmd: {cmd}") - - ret[target] = Popen(cmd, **self.cmd_kwargs) + ret[target] = Popen(cmd, **self.cmd_kwargs) return ret - def wait_for(self, target: str): - """Wait for target to finish downloading.""" - status = self.file_downloads.get(target).poll() - - limit = 10 - while status is None and limit > 0: - time_sleep(5) - limit -= 1 - status = self.file_downloads.get(target).poll() + def wait_for(self, target: str, timeout: int = 50, interval=5): + """Wait for target to finish downloading, and clean up if needed. + + Parameters + ---------- + target : str + Name of file to wait for. + timeout : int, optional + Maximum time to wait for download to finish. + interval : int, optional + Time between checks for download completion. + + Raises + ------ + ValueError + If download failed or target not being downloaded. + TimeoutError + If download took too long. + """ + process = self.file_downloads.get(target) + if not process: + raise ValueError(f"No active download process for target: {target}") + if process == "Done": + return - if status != 0: # Error downloading - raise ValueError(f"Error downloading: {target}") - if limit < 1: # Reached attempt limit - raise TimeoutError(f"Timeout downloading: {target}") + elapsed_time = 0 + try: # Refactored to clean up process streams + while (status := process.poll()) is None: + if elapsed_time >= timeout: + process.terminate() # Terminate on timeout + process.wait() + raise TimeoutError(f"Timeout waiting for {target}.") + time_sleep(interval) + elapsed_time += interval + if status != 0: + raise ValueError(f"Error occurred during download of {target}.") + finally: # Ensure process streams are closed and cleaned up + process.stdout and process.stdout.close() + process.stderr and process.stderr.close() + self.file_downloads[target] = "Done" # Remove target from dict def move_dlc_items(self, dest_dir: Path): """Move completed DLC files to dest_dir.""" diff --git a/tests/decoding/test_clusterless.py b/tests/decoding/test_clusterless.py index fc8967454..66d80c8d0 100644 --- a/tests/decoding/test_clusterless.py +++ b/tests/decoding/test_clusterless.py @@ -31,6 +31,7 @@ def test_fetch_linearized_position(clusterless_pop, clusterless_key): assert lin_pos is not None, "Linearized position is None" +# NOTE: Impacts spikesorting merge tests def test_fetch_spike_by_interval(decode_v1, clusterless_pop, clusterless_key): begin, end = decode_v1.clusterless._get_interval_range(clusterless_key) spikes = clusterless_pop.fetch_spike_data( diff --git a/tests/spikesorting/test_curation.py b/tests/spikesorting/test_curation.py index eac00ab0e..dccff0f69 100644 --- a/tests/spikesorting/test_curation.py +++ b/tests/spikesorting/test_curation.py @@ -80,7 +80,6 @@ def test_curation_sort_metric(spike_v1, pop_curation, pop_curation_metric): expected = { "bad_channel": "False", "contacts": "", - "curation_id": 1, "description": "after metric curation", "electrode_group_name": "0", "electrode_id": 0, diff --git a/tests/spikesorting/test_merge.py b/tests/spikesorting/test_merge.py index ad7ba2510..638179271 100644 --- a/tests/spikesorting/test_merge.py +++ b/tests/spikesorting/test_merge.py @@ -68,22 +68,22 @@ def test_merge_get_sort_group_info(spike_merge, pop_spike_merge): @pytest.fixture(scope="session") def merge_times(spike_merge, pop_spike_merge): - yield spike_merge.get_spike_times(pop_spike_merge) + yield spike_merge.get_spike_times(pop_spike_merge)[0] + + +def assert_shape(df, expected: tuple, msg: str = None): + assert df.shape == expected, f"Unexpected shape: {msg}" def test_merge_get_spike_times(merge_times): - assert ( - merge_times[0].shape[0] == 23908 - ), "SpikeSortingOutput.get_spike_times unexpected shape" + assert_shape(merge_times, (243,), "SpikeSortingOutput.get_spike_times") -@pytest.mark.skip(reason="Not testing bc #1077") def test_merge_get_spike_indicators(spike_merge, pop_spike_merge, merge_times): ret = spike_merge.get_spike_indicator(pop_spike_merge, time=merge_times) - raise NotImplementedError(ret) + assert_shape(ret, (243, 3), "SpikeSortingOutput.get_spike_indicator") -@pytest.mark.skip(reason="Not testing bc #1077") def test_merge_get_firing_rate(spike_merge, pop_spike_merge, merge_times): ret = spike_merge.get_firing_rate(pop_spike_merge, time=merge_times) - raise NotImplementedError(ret) + assert_shape(ret, (243, 3), "SpikeSortingOutput.get_firing_rate")