From 8800f993a8c967b69aeaab0a0c333e3c31e139d7 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Fri, 22 Dec 2023 11:43:06 -0600 Subject: [PATCH 01/16] WIP: Pull from old stash, resolve conflicts --- tests/README.md | 47 +++++ ...val.py => test_common_interval_helpers.py} | 0 tests/conftest.py | 173 +++++++++++++++--- tests/data_import/test_insert_sessions.py | 26 ++- tests/test_insert_beans.py | 19 +- tests/test_nwb_helper_fn.py | 6 +- 6 files changed, 222 insertions(+), 49 deletions(-) create mode 100644 tests/README.md rename tests/common/{test_common_interval.py => test_common_interval_helpers.py} (100%) diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 000000000..476dbb4c8 --- /dev/null +++ b/tests/README.md @@ -0,0 +1,47 @@ +# PyTests + +This directory is contains files for testing the code. Simply by running +`pytest` from the root directory, all tests will be run with default parameters +specified in `pyproject.toml`. Notable optional parameters include... + +- Coverage items. The coverage report indicates what percentage of the code was + included in tests. + + - `--cov=spyglatss`: Which package should be described in the coverage report + - `--cov-report term-missing`: Include lines of items missing in coverage + +- Verbosity. + + - `-v`: List individual tests, report pass/fail + - `--quiet-spy`: Default False. When True, print and other logging statements + from Spyglass are silenced. + +- Data and database. + + - `--no-server`: Default False, launch Docker container from python. When + True, no server is started and tests attempt to connect to existing + container. + - `--no-teardown`: Default False. When True, docker database tables are + preserved on exit. Set to false to inspect output items after testing. + - `--my-datadir ./rel-path/`: Default `./tests/test_data/`. Where to store + created files. + +- Incremental running. + + - `-m`: Run tests with the + [given marker](https://docs.pytest.org/en/6.2.x/usage.html#specifying-tests-selecting-tests) + (e.g., `pytest -m current`). + - `--sw`: Stepwise. Continue from previously failed test when starting again. + - `-s`: No capture. By including `from IPython import embed; embed()` in a + test, and using this flag, you can open an IPython environment from within + a test + - `--pdb`: Enter debug mode if a test fails. + - `tests/test_file.py -k test_name`: To run just a set of tests, specify the + file name at the end of the command. To run a single test, further specify + `-k` with the test name. + +When customizing parameters, comment out the `addopts` line in `pyproject.toml`. + +```console +pytest -m current --quiet-spy --no-teardown tests/test_file.py -k test_name +``` diff --git a/tests/common/test_common_interval.py b/tests/common/test_common_interval_helpers.py similarity index 100% rename from tests/common/test_common_interval.py rename to tests/common/test_common_interval_helpers.py diff --git a/tests/conftest.py b/tests/conftest.py index ac1539abf..ca80b5749 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,13 @@ -# directory-specific hook implementations import os +import pathlib import shutil import sys import tempfile +from contextlib import nullcontext import datajoint as dj +import pytest +from datajoint.logging import logger from .datajoint._config import DATAJOINT_SERVER_PORT from .datajoint._datajoint_server import ( @@ -12,39 +15,81 @@ run_datajoint_server, ) +# ---------------------- CONSTANTS --------------------- + thisdir = os.path.dirname(os.path.realpath(__file__)) sys.path.append(thisdir) - - global __PROCESS __PROCESS = None def pytest_addoption(parser): + """Permit constants when calling pytest at command line + + Example + ------- + > pytest --quiet-spy + + Parameters + ---------- + --quiet-spy (bool): Default False. Allow print statements from Spyglass. + --no-teardown (bool): Default False. Delete pipeline on close. + --no-server (bool): Default False. Run datajoint server in Docker. + --datadir (str): Default './tests/test_data/'. Dir for local input file. + WARNING: not yet implemented. + """ + parser.addoption( + "--quiet-spy", + action="store_true", + dest="quiet_spy", + default=False, + help="Quiet print statements from Spyglass.", + ) + parser.addoption( + "--no-server", + action="store_true", + dest="no_server", + default=False, + help="Do not launch datajoint server in Docker.", + ) parser.addoption( - "--current", + "--no-teardown", action="store_true", - dest="current", default=False, - help="run only tests marked as current", + dest="no_teardown", + help="Tear down tables after tests.", + ) + parser.addoption( + "--datadir", + action="store", + default="./tests/test_data/", + dest="datadir", + help="Directory for local input file.", ) -def pytest_configure(config): - config.addinivalue_line( - "markers", "current: for convenience -- mark one test as current" - ) +# ------------------- FIXTURES ------------------- + + +@pytest.fixture(scope="session") +def verbose_context(config): + """Verbosity context for suppressing Spyglass print statements.""" + return QuietStdOut() if config.option.quiet_spy else nullcontext() + - markexpr_list = [] +@pytest.fixture(scope="session") +def teardown(config): + return not config.option.no_teardown - if config.option.current: - markexpr_list.append("current") - if len(markexpr_list) > 0: - markexpr = " and ".join(markexpr_list) - setattr(config.option, "markexpr", markexpr) +@pytest.fixture(scope="session") +def spy_config(config): + pass - _set_env() + +def pytest_configure(config): + """Run on build, after parsing command line options.""" + _set_env(base_dir=config.option.datadir) # note that in this configuration, every test will use the same datajoint # server this may create conflicts and dependencies between tests it may be @@ -53,28 +98,102 @@ def pytest_configure(config): # datajoint runs when the source files are loaded, not when the tests are # run. one solution might be to restart the server after every test - global __PROCESS - __PROCESS = run_datajoint_server() + if not config.option.no_server: + global __PROCESS + __PROCESS = run_datajoint_server() def pytest_unconfigure(config): + """Called before test process is exited.""" if __PROCESS: - print("Terminating datajoint compute resource process") + logger.info("Terminating datajoint compute resource process") __PROCESS.terminate() - # TODO handle ResourceWarning: subprocess X is still running - # __PROCESS.join() - kill_datajoint_server() - shutil.rmtree(os.environ["SPYGLASS_BASE_DIR"]) + # TODO handle ResourceWarning: subprocess X is still running __PROCESS.join() + + if not config.option.no_server: + kill_datajoint_server() + shutil.rmtree(os.environ["SPYGLASS_BASE_DIR"]) -def _set_env(): +# ------------------ GENERAL FUNCTION ------------------ + + +def _set_env(base_dir): """Set environment variables.""" - print("Setting datajoint and kachery environment variables.") - os.environ["SPYGLASS_BASE_DIR"] = str(tempfile.mkdtemp()) + # TODO: change from tempdir to user supplied dir + # spyglass_base_dir = pathlib.Path(base_dir) + spyglass_base_dir = pathlib.Path(tempfile.mkdtemp()) + + spike_sorting_storage_dir = spyglass_base_dir / "spikesorting" + tmp_dir = spyglass_base_dir / "tmp" + + logger.info("Setting datajoint and kachery environment variables.") + logger.info("SPYGLASS_BASE_DIR set to", spyglass_base_dir) + + # TODO: make this a fixture + # spy_config_dict = dict( + # SPYGLASS_BASE_DIR=str(spyglass_base_dir), + # SPYGLASS_RECORDING_DIR=str(spyglass_base_dir / "recording"), + # SPYGLASS_SORTING_DIR=str(spyglass_base_dir / "sorting"), + # SPYGLASS_WAVEFORMS_DIR=str(spyglass_base_dir / "waveforms"), + # SPYGLASS_TEMP_DIR=str(tmp_dir), + # SPIKE_SORTING_STORAGE_DIR=str(spike_sorting_storage_dir), + # KACHERY_ZONE="franklab.collaborators", + # KACHERY_CLOUD_DIR="/stelmo/nwb/.kachery_cloud", + # KACHERY_STORAGE_DIR=str(spyglass_base_dir / "kachery_storage"), + # KACHERY_TEMP_DIR=str(spyglass_base_dir / "tmp"), + # FIGURL_CHANNEL="franklab2", + # DJ_SUPPORT_FILEPATH_MANAGEMENT="TRUE", + # KACHERY_CLOUD_EPHEMERAL="TRUE", + # ) + + os.environ["SPYGLASS_BASE_DIR"] = str(spyglass_base_dir) + os.environ["DJ_SUPPORT_FILEPATH_MANAGEMENT"] = "TRUE" + os.environ["SPIKE_SORTING_STORAGE_DIR"] = str(spike_sorting_storage_dir) + os.environ["SPYGLASS_TEMP_DIR"] = str(tmp_dir) + os.environ["KACHERY_CLOUD_EPHEMERAL"] = "TRUE" + + os.mkdir(spike_sorting_storage_dir) + os.mkdir(tmp_dir) + + raw_dir = spyglass_base_dir / "raw" + analysis_dir = spyglass_base_dir / "analysis" + + os.mkdir(raw_dir) + os.mkdir(analysis_dir) dj.config["database.host"] = "localhost" dj.config["database.port"] = DATAJOINT_SERVER_PORT dj.config["database.user"] = "root" dj.config["database.password"] = "tutorial" + + dj.config["stores"] = { + "raw": { + "protocol": "file", + "location": str(raw_dir), + "stage": str(raw_dir), + }, + "analysis": { + "protocol": "file", + "location": str(analysis_dir), + "stage": str(analysis_dir), + }, + } + + +class QuietStdOut: + """If quiet_spy, used to quiet prints, teardowns and table.delete prints""" + + def __enter__(self): + # os.environ["LOG_LEVEL"] = "WARNING" + logger.setLevel("CRITICAL") + self._original_stdout = sys.stdout + sys.stdout = open(os.devnull, "w") + + def __exit__(self, exc_type, exc_val, exc_tb): + # os.environ["LOG_LEVEL"] = "INFO" + logger.setLevel("INFO") + sys.stdout.close() + sys.stdout = self._original_stdout diff --git a/tests/data_import/test_insert_sessions.py b/tests/data_import/test_insert_sessions.py index d7968d164..65a16170a 100644 --- a/tests/data_import/test_insert_sessions.py +++ b/tests/data_import/test_insert_sessions.py @@ -9,7 +9,6 @@ from hdmf.backends.warnings import BrokenLinkWarning from spyglass.data_import.insert_sessions import copy_nwb_link_raw_ephys -from spyglass.settings import raw_dir @pytest.fixture() @@ -47,13 +46,26 @@ def new_nwbfile_raw_file_name(tmp_path): ) nwbfile.add_acquisition(es) - _ = tmp_path # CBroz: Changed to match testing base directory + spyglass_base_dir = tmp_path / "nwb-data" + os.environ["SPYGLASS_BASE_DIR"] = str(spyglass_base_dir) + os.mkdir(os.environ["SPYGLASS_BASE_DIR"]) - file_name = "raw.nwb" - file_path = raw_dir + "/" + file_name + raw_dir = spyglass_base_dir / "raw" + os.mkdir(raw_dir) + + dj.config["stores"] = { + "raw": { + "protocol": "file", + "location": str(raw_dir), + "stage": str(raw_dir), + }, + } + file_name = "raw.nwb" + file_path = raw_dir / file_name with pynwb.NWBHDF5IO(str(file_path), mode="w") as io: io.write(nwbfile) + return file_name @@ -91,10 +103,8 @@ def test_copy_nwb( new_nwbfile_raw_file_name_abspath ) - # test readability after moving the linking raw file (paths are stored as - # relative paths in NWB) so this should break the link (moving the linked-to - # file should also break the link) - + # test readability after moving the linking raw file (paths are stored as relative paths in NWB) + # so this should break the link (moving the linked-to file should also break the link) shutil.move(out_nwb_file_abspath, moved_nwbfile_no_ephys_file_path) with pynwb.NWBHDF5IO( path=str(moved_nwbfile_no_ephys_file_path), mode="r" diff --git a/tests/test_insert_beans.py b/tests/test_insert_beans.py index d74ecb856..29b3d7fb1 100644 --- a/tests/test_insert_beans.py +++ b/tests/test_insert_beans.py @@ -1,10 +1,14 @@ -from datetime import datetime -import kachery_cloud as kcl import os import pathlib +from datetime import datetime + +import kachery_cloud as kcl import pynwb import pytest +from spyglass.common import CameraDevice, DataAcquisitionDevice, Probe, Session +from spyglass.data_import import insert_sessions + @pytest.mark.skip(reason="test_path needs to be updated") def test_insert_sessions(): @@ -15,14 +19,6 @@ def test_insert_sessions(): raw_dir = pathlib.Path(os.environ["SPYGLASS_BASE_DIR"]) / "raw" nwbfile_path = raw_dir / "test.nwb" - from spyglass.common import ( - Session, - DataAcquisitionDevice, - CameraDevice, - Probe, - ) - from spyglass.data_import import insert_sessions - test_path = ( "ipfs://bafybeie4svt3paz5vr7cw7mkgibutbtbzyab4s24hqn5pzim3sgg56m3n4" ) @@ -31,7 +27,8 @@ def test_insert_sessions(): except Exception as e: if os.environ.get("KACHERY_CLOUD_EPHEMERAL", None) != "TRUE": print( - "Cannot load test file in non-ephemeral mode. Kachery cloud client may need to be registered." + "Cannot load test file in non-ephemeral mode. Kachery cloud" + + "client may need to be registered." ) raise e diff --git a/tests/test_nwb_helper_fn.py b/tests/test_nwb_helper_fn.py index ad382b0a4..8095f80e0 100644 --- a/tests/test_nwb_helper_fn.py +++ b/tests/test_nwb_helper_fn.py @@ -3,8 +3,8 @@ import pynwb -# NOTE: importing this calls spyglass.__init__ whichand spyglass.common.__init__ which both require the -# DataJoint MySQL server to be already set up and running +# NOTE: importing this calls spyglass.__init__ and spyglass.common.__init__ +# which both require the DataJoint MySQL server to be up and running from spyglass.common import get_electrode_indices @@ -48,7 +48,7 @@ def setUp(self): ) self.nwbfile.add_acquisition(eseries) - def test_nwbfile(self): + def test_electrode_nwbfile(self): ret = get_electrode_indices(self.nwbfile, [102, 105]) assert ret == [2, 5] From 03be7d4b34159b7e9cdb97a29d6280716499de30 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Fri, 29 Dec 2023 17:07:11 -0600 Subject: [PATCH 02/16] Pytest WIP. Position centriod fix. Centralize device prompt logic --- .github/workflows/test-conda.yml | 10 - pyproject.toml | 42 ++- src/spyglass/common/common_device.py | 158 +++++------- src/spyglass/common/common_position.py | 4 +- src/spyglass/common/common_session.py | 23 +- src/spyglass/data_import/__init__.py | 1 + src/spyglass/data_import/insert_sessions.py | 2 +- src/spyglass/settings.py | 16 ++ tests/ci_config.py | 27 -- tests/{datajoint => common}/__init__.py | 0 tests/common/test_common_interval_helpers.py | 62 ----- tests/common/test_insert.py | 174 +++++++++++++ tests/common/test_interval_helpers.py | 75 ++++++ tests/conftest.py | 256 +++++++++++-------- tests/container.py | 217 ++++++++++++++++ tests/data_import/__init__.py | 3 + tests/data_import/test_insert_sessions.py | 126 ++------- tests/datajoint/_config.py | 1 - tests/datajoint/_datajoint_server.py | 110 -------- tests/old_tests.py | 180 +++++++++++++ tests/test_insert_beans.py | 94 ------- tests/trim_beans.py | 73 ------ tests/{ => utils}/test_nwb_helper_fn.py | 8 +- 23 files changed, 974 insertions(+), 688 deletions(-) delete mode 100644 tests/ci_config.py rename tests/{datajoint => common}/__init__.py (100%) delete mode 100644 tests/common/test_common_interval_helpers.py create mode 100644 tests/common/test_insert.py create mode 100644 tests/common/test_interval_helpers.py create mode 100644 tests/container.py delete mode 100644 tests/datajoint/_config.py delete mode 100644 tests/datajoint/_datajoint_server.py create mode 100644 tests/old_tests.py delete mode 100644 tests/test_insert_beans.py delete mode 100644 tests/trim_beans.py rename tests/{ => utils}/test_nwb_helper_fn.py (89%) diff --git a/.github/workflows/test-conda.yml b/.github/workflows/test-conda.yml index cd793a480..576713163 100644 --- a/.github/workflows/test-conda.yml +++ b/.github/workflows/test-conda.yml @@ -17,16 +17,6 @@ jobs: env: OS: ${{ matrix.os }} PYTHON: '3.8' - # SPYGLASS_BASE_DIR: ./data - # KACHERY_STORAGE_DIR: ./data/kachery-storage - # DJ_SUPPORT_FILEPATH_MANAGEMENT: True - # services: - # datajoint_test_server: - # image: datajoint/mysql - # ports: - # - 3306:3306 - # options: >- - # -e MYSQL_ROOT_PASSWORD=tutorial steps: - name: Cancel Workflow Action uses: styfle/cancel-workflow-action@0.11.0 diff --git a/pyproject.toml b/pyproject.toml index 22a397e17..30793d2af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,9 +68,10 @@ spyglass_cli = "spyglass.cli:cli" "Homepage" = "https://github.com/LorenFrankLab/spyglass" "Bug Tracker" = "https://github.com/LorenFrankLab/spyglass/issues" -[project.optional-dependencies] +[project.optional-dependencies] position = ["ffmpeg", "numba>=0.54", "deeplabcut<2.3.0"] test = [ + "docker", # for tests in a container "pytest", # unit testing "pytest-cov", # code coverage "kachery", # database access @@ -110,5 +111,42 @@ line-length = 80 [tool.codespell] skip = '.git,*.pdf,*.svg,*.ipynb,./docs/site/**,temp*' +ignore-words-list = 'nevers' # Nevers - name in Citation -ignore-words-list = 'nevers' + +[tool.pytest.ini_options] +minversion = "7.0" +addopts = [ + "-sv", + "-p no:warnings", + "--no-teardown", + "--quiet-spy", + "--show-capture=no", + "--pdbcls=IPython.terminal.debugger:TerminalPdb", # use ipython debugger + "--cov=spyglass", + "--cov-report=term-missing", + "--no-cov-on-fail", +] +testpaths = ["tests"] +log_level = "INFO" + +[tool.coverage.run] +source = ["*/src/spyglass/*"] +omit = [ + "*/__init__.py", + "*/_version.py", + "*/cli/*", + # "*/common/*", + # "*/data_import/*", + "*/decoding/*", + "*/figurl_views/*", + "*/lfp/*", + "*/linearization/*", + "*/lock/*", + "*/position/*", + "*/ripple/*", + "*/sharing/*", + "*/spikesorting/*", + # "*/utils/*", +] + diff --git a/src/spyglass/common/common_device.py b/src/spyglass/common/common_device.py index 223862c81..2dd03c822 100644 --- a/src/spyglass/common/common_device.py +++ b/src/spyglass/common/common_device.py @@ -2,8 +2,8 @@ import ndx_franklab_novela from spyglass.common.errors import PopulateException -from spyglass.utils.dj_mixin import SpyglassMixin -from spyglass.utils.logging import logger +from spyglass.settings import test_mode +from spyglass.utils import SpyglassMixin, logger from spyglass.utils.nwb_helper_fn import get_nwb_file schema = dj.schema("common_device") @@ -154,25 +154,9 @@ def _add_device(cls, new_device_dict): all_values = DataAcquisitionDevice.fetch( "data_acquisition_device_name" ).tolist() - if name not in all_values: - # no entry with the same name exists, prompt user to add a new entry - logger.info( - f"\nData acquisition device '{name}' was not found in the " - f"database. The current values are: {all_values}. " - "Please ensure that the device you want to add does not already" - " exist in the database under a different name or spelling. " - "If you want to use an existing device in the database, " - "please change the corresponding Device object in the NWB file." - " Entering 'N' will raise an exception." - ) - to_db = " to the database" - val = input(f"Add data acquisition device '{name}'{to_db}? (y/N)") - if val.lower() in ["y", "yes"]: - cls.insert1(new_device_dict, skip_duplicates=True) - return - raise PopulateException( - f"User chose not to add device '{name}'{to_db}." - ) + if prompt_insert(name=name, all_values=all_values): + cls.insert1(new_device_dict, skip_duplicates=True) + return # Check if values provided match the values stored in the database db_dict = ( @@ -213,28 +197,11 @@ def _add_system(cls, system): all_values = DataAcquisitionDeviceSystem.fetch( "data_acquisition_device_system" ).tolist() - if system not in all_values: - logger.info( - f"\nData acquisition device system '{system}' was not found in" - f" the database. The current values are: {all_values}. " - "Please ensure that the system you want to add does not already" - " exist in the database under a different name or spelling. " - "If you want to use an existing system in the database, " - "please change the corresponding Device object in the NWB file." - " Entering 'N' will raise an exception." - ) - val = input( - f"Do you want to add data acquisition device system '{system}'" - + " to the database? (y/N)" - ) - if val.lower() in ["y", "yes"]: - key = {"data_acquisition_device_system": system} - DataAcquisitionDeviceSystem.insert1(key, skip_duplicates=True) - else: - raise PopulateException( - "User chose not to add data acquisition device system " - + f"'{system}' to the database." - ) + if prompt_insert( + name=system, all_values=all_values, table_type="system" + ): + key = {"data_acquisition_device_system": system} + DataAcquisitionDeviceSystem.insert1(key, skip_duplicates=True) return system @classmethod @@ -264,30 +231,11 @@ def _add_amplifier(cls, amplifier): all_values = DataAcquisitionDeviceAmplifier.fetch( "data_acquisition_device_amplifier" ).tolist() - if amplifier not in all_values: - logger.info( - f"\nData acquisition device amplifier '{amplifier}' was not " - f"found in the database. The current values are: {all_values}. " - "Please ensure that the amplifier you want to add does not " - "already exist in the database under a different name or " - "spelling. If you want to use an existing name in the database," - " please change the corresponding Device object in the NWB " - "file. Entering 'N' will raise an exception." - ) - val = input( - "Do you want to add data acquisition device amplifier " - + f"'{amplifier}' to the database? (y/N)" - ) - if val.lower() in ["y", "yes"]: - key = {"data_acquisition_device_amplifier": amplifier} - DataAcquisitionDeviceAmplifier.insert1( - key, skip_duplicates=True - ) - else: - raise PopulateException( - "User chose not to add data acquisition device amplifier " - + f"'{amplifier}' to the database." - ) + if prompt_insert( + name=amplifier, all_values=all_values, table_type="amplifier" + ): + key = {"data_acquisition_device_amplifier": amplifier} + DataAcquisitionDeviceAmplifier.insert1(key, skip_duplicates=True) return amplifier @@ -576,27 +524,9 @@ def _add_probe_type(cls, new_probe_type_dict): """ probe_type = new_probe_type_dict["probe_type"] all_values = ProbeType.fetch("probe_type").tolist() - if probe_type not in all_values: - logger.info( - f"\nProbe type '{probe_type}' was not found in the database. " - f"The current values are: {all_values}. " - "Please ensure that the probe type you want to add does not " - "already exist in the database under a different name or " - "spelling. If you want to use an existing name in the " - "database, please change the corresponding Probe object in the " - "NWB file. Entering 'N' will raise an exception." - ) - val = input( - f"Do you want to add probe type '{probe_type}' to the database?" - + " (y/N)" - ) - if val.lower() in ["y", "yes"]: - ProbeType.insert1(new_probe_type_dict, skip_duplicates=True) - return - raise PopulateException( - f"User chose not to add probe type '{probe_type}' to the " - + "database." - ) + if prompt_insert(probe_type, all_values, table="probe type"): + ProbeType.insert1(new_probe_type_dict, skip_duplicates=True) + return # else / entry exists: check whether the values provided match the # values stored in the database @@ -738,3 +668,55 @@ def create_from_nwbfile( cls.Shank.insert1(shank, skip_duplicates=True) for electrode in elect_dict.values(): cls.Electrode.insert1(electrode, skip_duplicates=True) + + +# ---------------------------- Helper functions ---------------------------- + + +# Migrated down to reduce redundancy and centralize 'test_mode' check for pytest +def prompt_insert( + name: str, + all_values: list, + table: str = "Data Acquisition Device", + table_type: str = None, +) -> bool: + """Prompt user to add an item to the database. Return True if yes. + + Assume insert during test mode. + + Parameters + ---------- + name : str + The name of the item to add. + all_values : list + List of all values in the database. + table : str, optional + The name of the table to add to, by default Data Acquisition Device + table_type : str, optional + The type of item to add, by default None. Data Acquisition Device X + """ + if name in all_values: + return False + + if test_mode: + return True + + if table_type: + table_type += " " + + logger.info( + f"{table}{table_type} '{name}' was not found in the" + f"database. The current values are: {all_values}.\n" + "Please ensure that the device you want to add does not already" + "exist in the database under a different name or spelling. If you" + "want to use an existing device in the database, please change the" + "corresponding Device object in the NWB file.\nEntering 'N' will " + "raise an exception." + ) + msg = f"Do you want to add {table}{table_type} '{name}' to the database?" + if dj.utils.user_choice(msg).lower() in ["y", "yes"]: + return True + + raise PopulateException( + f"User chose not to add {table}{table_type} '{name}' to the database." + ) diff --git a/src/spyglass/common/common_position.py b/src/spyglass/common/common_position.py index ea661a29d..d0fdc75dd 100644 --- a/src/spyglass/common/common_position.py +++ b/src/spyglass/common/common_position.py @@ -8,7 +8,7 @@ import pynwb.behavior from position_tools import ( get_angle, - get_centriod, + get_centroid, get_distance, get_speed, get_velocity, @@ -417,7 +417,7 @@ def calculate_position_info( ) # Calculate position, orientation, velocity, speed - position = get_centriod(back_LED, front_LED) # cm + position = get_centroid(back_LED, front_LED) # cm orientation = get_angle(back_LED, front_LED) # radians is_nan = np.isnan(orientation) diff --git a/src/spyglass/common/common_session.py b/src/spyglass/common/common_session.py index a237bc03c..7903a0d7a 100644 --- a/src/spyglass/common/common_session.py +++ b/src/spyglass/common/common_session.py @@ -1,10 +1,6 @@ import datajoint as dj -from spyglass.common.common_device import ( - CameraDevice, - DataAcquisitionDevice, - Probe, -) +from spyglass.common.common_device import CameraDevice, DataAcquisitionDevice, Probe from spyglass.common.common_lab import Institution, Lab, LabMember from spyglass.common.common_nwbfile import Nwbfile from spyglass.common.common_subject import Subject @@ -63,13 +59,15 @@ def make(self, key): nwbf = get_nwb_file(nwb_file_abspath) config = get_config(nwb_file_abspath) - # certain data are not associated with a single NWB file / session because they may apply to - # multiple sessions. these data go into dj.Manual tables. - # e.g., a lab member may be associated with multiple experiments, so the lab member table should not - # be dependent on (contain a primary key for) a session. + # certain data are not associated with a single NWB file / session + # because they may apply to multiple sessions. these data go into + # dj.Manual tables. e.g., a lab member may be associated with multiple + # experiments, so the lab member table should not be dependent on + # (contain a primary key for) a session. - # here, we create new entries in these dj.Manual tables based on the values read from the NWB file - # then, they are linked to the session via fields of Session (e.g., Subject, Institution, Lab) or part + # here, we create new entries in these dj.Manual tables based on the + # values read from the NWB file then, they are linked to the session + # via fields of Session (e.g., Subject, Institution, Lab) or part # tables (e.g., Experimenter, DataAcquisitionDevice). logger.info("Institution...") @@ -87,15 +85,12 @@ def make(self, key): if not debug_mode: # TODO: remove when demo files agree on device logger.info("Populate DataAcquisitionDevice...") DataAcquisitionDevice.insert_from_nwbfile(nwbf, config) - logger.info() logger.info("Populate CameraDevice...") CameraDevice.insert_from_nwbfile(nwbf) - logger.info() logger.info("Populate Probe...") Probe.insert_from_nwbfile(nwbf, config) - logger.info() if nwbf.subject is not None: subject_id = nwbf.subject.subject_id diff --git a/src/spyglass/data_import/__init__.py b/src/spyglass/data_import/__init__.py index 703cfa3c1..9c68cf038 100644 --- a/src/spyglass/data_import/__init__.py +++ b/src/spyglass/data_import/__init__.py @@ -1 +1,2 @@ +# TODO: change naming to avoid match between module and function from .insert_sessions import insert_sessions diff --git a/src/spyglass/data_import/insert_sessions.py b/src/spyglass/data_import/insert_sessions.py index c862fe85b..f31b0c09e 100644 --- a/src/spyglass/data_import/insert_sessions.py +++ b/src/spyglass/data_import/insert_sessions.py @@ -101,7 +101,7 @@ def copy_nwb_link_raw_ephys(nwb_file_name, out_nwb_file_name): if os.path.exists(out_nwb_file_abs_path): if debug_mode: return out_nwb_file_abs_path - warnings.warn( + logger.warn( f"Output file {out_nwb_file_abs_path} exists and will be " + "overwritten." ) diff --git a/src/spyglass/settings.py b/src/spyglass/settings.py index 4672af615..122d53014 100644 --- a/src/spyglass/settings.py +++ b/src/spyglass/settings.py @@ -31,6 +31,7 @@ def __init__(self, base_dir: str = None, **kwargs): self._config = dict() self.config_defaults = dict(prepopulate=True) self._debug_mode = False + self._test_mode = False self._dlc_base = None self.relative_dirs = { @@ -106,6 +107,7 @@ def load_config(self, force_reload=False): dj_dlc = dj_custom.get("dlc_dirs", {}) self._debug_mode = dj_custom.get("debug_mode", False) + self._test_mode = dj_custom.get("test_mode", False) resolved_base = ( self.supplied_base_dir @@ -166,6 +168,7 @@ def load_config(self, force_reload=False): self._config = dict( debug_mode=self._debug_mode, + test_mode=self._test_mode, **self.config_defaults, **config_dirs, **kachery_zone_dict, @@ -381,6 +384,7 @@ def _dj_custom(self) -> dict: return { "custom": { "debug_mode": str(self.debug_mode).lower(), + "test_mode": str(self._test_mode).lower(), "spyglass_dirs": { "base": self.base_dir, "raw": self.raw_dir, @@ -453,8 +457,19 @@ def video_dir(self) -> str: @property def debug_mode(self) -> bool: + """Returns True if debug_mode is set. + + Supports skipping inserts for Dockerized development. + """ return self._debug_mode + @property + def test_mode(self) -> bool: + """Returns True if test_mode is set. + + Required for pytests to run without prompts.""" + return self._test_mode + @property def dlc_project_dir(self) -> str: return self.config.get(self.dir_to_var("project", "dlc")) @@ -479,6 +494,7 @@ def dlc_output_dir(self) -> str: waveform_dir = sg_config.waveform_dir video_dir = sg_config.video_dir debug_mode = sg_config.debug_mode +test_mode = sg_config.test_mode prepopulate = config.get("prepopulate", False) dlc_project_dir = sg_config.dlc_project_dir dlc_video_dir = sg_config.dlc_video_dir diff --git a/tests/ci_config.py b/tests/ci_config.py deleted file mode 100644 index e329df7ed..000000000 --- a/tests/ci_config.py +++ /dev/null @@ -1,27 +0,0 @@ -import os -from pathlib import Path - -import datajoint as dj - -# NOTE this env var is set in the GitHub Action directly -data_dir = Path(os.environ["SPYGLASS_BASE_DIR"]) - -raw_dir = data_dir / "raw" -analysis_dir = data_dir / "analysis" - -dj.config["database.host"] = "localhost" -dj.config["database.user"] = "root" -dj.config["database.password"] = "tutorial" -dj.config["stores"] = { - "raw": { - "protocol": "file", - "location": str(raw_dir), - "stage": str(raw_dir), - }, - "analysis": { - "protocol": "file", - "location": str(analysis_dir), - "stage": str(analysis_dir), - }, -} -dj.config.save_global() diff --git a/tests/datajoint/__init__.py b/tests/common/__init__.py similarity index 100% rename from tests/datajoint/__init__.py rename to tests/common/__init__.py diff --git a/tests/common/test_common_interval_helpers.py b/tests/common/test_common_interval_helpers.py deleted file mode 100644 index 293abda91..000000000 --- a/tests/common/test_common_interval_helpers.py +++ /dev/null @@ -1,62 +0,0 @@ -import numpy as np -from spyglass.common.common_interval import ( - interval_list_intersect, - interval_set_difference_inds, -) - - -def test_interval_list_intersect1(): - interval_list1 = np.array([[0, 10], [3, 5], [14, 16]]) - interval_list2 = np.array([[10, 11], [9, 14], [13, 18]]) - intersection_list = interval_list_intersect(interval_list1, interval_list2) - assert np.all(intersection_list == np.array([[9, 10], [14, 16]])) - - -def test_interval_list_intersect2(): - # if there is no intersection, return empty list - interval_list1 = np.array([[0, 10], [3, 5]]) - interval_list2 = np.array([[11, 14]]) - intersection_list = interval_list_intersect(interval_list1, interval_list2) - assert len(intersection_list) == 0 - - -def test_interval_set_difference_inds_no_overlap(): - intervals1 = [(0, 5), (8, 10)] - intervals2 = [(5, 8)] - result = interval_set_difference_inds(intervals1, intervals2) - assert result == [(0, 5), (8, 10)] - - -def test_interval_set_difference_inds_overlap(): - intervals1 = [(0, 5), (8, 10)] - intervals2 = [(1, 2), (3, 4), (6, 9)] - result = interval_set_difference_inds(intervals1, intervals2) - assert result == [(0, 1), (2, 3), (4, 5), (9, 10)] - - -def test_interval_set_difference_inds_empty_intervals1(): - intervals1 = [] - intervals2 = [(1, 2), (3, 4), (6, 9)] - result = interval_set_difference_inds(intervals1, intervals2) - assert result == [] - - -def test_interval_set_difference_inds_empty_intervals2(): - intervals1 = [(0, 5), (8, 10)] - intervals2 = [] - result = interval_set_difference_inds(intervals1, intervals2) - assert result == [(0, 5), (8, 10)] - - -def test_interval_set_difference_inds_equal_intervals(): - intervals1 = [(0, 5), (8, 10)] - intervals2 = [(0, 5), (8, 10)] - result = interval_set_difference_inds(intervals1, intervals2) - assert result == [] - - -def test_interval_set_difference_inds_multiple_overlaps(): - intervals1 = [(0, 10)] - intervals2 = [(1, 3), (4, 6), (7, 9)] - result = interval_set_difference_inds(intervals1, intervals2) - assert result == [(0, 1), (3, 4), (6, 7), (9, 10)] diff --git a/tests/common/test_insert.py b/tests/common/test_insert.py new file mode 100644 index 000000000..8e57e22d3 --- /dev/null +++ b/tests/common/test_insert.py @@ -0,0 +1,174 @@ +from pytest import approx + + +def test_load_file(minirec_content): + assert minirec_content is not None + + +def test_insert_session(minirec_insert, minirec_content, minirec_restr, common): + subj_raw = minirec_content.subject + meta_raw = minirec_content + + sess_data = (common.Session & minirec_restr).fetch1() + assert ( + sess_data["subject_id"] == subj_raw.subject_id + ), "Subjuect ID not match" + + attrs = [ + ("institution_name", "institution"), + ("lab_name", "lab"), + ("session_id", "session_id"), + ("session_description", "session_description"), + ("experiment_description", "experiment_description"), + ] + + for sess_attr, meta_attr in attrs: + assert sess_data[sess_attr] == getattr( + meta_raw, meta_attr + ), f"Session table {sess_attr} not match raw data {meta_attr}" + + time_attrs = [ + ("session_start_time", "session_start_time"), + ("timestamps_reference_time", "timestamps_reference_time"), + ] + for sess_attr, meta_attr in time_attrs: + # a. strip timezone info from meta_raw + # b. convert to timestamp + # c. compare precision to 1 second + assert sess_data[sess_attr].timestamp() == approx( + getattr(meta_raw, meta_attr).replace(tzinfo=None).timestamp(), abs=1 + ), f"Session table {sess_attr} not match raw data {meta_attr}" + + +def test_insert_electrode_group(minirec_insert, minirec_content, common): + group_name = "0" + egroup_data = ( + common.ElectrodeGroup & {"electrode_group_name": group_name} + ).fetch1() + egroup_raw = minirec_content.electrode_groups.get(group_name) + + assert ( + egroup_data["description"] == egroup_raw.description + ), "ElectrodeGroup description not match" + + assert egroup_data["region_id"] == ( + common.BrainRegion & {"region_name": egroup_raw.location} + ).fetch1( + "region_id" + ), "Region ID does not match across raw data and BrainRegion table" + + +def test_insert_electrode( + minirec_insert, minirec_content, minirec_restr, common +): + electrode_id = "0" + e_data = (common.Electrode & {"electrode_id": electrode_id}).fetch1() + e_raw = minirec_content.electrodes.get(int(electrode_id)).to_dict().copy() + + attrs = [ + ("x", "x"), + ("y", "y"), + ("z", "z"), + ("impedance", "imp"), + ("filtering", "filtering"), + ("original_reference_electrode", "ref_elect_id"), + ] + + for e_attr, meta_attr in attrs: + assert ( # KeyError: 0 here ↓ + e_data[e_attr] == e_raw[int(electrode_id)][meta_attr] + ), f"Electrode table {e_attr} not match raw data {meta_attr}" + + +def test_insert_raw(minirec_insert, minirec_content, minirec_restr, common): + raw_data = (common.Raw & minirec_restr).fetch1() + raw_raw = minirec_content.get_acquisition() + + attrs = [ + ("comments", "comments"), + ("description", "description"), + ] + for raw_attr, meta_attr in attrs: + assert raw_data[raw_attr] == getattr( + raw_raw, meta_attr + ), f"Raw table {raw_attr} not match raw data {meta_attr}" + + +def test_insert_sample_count(minirec_insert, minirec_content, common): + # commont.SampleCount + assert False, "TODO" + + +def test_insert_dio(minirec_insert, minirec_content, common): + # commont.DIOEvents + assert False, "TODO" + + +def test_insert_pos(minirec_insert, minirec_content, common): + # commont.PositionSource * common.RawPosition + assert False, "TODO" + + +def test_insert_device(minirec_insert, minirec_devices, common): + this_device = "dataacq_device0" + device_raw = minirec_devices.get(this_device) + device_data = ( + common.DataAcquisitionDevice + & {"data_acquisition_device_name": this_device} + ).fetch1() + + attrs = [ + ("data_acquisition_device_name", "name"), + ("data_acquisition_device_system", "system"), + ("data_acquisition_device_amplifier", "amplifier"), + ("adc_circuit", "adc_circuit"), + ] + + for device_attr, meta_attr in attrs: + assert device_data[device_attr] == getattr( + device_raw, meta_attr + ), f"Device table {device_attr} not match raw data {meta_attr}" + + +def test_insert_camera(minirec_insert, minirec_devices, common): + camera_raw = minirec_devices.get("camera_device 0") + camera_data = ( + common.CameraDevice & {"camera_name": camera_raw.camera_name} + ).fetch1() + + attrs = [ + ("camera_name", "camera_name"), + ("manufacturer", "manufacturer"), + ("model", "model"), + ("lens", "lens"), + ("meters_per_pixel", "meters_per_pixel"), + ] + for camera_attr, meta_attr in attrs: + assert camera_data[camera_attr] == getattr( + camera_raw, meta_attr + ), f"Camera table {camera_attr} not match raw data {meta_attr}" + + +def test_insert_probe(minirec_insert, minirec_devices, common): + this_probe = "probe 0" + probe_raw = minirec_devices.get(this_probe) + probe_id = probe_raw.probe_type + + probe_data = ( + common.Probe * common.ProbeType & {"probe_id": probe_id} + ).fetch1() + + attrs = [ + ("probe_type", "probe_type"), + ("probe_description", "probe_description"), + ("contact_side_numbering", "contact_side_numbering"), + ] + + for probe_attr, meta_attr in attrs: + assert probe_data[probe_attr] == str( + getattr(probe_raw, meta_attr) + ), f"Probe table {probe_attr} not match raw data {meta_attr}" + + assert probe_data["num_shanks"] == len( + probe_raw.shanks + ), "Number of shanks in ProbeType number not raw data" diff --git a/tests/common/test_interval_helpers.py b/tests/common/test_interval_helpers.py new file mode 100644 index 000000000..d91ea4a96 --- /dev/null +++ b/tests/common/test_interval_helpers.py @@ -0,0 +1,75 @@ +import numpy as np +import pytest +from numpy import all, array + + +@pytest.fixture(scope="session") +def list_intersect(common): + yield common.common_interval.interval_list_intersect + + +@pytest.mark.parametrize( + "one, two, result", + [ + ( + np.array([[0, 10], [3, 5], [14, 16]]), + np.array([[10, 11], [9, 14], [13, 18]]), + np.array([[9, 10], [14, 16]]), + ), + ( # Empty result for no intersection + np.array([[0, 10], [3, 5]]), + np.array([[11, 14]]), + np.array([]), + ), + ], +) +def test_list_intersect(list_intersect, one, two, result): + assert np.array_equal( + list_intersect(one, two), result + ), "Problem with common_interval.interval_list_intersect" + + +@pytest.fixture(scope="session") +def set_difference(common): + yield common.common_interval.interval_set_difference_inds + + +@pytest.mark.parametrize( + "one, two, expected_result", + [ + ( # No overlap + [(0, 5), (8, 10)], + [(5, 8)], + [(0, 5), (8, 10)], + ), + ( # Overlap + [(0, 5), (8, 10)], + [(1, 2), (3, 4), (6, 9)], + [(0, 1), (2, 3), (4, 5), (9, 10)], + ), + ( # One empty + [], + [(1, 2), (3, 4), (6, 9)], + [], + ), + ( # Two empty + [(0, 5), (8, 10)], + [], + [(0, 5), (8, 10)], + ), + ( # Equal intervals + [(0, 5), (8, 10)], + [(0, 5), (8, 10)], + [], + ), + ( # Multiple overlaps + [(0, 10)], + [(1, 3), (4, 6), (7, 9)], + [(0, 1), (3, 4), (6, 7), (9, 10)], + ), + ], +) +def test_set_difference(set_difference, one, two, expected_result): + assert ( + set_difference(one, two) == expected_result + ), "Problem with common_interval.interval_set_difference_inds" diff --git a/tests/conftest.py b/tests/conftest.py index ca80b5749..4213b7a9b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,26 +1,20 @@ import os -import pathlib -import shutil import sys -import tempfile +import warnings from contextlib import nullcontext +from pathlib import Path import datajoint as dj +import pynwb import pytest from datajoint.logging import logger -from .datajoint._config import DATAJOINT_SERVER_PORT -from .datajoint._datajoint_server import ( - kill_datajoint_server, - run_datajoint_server, -) +from .container import DockerMySQLManager # ---------------------- CONSTANTS --------------------- -thisdir = os.path.dirname(os.path.realpath(__file__)) -sys.path.append(thisdir) -global __PROCESS -__PROCESS = None +# globals in pytest_configure: BASE_DIR, SERVER, TEARDOWN, VERBOSE +warnings.filterwarnings("ignore", category=UserWarning, module="hdmf") def pytest_addoption(parser): @@ -43,7 +37,7 @@ def pytest_addoption(parser): action="store_true", dest="quiet_spy", default=False, - help="Quiet print statements from Spyglass.", + help="Quiet logging from Spyglass.", ) parser.addoption( "--no-server", @@ -60,140 +54,198 @@ def pytest_addoption(parser): help="Tear down tables after tests.", ) parser.addoption( - "--datadir", + "--base-dir", action="store", - default="./tests/test_data/", - dest="datadir", + default="./tests/_data/", + dest="base_dir", help="Directory for local input file.", ) +def pytest_configure(config): + global BASE_DIR, SERVER, TEARDOWN, VERBOSE + + TEARDOWN = not config.option.no_teardown + VERBOSE = not config.option.quiet_spy + + BASE_DIR = Path(config.option.base_dir).absolute() + BASE_DIR.mkdir(parents=True, exist_ok=True) + os.environ["SPYGLASS_BASE_DIR"] = str(BASE_DIR) + + SERVER = DockerMySQLManager( + restart=False, + shutdown=TEARDOWN, + null_server=config.option.no_server, + verbose=VERBOSE, + ) + + # ------------------- FIXTURES ------------------- @pytest.fixture(scope="session") -def verbose_context(config): - """Verbosity context for suppressing Spyglass print statements.""" - return QuietStdOut() if config.option.quiet_spy else nullcontext() +def verbose(): + """Config for pytest fixtures.""" + yield VERBOSE + + +@pytest.fixture(scope="session", autouse=True) +def verbose_context(verbose): + """Verbosity context for suppressing Spyglass logging.""" + yield nullcontext() if verbose else QuietStdOut() + + +@pytest.fixture(scope="session") +def teardown(request): + yield TEARDOWN + + +@pytest.fixture(scope="session") +def server(request, teardown): + SERVER.wait() + yield SERVER + if teardown: + SERVER.stop() + + +@pytest.fixture(scope="session") +def dj_conn(request, server, verbose, teardown): + """Fixture for datajoint connection.""" + config_file = "dj_local_conf.json_pytest" + + dj.config.update(server.creds) + dj.config["loglevel"] = "INFO" if verbose else "ERROR" + dj.config.save(config_file) + dj.conn() + yield dj.conn() + if teardown: + if Path(config_file).exists(): + os.remove(config_file) + + +@pytest.fixture(scope="session") +def base_dir(): + yield BASE_DIR + + +@pytest.fixture(scope="session") +def raw_dir(base_dir): + # could do settings.raw_dir, but this is faster while server booting + yield base_dir / "raw" @pytest.fixture(scope="session") -def teardown(config): - return not config.option.no_teardown +def minirec_path(raw_dir): + yield raw_dir / "test.nwb" @pytest.fixture(scope="session") -def spy_config(config): +def minirec_download(): + # test_path = ( + # "ipfs://bafybeie4svt3paz5vr7cw7mkgibutbtbzyab4s24hqn5pzim3sgg56m3n4" + # ) + # try: + # local_test_path = kcl.load_file(test_path) + # except Exception as e: + # if os.environ.get("KACHERY_CLOUD_EPHEMERAL", None) != "TRUE": + # print( + # "Cannot load test file in non-ephemeral mode. Kachery cloud" + # + "client may need to be registered." + # ) + # raise e + # os.rename(local_test_path, nwbfile_path) pass -def pytest_configure(config): - """Run on build, after parsing command line options.""" - _set_env(base_dir=config.option.datadir) +@pytest.fixture(scope="session") +def minirec_content(minirec_path): + with pynwb.NWBHDF5IO( + path=str(minirec_path), mode="r", load_namespaces=True + ) as io: + nwbfile = io.read() + yield nwbfile + + +@pytest.fixture(scope="session") +def minirec_open(minirec_content): + yield minirec_content - # note that in this configuration, every test will use the same datajoint - # server this may create conflicts and dependencies between tests it may be - # better but significantly slower to start a new server for every test but - # the server needs to be started before tests are collected because - # datajoint runs when the source files are loaded, not when the tests are - # run. one solution might be to restart the server after every test - if not config.option.no_server: - global __PROCESS - __PROCESS = run_datajoint_server() +@pytest.fixture(scope="session") +def minirec_closed(minirec_path): + with pynwb.NWBHDF5IO( + path=str(minirec_path), mode="r", load_namespaces=True + ) as io: + nwbfile = io.read() + yield nwbfile -def pytest_unconfigure(config): - """Called before test process is exited.""" - if __PROCESS: - logger.info("Terminating datajoint compute resource process") - __PROCESS.terminate() +@pytest.fixture(scope="session") +def minirec_devices(minirec_content): + yield minirec_content.devices - # TODO handle ResourceWarning: subprocess X is still running __PROCESS.join() - if not config.option.no_server: - kill_datajoint_server() - shutil.rmtree(os.environ["SPYGLASS_BASE_DIR"]) +@pytest.fixture(scope="session") +def minirec_insert(minirec_path, teardown, server, dj_conn): + from spyglass.common import Nwbfile, Session # noqa: E402 + from spyglass.data_import import insert_sessions # noqa: E402 + from spyglass.utils.nwb_helper_fn import close_nwb_files # noqa: E402 + if len(Nwbfile()) > 0: + Nwbfile().delete(safemode=False) -# ------------------ GENERAL FUNCTION ------------------ + if server.connected: + insert_sessions(minirec_path.name) + else: + logger.error("No server connection.") + if len(Session()) == 0: + logger.error("No sessions inserted.") + yield -def _set_env(base_dir): - """Set environment variables.""" - - # TODO: change from tempdir to user supplied dir - # spyglass_base_dir = pathlib.Path(base_dir) - spyglass_base_dir = pathlib.Path(tempfile.mkdtemp()) - - spike_sorting_storage_dir = spyglass_base_dir / "spikesorting" - tmp_dir = spyglass_base_dir / "tmp" - - logger.info("Setting datajoint and kachery environment variables.") - logger.info("SPYGLASS_BASE_DIR set to", spyglass_base_dir) - - # TODO: make this a fixture - # spy_config_dict = dict( - # SPYGLASS_BASE_DIR=str(spyglass_base_dir), - # SPYGLASS_RECORDING_DIR=str(spyglass_base_dir / "recording"), - # SPYGLASS_SORTING_DIR=str(spyglass_base_dir / "sorting"), - # SPYGLASS_WAVEFORMS_DIR=str(spyglass_base_dir / "waveforms"), - # SPYGLASS_TEMP_DIR=str(tmp_dir), - # SPIKE_SORTING_STORAGE_DIR=str(spike_sorting_storage_dir), - # KACHERY_ZONE="franklab.collaborators", - # KACHERY_CLOUD_DIR="/stelmo/nwb/.kachery_cloud", - # KACHERY_STORAGE_DIR=str(spyglass_base_dir / "kachery_storage"), - # KACHERY_TEMP_DIR=str(spyglass_base_dir / "tmp"), - # FIGURL_CHANNEL="franklab2", - # DJ_SUPPORT_FILEPATH_MANAGEMENT="TRUE", - # KACHERY_CLOUD_EPHEMERAL="TRUE", - # ) + close_nwb_files() + if teardown: + Nwbfile().delete(safemode=False) - os.environ["SPYGLASS_BASE_DIR"] = str(spyglass_base_dir) - os.environ["DJ_SUPPORT_FILEPATH_MANAGEMENT"] = "TRUE" - os.environ["SPIKE_SORTING_STORAGE_DIR"] = str(spike_sorting_storage_dir) - os.environ["SPYGLASS_TEMP_DIR"] = str(tmp_dir) - os.environ["KACHERY_CLOUD_EPHEMERAL"] = "TRUE" - os.mkdir(spike_sorting_storage_dir) - os.mkdir(tmp_dir) +@pytest.fixture(scope="session") +def minirec_restr(minirec_path): + yield f"nwb_file_name LIKE '{minirec_path.stem}%'" + + +@pytest.fixture(scope="session") +def common(dj_conn): + from spyglass import common + + yield common + + +@pytest.fixture(scope="session") +def data_import(dj_conn): + from spyglass import data_import + + yield data_import - raw_dir = spyglass_base_dir / "raw" - analysis_dir = spyglass_base_dir / "analysis" - os.mkdir(raw_dir) - os.mkdir(analysis_dir) +@pytest.fixture(scope="session") +def settings(dj_conn): + from spyglass import settings + + yield settings - dj.config["database.host"] = "localhost" - dj.config["database.port"] = DATAJOINT_SERVER_PORT - dj.config["database.user"] = "root" - dj.config["database.password"] = "tutorial" - dj.config["stores"] = { - "raw": { - "protocol": "file", - "location": str(raw_dir), - "stage": str(raw_dir), - }, - "analysis": { - "protocol": "file", - "location": str(analysis_dir), - "stage": str(analysis_dir), - }, - } +# ------------------ GENERAL FUNCTION ------------------ class QuietStdOut: """If quiet_spy, used to quiet prints, teardowns and table.delete prints""" def __enter__(self): - # os.environ["LOG_LEVEL"] = "WARNING" logger.setLevel("CRITICAL") self._original_stdout = sys.stdout sys.stdout = open(os.devnull, "w") def __exit__(self, exc_type, exc_val, exc_tb): - # os.environ["LOG_LEVEL"] = "INFO" logger.setLevel("INFO") sys.stdout.close() sys.stdout = self._original_stdout diff --git a/tests/container.py b/tests/container.py new file mode 100644 index 000000000..d178e1fce --- /dev/null +++ b/tests/container.py @@ -0,0 +1,217 @@ +import atexit +import time + +import datajoint as dj +import docker +from datajoint import logger + + +class DockerMySQLManager: + """Manage Docker container for MySQL server + + Parameters + ---------- + image_name : str + Docker image name. Default 'datajoint/mysql'. + mysql_version : str + MySQL version. Default '8.0'. + container_name : str + Docker container name. Default 'spyglass-pytest'. + port : str + Port to map to DJ's default 3306. Default '330[mysql_version]' + (i.e., 3308 if testing 8.0). + null_server : bool + If True, do not start container. Return on all methods. Default False. + Useful for iterating on tests in existing container. + restart : bool + If True, stop and remove existing container on startup. Default True. + shutdown : bool + If True, stop and remove container on exit from python. Default True. + verbose : bool + If True, print container status on startup. Default False. + """ + + def __init__( + self, + image_name="datajoint/mysql", + mysql_version="8.0", + container_name="spyglass-pytest", + port=None, + null_server=False, + restart=True, + shutdown=True, + verbose=False, + ) -> None: + self.image_name = image_name + self.mysql_version = mysql_version + self.container_name = container_name + self.port = port or "330" + self.mysql_version[0] + self.client = docker.from_env() + self.null_server = null_server + self.password = "tutorial" + self.user = "root" + self.host = "localhost" + self._ran_container = None + self.logger = logger + self.logger.setLevel("INFO" if verbose else "ERROR") + + if not self.null_server: + if shutdown: + atexit.register(self.stop) # stop container on python exit + if restart: + self.stop() # stop container if it exists + self.start() + + @property + def container(self) -> docker.models.containers.Container: + return self.client.containers.get(self.container_name) + + @property + def container_status(self) -> str: + try: + self.container.reload() + return self.container.status + except docker.errors.NotFound: + return None + + @property + def container_health(self) -> str: + try: + self.container.reload() + return self.container.health + except docker.errors.NotFound: + return None + + @property + def msg(self) -> str: + return f"Container {self.container_name} " + + def start(self) -> str: + if self.null_server: + return None + + elif self.container_status in ["created", "running", "restarting"]: + self.logger.info( + self.msg + "starting: " + self.container_status + "." + ) + + elif self.container_status == "exited": + self.logger.info(self.msg + "restarting.") + self.container.restart() + + else: + self._ran_container = self.client.containers.run( + image=f"{self.image_name}:{self.mysql_version}", + name=self.container_name, + ports={3306: self.port}, + environment=[ + f"MYSQL_ROOT_PASSWORD={self.password}", + "MYSQL_DEFAULT_STORAGE_ENGINE=InnoDB", + ], + detach=True, + tty=True, + ) + self.logger.info(self.msg + "starting new.") + + return self.container.name + + def wait(self, timeout=120, wait=5) -> None: + """Wait for healthy container. + + Parameters + ---------- + timeout : int + Timeout in seconds. Default 120. + wait : int + Time to wait between checks in seconds. Default 5. + """ + + if self.null_server: + return None + if not self.container_status or self.container_status == "exited": + self.start() + + for _ in range(timeout // wait): + if self.container.health == "healthy": + break + self.logger.info(f"Container {self.container_name} starting...") + time.sleep(wait) + self.logger.info( + f"Container {self.container_name}, {self.container.health}." + ) + + @property + def _add_sql(self) -> str: + ESC = r"\_%" + return ( + "CREATE USER IF NOT EXISTS 'basic'@'%' IDENTIFIED BY " + + f"'{self.password}'; GRANT USAGE ON `%`.* TO 'basic'@'%';" + + "GRANT SELECT ON `%`.* TO 'basic'@'%';" + + f"GRANT ALL PRIVILEGES ON `common{ESC}`.* TO `basic`@`%`;" + + f"GRANT ALL PRIVILEGES ON `spikesorting{ESC}`.* TO `basic`@`%`;" + + f"GRANT ALL PRIVILEGES ON `lfp{ESC}`.* TO `basic`@`%`;" + + f"GRANT ALL PRIVILEGES ON `position{ESC}`.* TO `basic`@`%`;" + + f"GRANT ALL PRIVILEGES ON `ripple{ESC}`.* TO `basic`@`%`;" + + f"GRANT ALL PRIVILEGES ON `linearization{ESC}`.* TO `basic`@`%`;" + ).strip() + + def add_user(self) -> int: + """Add 'basic' user to container.""" + if self.null_server: + return None + + if self._container_running(): + result = self.container.exec_run( + cmd=[ + "mysql", + "-u", + self.user, + f"--password={self.password}", + "-e", + self._add_sql, + ], + stdout=False, + stderr=False, + tty=True, + ) + if result.exit_code == 0: + self.logger.info("Container added user.") + else: + logger.error("Failed to add user.") + return result.exit_code + else: + logger.error(f"Container {self.container_name} does not exist.") + return None + + @property + def creds(self): + """Datajoint credentials for this container.""" + return { + "database.host": "localhost", + "database.password": self.password, + "database.user": self.user, + "database.port": int(self.port), + "safmode": "false", + "custom": {"test_mode": True}, + } + + @property + def connected(self) -> bool: + self.wait() + dj.config.update(self.creds) + return dj.conn().is_connected + + def stop(self, remove=True) -> None: + """Stop and remove container.""" + if self.null_server: + return None + if not self.container_status or self.container_status == "exited": + self.logger.info( + f"Container {self.container_name} already stopped." + ) + return + self.container.stop() + self.logger.info(f"Container {self.container_name} stopped.") + if remove: + self.container.remove() + self.logger.info(f"Container {self.container_name} removed.") diff --git a/tests/data_import/__init__.py b/tests/data_import/__init__.py index e69de29bb..8f7eaee37 100644 --- a/tests/data_import/__init__.py +++ b/tests/data_import/__init__.py @@ -0,0 +1,3 @@ +# NOTE: test_insert_sessions does not increase coverage over common/test_insert +# but it does declare it's own nwbfile without downloading and test broken +# links which aren't technically part of spyglass diff --git a/tests/data_import/test_insert_sessions.py b/tests/data_import/test_insert_sessions.py index 65a16170a..ea907fcc1 100644 --- a/tests/data_import/test_insert_sessions.py +++ b/tests/data_import/test_insert_sessions.py @@ -1,114 +1,42 @@ import datetime -import os -import pathlib import shutil +import warnings +from pathlib import Path -import datajoint as dj import pynwb import pytest from hdmf.backends.warnings import BrokenLinkWarning -from spyglass.data_import.insert_sessions import copy_nwb_link_raw_ephys - -@pytest.fixture() -def new_nwbfile_raw_file_name(tmp_path): - nwbfile = pynwb.NWBFile( - session_description="session_description", - identifier="identifier", - session_start_time=datetime.datetime.now(datetime.timezone.utc), - ) - - device = nwbfile.create_device("dev1") - group = nwbfile.create_electrode_group( - "tetrode1", "tetrode description", "tetrode location", device - ) - nwbfile.add_electrode( - id=1, - x=1.0, - y=2.0, - z=3.0, - imp=-1.0, - location="CA1", - filtering="none", - group=group, - group_name="tetrode1", - ) - region = nwbfile.create_electrode_table_region( - region=[0], description="electrode 1" +@pytest.fixture(scope="session") +def copy_nwb_link_raw_ephys(data_import): + from spyglass.data_import.insert_sessions import ( # noqa: E402 + copy_nwb_link_raw_ephys, ) - es = pynwb.ecephys.ElectricalSeries( - name="test_ts", - data=[1, 2, 3], - timestamps=[1.0, 2.0, 3.0], - electrodes=region, - ) - nwbfile.add_acquisition(es) - - spyglass_base_dir = tmp_path / "nwb-data" - os.environ["SPYGLASS_BASE_DIR"] = str(spyglass_base_dir) - os.mkdir(os.environ["SPYGLASS_BASE_DIR"]) - - raw_dir = spyglass_base_dir / "raw" - os.mkdir(raw_dir) - - dj.config["stores"] = { - "raw": { - "protocol": "file", - "location": str(raw_dir), - "stage": str(raw_dir), - }, - } + return copy_nwb_link_raw_ephys - file_name = "raw.nwb" - file_path = raw_dir / file_name - with pynwb.NWBHDF5IO(str(file_path), mode="w") as io: - io.write(nwbfile) - return file_name +def test_open_path(minirec_path, minirec_open): + this_acq = minirec_open.acquisition + assert "e-series" in this_acq, "Ephys link no longer exists" + assert ( + str(minirec_path) == this_acq["e-series"].data.file.filename + ), "Path of ephys link is incorrect" -@pytest.fixture() -def new_nwbfile_no_ephys_file_name(): - return "raw_no_ephys.nwb" - - -@pytest.fixture() -def moved_nwbfile_no_ephys_file_path(tmp_path, new_nwbfile_no_ephys_file_name): - return tmp_path / new_nwbfile_no_ephys_file_name - - -def test_copy_nwb( - new_nwbfile_raw_file_name, - new_nwbfile_no_ephys_file_name, - moved_nwbfile_no_ephys_file_path, +def test_copy_link( + minirec_path, settings, minirec_closed, copy_nwb_link_raw_ephys ): - copy_nwb_link_raw_ephys( - new_nwbfile_raw_file_name, new_nwbfile_no_ephys_file_name - ) - - # new file should not have ephys data - base_dir = pathlib.Path(os.getenv("SPYGLASS_BASE_DIR", None)) - new_nwbfile_raw_file_name_abspath = ( - base_dir / "raw" / new_nwbfile_raw_file_name - ) - out_nwb_file_abspath = base_dir / "raw" / new_nwbfile_no_ephys_file_name - with pynwb.NWBHDF5IO(path=str(out_nwb_file_abspath), mode="r") as io: - nwbfile = io.read() - assert ( - "test_ts" in nwbfile.acquisition - ) # this still exists but should be a link now - assert nwbfile.acquisition["test_ts"].data.file.filename == str( - new_nwbfile_raw_file_name_abspath - ) - - # test readability after moving the linking raw file (paths are stored as relative paths in NWB) - # so this should break the link (moving the linked-to file should also break the link) - shutil.move(out_nwb_file_abspath, moved_nwbfile_no_ephys_file_path) - with pynwb.NWBHDF5IO( - path=str(moved_nwbfile_no_ephys_file_path), mode="r" - ) as io: - with pytest.warns(BrokenLinkWarning): - nwbfile = io.read() # should raise BrokenLinkWarning - assert "test_ts" not in nwbfile.acquisition + """Test readabilty after moving the linking raw file, breaking link""" + new_path = Path(settings.raw_dir) / "no_ephys.nwb" + new_moved = Path(settings.temp_dir) / "no_ephys_moved.nwb" + + copy_nwb_link_raw_ephys(minirec_path.name, new_path.name) + shutil.move(new_path, new_moved) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=UserWarning) + with pynwb.NWBHDF5IO(path=str(new_moved), mode="r") as io: + with pytest.warns(BrokenLinkWarning): + nwb_acq = io.read().acquisition + assert "e-series" not in nwb_acq, "Ephys link still exists after move" diff --git a/tests/datajoint/_config.py b/tests/datajoint/_config.py deleted file mode 100644 index 3798427ea..000000000 --- a/tests/datajoint/_config.py +++ /dev/null @@ -1 +0,0 @@ -DATAJOINT_SERVER_PORT = 3307 diff --git a/tests/datajoint/_datajoint_server.py b/tests/datajoint/_datajoint_server.py deleted file mode 100644 index f12455e67..000000000 --- a/tests/datajoint/_datajoint_server.py +++ /dev/null @@ -1,110 +0,0 @@ -import multiprocessing -import os -import time -import traceback - -import kachery_client as kc -from pymysql.err import OperationalError - -from ._config import DATAJOINT_SERVER_PORT - -DOCKER_IMAGE_NAME = "datajoint-server-pytest" - - -def run_service_datajoint_server(): - # The following cleanup is needed because we terminate this compute resource process - # See: https://pytest-cov.readthedocs.io/en/latest/subprocess-support.html - from pytest_cov.embed import cleanup_on_sigterm - - cleanup_on_sigterm() - - os.environ["RUNNING_PYTEST"] = "TRUE" - - ss = kc.ShellScript( - f""" - #!/bin/bash - set -ex - - docker kill {DOCKER_IMAGE_NAME} > /dev/null 2>&1 || true - docker rm {DOCKER_IMAGE_NAME} > /dev/null 2>&1 || true - exec docker run --name {DOCKER_IMAGE_NAME} -e MYSQL_ROOT_PASSWORD=tutorial -p {DATAJOINT_SERVER_PORT}:3306 datajoint/mysql - """, - redirect_output_to_stdout=True, - ) # noqa: E501 - ss.start() - ss.wait() - - -def run_datajoint_server(): - print("Starting datajoint server") - - ss_pull = kc.ShellScript( - """ - #!/bin/bash - set -ex - - exec docker pull datajoint/mysql - """ - ) - ss_pull.start() - ss_pull.wait() - - process = multiprocessing.Process( - target=run_service_datajoint_server, kwargs=dict() - ) - process.start() - - try: - _wait_for_datajoint_server_to_start() - except Exception: - kill_datajoint_server() - raise - - return process - # yield process - - # process.terminate() - # kill_datajoint_server() - - -def kill_datajoint_server(): - print("Terminating datajoint server") - - ss2 = kc.ShellScript( - f""" - #!/bin/bash - - set -ex - - docker kill {DOCKER_IMAGE_NAME} || true - docker rm {DOCKER_IMAGE_NAME} - """ - ) - ss2.start() - ss2.wait() - - -def _wait_for_datajoint_server_to_start(): - time.sleep(15) # it takes a while to start the server - timer = time.time() - print("Waiting for DataJoint server to start. Time", timer) - while True: - try: - from spyglass.common import Session # noqa: F401 - - return - except OperationalError as e: # e.g. Connection Error - print("DataJoint server not yet started. Time", time.time()) - print(e) - except Exception: - print("Failed to import Session. Time", time.time()) - print(traceback.format_exc()) - current_time = time.time() - elapsed = current_time - timer - if elapsed > 300: - raise Exception( - "Timeout while waiting for datajoint server to start and " - "import Session to succeed. Time", - current_time, - ) - time.sleep(5) diff --git a/tests/old_tests.py b/tests/old_tests.py new file mode 100644 index 000000000..7129e0bdf --- /dev/null +++ b/tests/old_tests.py @@ -0,0 +1,180 @@ +import datetime +import shutil +from pathlib import Path + +import pynwb +import pytest +from hdmf.backends.warnings import BrokenLinkWarning + + +@pytest.fixture(scope="session") +def new_raw_name(): + return "raw.nwb" + + +@pytest.fixture(scope="session") +def write_new_raw(new_raw_name, settings): + nwbfile = pynwb.NWBFile( + session_description="session_description", + identifier="identifier", + session_start_time=datetime.datetime.now(datetime.timezone.utc), + ) + + nwbfile.add_electrode( + id=1, + x=1.0, + y=2.0, + z=3.0, + imp=-1.0, + location="CA1", + filtering="none", + group=nwbfile.create_electrode_group( + "tetrode1", + "tetrode description", + "tetrode location", + nwbfile.create_device("dev1"), + ), + group_name="tetrode1", + ) + + nwbfile.add_acquisition( + pynwb.ecephys.ElectricalSeries( + name="test_ts", + data=[1, 2, 3], + timestamps=[1.0, 2.0, 3.0], + electrodes=nwbfile.create_electrode_table_region( + region=[0], description="electrode 1" + ), + ), + ) + + file_path = Path(settings.raw_dir) / new_raw_name + + with pynwb.NWBHDF5IO(str(file_path), mode="w") as io: + io.write(nwbfile) + + +@pytest.fixture(scope="session") +def no_ephys_name(): + return "raw_no_ephys.nwb" + + +@pytest.fixture(scope="session") +def no_ephys_path_moved(settings, no_ephys_name): + from pathlib import Path + + return Path(settings.temp_dir) / no_ephys_name + + +def test_copy_nwb( + new_raw_name, + no_ephys_name, + no_ephys_path_moved, + copy_nwb_link_raw_ephys, + settings, + write_new_raw, + minirec_content, +): + copy_nwb_link_raw_ephys(new_raw_name, no_ephys_name) + raw_path = Path(settings.raw_dir) + + # new file should not have ephys data + new_raw_abspath = raw_path / new_raw_name + no_ephys_abspath = raw_path / no_ephys_name + with pynwb.NWBHDF5IO(path=str(no_ephys_abspath), mode="r") as io: + nwb_acq = io.read().acquisition + assert nwb_acq["test_ts"].data.file.filename == str(new_raw_abspath) + + assert "test_ts" in nwb_acq, "Ephys link no longer exists" + + # test readability after moving the linking raw file (paths are stored as + # relative paths in NWB) so this should break the link (moving the + # linked-to file should also break the link) + + shutil.move(no_ephys_abspath, no_ephys_path_moved) + + with pynwb.NWBHDF5IO(path=str(no_ephys_path_moved), mode="r") as io: + with pytest.warns(BrokenLinkWarning): + nwb_acq = io.read().acquisition + assert "test_ts" not in nwb_acq, "Ephys link still exists" + + +def trim_file( + file_in="beans20190718.nwb", + file_out="beans20190718_trimmed.nwb", + old_spatial_series=True, +): + file_in = "beans20190718.nwb" + file_out = "beans20190718_trimmed.nwb" + + n_timestamps_to_keep = 20 # / 20000 Hz sampling rate = 1 ms + + with pynwb.NWBHDF5IO(file_in, "r", load_namespaces=True) as io: + nwbfile = io.read() + orig_eseries = nwbfile.acquisition.pop("e-series") + + # create a new ElectricalSeries with a subset of the data and timestamps + data = orig_eseries.data[0:n_timestamps_to_keep, :] + ts = orig_eseries.timestamps[0:n_timestamps_to_keep] + + electrodes = nwbfile.create_electrode_table_region( + region=orig_eseries.electrodes.data[:].tolist(), + name=orig_eseries.electrodes.name, + description=orig_eseries.electrodes.description, + ) + new_eseries = pynwb.ecephys.ElectricalSeries( + name=orig_eseries.name, + description=orig_eseries.description, + data=data, + timestamps=ts, + electrodes=electrodes, + ) + nwbfile.add_acquisition(new_eseries) + + # create a new analog TimeSeries with a subset of the data and timestamps + orig_analog = nwbfile.processing["analog"]["analog"].time_series.pop( + "analog" + ) + data = orig_analog.data[0:n_timestamps_to_keep, :] + ts = orig_analog.timestamps[0:n_timestamps_to_keep] + new_analog = pynwb.TimeSeries( + name=orig_analog.name, + description=orig_analog.description, + data=data, + timestamps=ts, + unit=orig_analog.unit, + ) + nwbfile.processing["analog"]["analog"].add_timeseries(new_analog) + + if old_spatial_series: + # remove last two columns of all SpatialSeries data (xloc2, yloc2) + # because it does not conform with NWB 2.5 and they are all zeroes + # anyway + + new_spatial_series = list() + for spatial_series_name in list( + nwbfile.processing["behavior"]["position"].spatial_series + ): + spatial_series = nwbfile.processing["behavior"][ + "position" + ].spatial_series.pop(spatial_series_name) + assert isinstance(spatial_series, pynwb.behavior.SpatialSeries) + data = spatial_series.data[:, 0:2] + ts = spatial_series.timestamps[0:n_timestamps_to_keep] + new_spatial_series.append( + pynwb.behavior.SpatialSeries( + name=spatial_series.name, + description=spatial_series.description, + data=data, + timestamps=spatial_series.timestamps, + reference_frame=spatial_series.reference_frame, + ) + ) + + for spatial_series in new_spatial_series: + nwbfile.processing["behavior"]["position"].add_spatial_series( + spatial_series + ) + + with pynwb.NWBHDF5IO(file_out, "w") as export_io: + export_io.export(io, nwbfile) diff --git a/tests/test_insert_beans.py b/tests/test_insert_beans.py deleted file mode 100644 index 29b3d7fb1..000000000 --- a/tests/test_insert_beans.py +++ /dev/null @@ -1,94 +0,0 @@ -import os -import pathlib -from datetime import datetime - -import kachery_cloud as kcl -import pynwb -import pytest - -from spyglass.common import CameraDevice, DataAcquisitionDevice, Probe, Session -from spyglass.data_import import insert_sessions - - -@pytest.mark.skip(reason="test_path needs to be updated") -def test_insert_sessions(): - print( - "In test_insert_sessions, os.environ['SPYGLASS_BASE_DIR'] is", - os.environ["SPYGLASS_BASE_DIR"], - ) - raw_dir = pathlib.Path(os.environ["SPYGLASS_BASE_DIR"]) / "raw" - nwbfile_path = raw_dir / "test.nwb" - - test_path = ( - "ipfs://bafybeie4svt3paz5vr7cw7mkgibutbtbzyab4s24hqn5pzim3sgg56m3n4" - ) - try: - local_test_path = kcl.load_file(test_path) - except Exception as e: - if os.environ.get("KACHERY_CLOUD_EPHEMERAL", None) != "TRUE": - print( - "Cannot load test file in non-ephemeral mode. Kachery cloud" - + "client may need to be registered." - ) - raise e - - # move the file to spyglass raw dir - os.rename(local_test_path, nwbfile_path) - - # test that the file can be read. this is not used otherwise - with pynwb.NWBHDF5IO( - path=str(nwbfile_path), mode="r", load_namespaces=True - ) as io: - nwbfile = io.read() - assert nwbfile is not None - - insert_sessions(nwbfile_path.name) - - x = (Session() & {"nwb_file_name": "test_.nwb"}).fetch1() - assert x["nwb_file_name"] == "test_.nwb" - assert x["subject_id"] == "Beans" - assert x["institution_name"] == "University of California, San Francisco" - assert x["lab_name"] == "Loren Frank" - assert x["session_id"] == "beans_01" - assert x["session_description"] == "Reinforcement leaarning" - assert x["session_start_time"] == datetime(2019, 7, 18, 15, 29, 47) - assert x["timestamps_reference_time"] == datetime(1970, 1, 1, 0, 0) - assert x["experiment_description"] == "Reinforcement learning" - - x = DataAcquisitionDevice().fetch() - assert len(x) == 1 - assert x[0]["device_name"] == "dataacq_device0" - assert x[0]["system"] == "SpikeGadgets" - assert x[0]["amplifier"] == "Intan" - assert x[0]["adc_circuit"] == "Intan" - - x = CameraDevice().fetch() - assert len(x) == 2 - # NOTE order of insertion is not consistent so cannot use x[0] - expected1 = dict( - camera_name="beans sleep camera", - # meters_per_pixel=0.00055, # cannot check floating point values this way - manufacturer="", - model="unknown", - lens="unknown", - camera_id=0, - ) - assert CameraDevice() & expected1 - assert (CameraDevice() & expected1).fetch("meters_per_pixel") == 0.00055 - expected2 = dict( - camera_name="beans run camera", - # meters_per_pixel=0.002, - manufacturer="", - model="unknown2", - lens="unknown2", - camera_id=1, - ) - assert CameraDevice() & expected2 - assert (CameraDevice() & expected2).fetch("meters_per_pixel") == 0.002 - - x = Probe().fetch() - assert len(x) == 1 - assert x[0]["probe_type"] == "128c-4s8mm6cm-20um-40um-sl" - assert x[0]["probe_description"] == "128 channel polyimide probe" - assert x[0]["num_shanks"] == 4 - assert x[0]["contact_side_numbering"] == "True" diff --git a/tests/trim_beans.py b/tests/trim_beans.py deleted file mode 100644 index 242e65c49..000000000 --- a/tests/trim_beans.py +++ /dev/null @@ -1,73 +0,0 @@ -import pynwb - -# import ndx_franklab_novela - -file_in = "beans20190718.nwb" -file_out = "beans20190718_trimmed.nwb" - -n_timestamps_to_keep = 20 # / 20000 Hz sampling rate = 1 ms - -with pynwb.NWBHDF5IO(file_in, "r", load_namespaces=True) as io: - nwbfile = io.read() - orig_eseries = nwbfile.acquisition.pop("e-series") - - # create a new ElectricalSeries with a subset of the data and timestamps - data = orig_eseries.data[0:n_timestamps_to_keep, :] - ts = orig_eseries.timestamps[0:n_timestamps_to_keep] - electrodes = nwbfile.create_electrode_table_region( - region=orig_eseries.electrodes.data[:].tolist(), - name=orig_eseries.electrodes.name, - description=orig_eseries.electrodes.description, - ) - new_eseries = pynwb.ecephys.ElectricalSeries( - name=orig_eseries.name, - description=orig_eseries.description, - data=data, - timestamps=ts, - electrodes=electrodes, - ) - nwbfile.add_acquisition(new_eseries) - - # create a new analog TimeSeries with a subset of the data and timestamps - orig_analog = nwbfile.processing["analog"]["analog"].time_series.pop( - "analog" - ) - data = orig_analog.data[0:n_timestamps_to_keep, :] - ts = orig_analog.timestamps[0:n_timestamps_to_keep] - new_analog = pynwb.TimeSeries( - name=orig_analog.name, - description=orig_analog.description, - data=data, - timestamps=ts, - unit=orig_analog.unit, - ) - nwbfile.processing["analog"]["analog"].add_timeseries(new_analog) - - # remove last two columns of all SpatialSeries data (xloc2, yloc2) because - # it does not conform with NWB 2.5 and they are all zeroes anyway - new_spatial_series = list() - for spatial_series_name in list( - nwbfile.processing["behavior"]["position"].spatial_series - ): - spatial_series = nwbfile.processing["behavior"][ - "position" - ].spatial_series.pop(spatial_series_name) - assert isinstance(spatial_series, pynwb.behavior.SpatialSeries) - data = spatial_series.data[:, 0:2] - ts = spatial_series.timestamps[0:n_timestamps_to_keep] - new_spatial_series.append( - pynwb.behavior.SpatialSeries( - name=spatial_series.name, - description=spatial_series.description, - data=data, - timestamps=spatial_series.timestamps, - reference_frame=spatial_series.reference_frame, - ) - ) - for spatial_series in new_spatial_series: - nwbfile.processing["behavior"]["position"].add_spatial_series( - spatial_series - ) - - with pynwb.NWBHDF5IO(file_out, "w") as export_io: - export_io.export(io, nwbfile) diff --git a/tests/test_nwb_helper_fn.py b/tests/utils/test_nwb_helper_fn.py similarity index 89% rename from tests/test_nwb_helper_fn.py rename to tests/utils/test_nwb_helper_fn.py index 8095f80e0..d054f7ecb 100644 --- a/tests/test_nwb_helper_fn.py +++ b/tests/utils/test_nwb_helper_fn.py @@ -3,9 +3,11 @@ import pynwb -# NOTE: importing this calls spyglass.__init__ and spyglass.common.__init__ -# which both require the DataJoint MySQL server to be up and running -from spyglass.common import get_electrode_indices + +def get_electrode_indices(*args, **kwargs): + from spyglass.common import get_electrode_indices # noqa: E402 + + return get_electrode_indices(*args, **kwargs) class TestGetElectrodeIndices(unittest.TestCase): From 0c3d0004a7d96bd61bc83fa1ac43aab15b795d3a Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Fri, 5 Jan 2024 17:43:50 -0600 Subject: [PATCH 03/16] Add tests for all tables in --- pyproject.toml | 13 +- src/spyglass/common/common_position.py | 7 +- src/spyglass/common/common_session.py | 6 +- tests/common/conftest.py | 28 ++++ tests/common/test_insert.py | 113 ++++++++++---- tests/common/test_interval_helpers.py | 1 - tests/conftest.py | 30 ++-- tests/data_import/test_insert_sessions.py | 13 +- tests/old_tests.py | 180 ---------------------- 9 files changed, 142 insertions(+), 249 deletions(-) create mode 100644 tests/common/conftest.py delete mode 100644 tests/old_tests.py diff --git a/pyproject.toml b/pyproject.toml index 30793d2af..8c303ce5d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,7 +68,7 @@ spyglass_cli = "spyglass.cli:cli" "Homepage" = "https://github.com/LorenFrankLab/spyglass" "Bug Tracker" = "https://github.com/LorenFrankLab/spyglass/issues" -[project.optional-dependencies] +[project.optional-dependencies] position = ["ffmpeg", "numba>=0.54", "deeplabcut<2.3.0"] test = [ "docker", # for tests in a container @@ -111,20 +111,20 @@ line-length = 80 [tool.codespell] skip = '.git,*.pdf,*.svg,*.ipynb,./docs/site/**,temp*' -ignore-words-list = 'nevers' +ignore-words-list = 'nevers' # Nevers - name in Citation [tool.pytest.ini_options] minversion = "7.0" addopts = [ - "-sv", - "-p no:warnings", + "-sv", + "-p no:warnings", "--no-teardown", "--quiet-spy", "--show-capture=no", "--pdbcls=IPython.terminal.debugger:TerminalPdb", # use ipython debugger - "--cov=spyglass", - "--cov-report=term-missing", + "--cov=spyglass", + "--cov-report=term-missing", "--no-cov-on-fail", ] testpaths = ["tests"] @@ -149,4 +149,3 @@ omit = [ "*/spikesorting/*", # "*/utils/*", ] - diff --git a/src/spyglass/common/common_position.py b/src/spyglass/common/common_position.py index d0fdc75dd..86ddd1403 100644 --- a/src/spyglass/common/common_position.py +++ b/src/spyglass/common/common_position.py @@ -8,7 +8,6 @@ import pynwb.behavior from position_tools import ( get_angle, - get_centroid, get_distance, get_speed, get_velocity, @@ -30,6 +29,12 @@ from spyglass.utils import SpyglassMixin, logger from spyglass.utils.dj_helper_fn import deprecated_factory +try: + from position_tools import get_centroid +except ImportError: + logger.warnint("Please update position_tools to >= 0.1.0") + from position_tools import get_centroid + schema = dj.schema("common_position") diff --git a/src/spyglass/common/common_session.py b/src/spyglass/common/common_session.py index 7903a0d7a..71a8323db 100644 --- a/src/spyglass/common/common_session.py +++ b/src/spyglass/common/common_session.py @@ -1,6 +1,10 @@ import datajoint as dj -from spyglass.common.common_device import CameraDevice, DataAcquisitionDevice, Probe +from spyglass.common.common_device import ( + CameraDevice, + DataAcquisitionDevice, + Probe, +) from spyglass.common.common_lab import Institution, Lab, LabMember from spyglass.common.common_nwbfile import Nwbfile from spyglass.common.common_subject import Subject diff --git a/tests/common/conftest.py b/tests/common/conftest.py new file mode 100644 index 000000000..f74e8002e --- /dev/null +++ b/tests/common/conftest.py @@ -0,0 +1,28 @@ +import pytest + + +@pytest.fixture(scope="session") +def mini_devices(mini_content): + yield mini_content.devices + + +@pytest.fixture(scope="session") +def mini_behavior(mini_content): + yield mini_content.processing.get("behavior") + + +@pytest.fixture(scope="session") +def mini_pos(mini_behavior): + yield mini_behavior.get_data_interface("position").spatial_series + + +@pytest.fixture(scope="session") +def mini_pos_series(mini_pos): + yield next(iter(mini_pos)) + + +@pytest.fixture(scope="session") +def mini_pos_tbl(common, mini_pos_series): + yield common.PositionSource.SpatialSeries * common.RawPosition.PosObject & { + "name": mini_pos_series + } diff --git a/tests/common/test_insert.py b/tests/common/test_insert.py index 8e57e22d3..5b5b2c4b1 100644 --- a/tests/common/test_insert.py +++ b/tests/common/test_insert.py @@ -1,15 +1,13 @@ +from datajoint.hash import key_hash +from pandas import DataFrame, Index from pytest import approx -def test_load_file(minirec_content): - assert minirec_content is not None +def test_insert_session(mini_insert, mini_content, mini_restr, common): + subj_raw = mini_content.subject + meta_raw = mini_content - -def test_insert_session(minirec_insert, minirec_content, minirec_restr, common): - subj_raw = minirec_content.subject - meta_raw = minirec_content - - sess_data = (common.Session & minirec_restr).fetch1() + sess_data = (common.Session & mini_restr).fetch1() assert ( sess_data["subject_id"] == subj_raw.subject_id ), "Subjuect ID not match" @@ -40,12 +38,12 @@ def test_insert_session(minirec_insert, minirec_content, minirec_restr, common): ), f"Session table {sess_attr} not match raw data {meta_attr}" -def test_insert_electrode_group(minirec_insert, minirec_content, common): +def test_insert_electrode_group(mini_insert, mini_content, common): group_name = "0" egroup_data = ( common.ElectrodeGroup & {"electrode_group_name": group_name} ).fetch1() - egroup_raw = minirec_content.electrode_groups.get(group_name) + egroup_raw = mini_content.electrode_groups.get(group_name) assert ( egroup_data["description"] == egroup_raw.description @@ -58,12 +56,10 @@ def test_insert_electrode_group(minirec_insert, minirec_content, common): ), "Region ID does not match across raw data and BrainRegion table" -def test_insert_electrode( - minirec_insert, minirec_content, minirec_restr, common -): +def test_insert_electrode(mini_insert, mini_content, mini_restr, common): electrode_id = "0" e_data = (common.Electrode & {"electrode_id": electrode_id}).fetch1() - e_raw = minirec_content.electrodes.get(int(electrode_id)).to_dict().copy() + e_raw = mini_content.electrodes.get(int(electrode_id)).to_dict().copy() attrs = [ ("x", "x"), @@ -75,14 +71,14 @@ def test_insert_electrode( ] for e_attr, meta_attr in attrs: - assert ( # KeyError: 0 here ↓ - e_data[e_attr] == e_raw[int(electrode_id)][meta_attr] + assert ( + e_data[e_attr] == e_raw[meta_attr][int(electrode_id)] ), f"Electrode table {e_attr} not match raw data {meta_attr}" -def test_insert_raw(minirec_insert, minirec_content, minirec_restr, common): - raw_data = (common.Raw & minirec_restr).fetch1() - raw_raw = minirec_content.get_acquisition() +def test_insert_raw(mini_insert, mini_content, mini_restr, common): + raw_data = (common.Raw & mini_restr).fetch1() + raw_raw = mini_content.get_acquisition() attrs = [ ("comments", "comments"), @@ -94,24 +90,73 @@ def test_insert_raw(minirec_insert, minirec_content, minirec_restr, common): ), f"Raw table {raw_attr} not match raw data {meta_attr}" -def test_insert_sample_count(minirec_insert, minirec_content, common): - # commont.SampleCount - assert False, "TODO" +def test_insert_sample_count(mini_insert, mini_content, mini_restr, common): + sample_data = (common.SampleCount & mini_restr).fetch1() + sample_full = mini_content.processing.get("sample_count") + if not sample_full: + assert False, "No sample count data in raw data" + sample_raw = sample_full.data_interfaces.get("sample_count") + assert ( + sample_data["sample_count_object_id"] == sample_raw.object_id + ), "SampleCount insertion error" + + +def test_insert_dio(mini_insert, mini_behavior, mini_restr, common): + events_data = (common.DIOEvents & mini_restr).fetch(as_dict=True) + events_raw = mini_behavior.get_data_interface( + "behavioral_events" + ).time_series + + assert len(events_data) == len(events_raw), "Number of events not match" + + event = "Poke1" + event_raw = events_raw.get(event) + event_data = (common.DIOEvents & {"dio_event_name": event}).fetch1() + assert ( + event_data["dio_object_id"] == event_raw.object_id + ), "DIO Event insertion error" + + +def test_insert_pos( + mini_insert, + common, + mini_behavior, + mini_restr, + mini_pos_series, + mini_pos_tbl, +): + pos_data = (common.PositionSource.SpatialSeries & mini_restr).fetch() + pos_raw = mini_behavior.get_data_interface("position").spatial_series -def test_insert_dio(minirec_insert, minirec_content, common): - # commont.DIOEvents - assert False, "TODO" + assert len(pos_data) == len(pos_raw), "Number of spatial series not match" + raw_obj_id = pos_raw[mini_pos_series].object_id + data_obj_id = mini_pos_tbl.fetch1("raw_position_object_id") + + assert data_obj_id == raw_obj_id, "PosObject insertion error" + + +def test_fetch_pos( + mini_insert, common, mini_pos, mini_pos_series, mini_pos_tbl +): + pos_key = ( + common.PositionSource.SpatialSeries & mini_pos_tbl.fetch("KEY") + ).fetch(as_dict=True)[0] + pos_df = (common.RawPosition.PosObject & pos_key).fetch1_dataframe() -def test_insert_pos(minirec_insert, minirec_content, common): - # commont.PositionSource * common.RawPosition - assert False, "TODO" + series = mini_pos[mini_pos_series] + raw_df = DataFrame( + data=series.data, + index=Index(series.timestamps, name="time"), + columns=[col + "1" for col in series.description.split(", ")], + ) + assert key_hash(pos_df) == key_hash(raw_df), "Spatial series fetch error" -def test_insert_device(minirec_insert, minirec_devices, common): +def test_insert_device(mini_insert, mini_devices, common): this_device = "dataacq_device0" - device_raw = minirec_devices.get(this_device) + device_raw = mini_devices.get(this_device) device_data = ( common.DataAcquisitionDevice & {"data_acquisition_device_name": this_device} @@ -130,8 +175,8 @@ def test_insert_device(minirec_insert, minirec_devices, common): ), f"Device table {device_attr} not match raw data {meta_attr}" -def test_insert_camera(minirec_insert, minirec_devices, common): - camera_raw = minirec_devices.get("camera_device 0") +def test_insert_camera(mini_insert, mini_devices, common): + camera_raw = mini_devices.get("camera_device 0") camera_data = ( common.CameraDevice & {"camera_name": camera_raw.camera_name} ).fetch1() @@ -149,9 +194,9 @@ def test_insert_camera(minirec_insert, minirec_devices, common): ), f"Camera table {camera_attr} not match raw data {meta_attr}" -def test_insert_probe(minirec_insert, minirec_devices, common): +def test_insert_probe(mini_insert, mini_devices, common): this_probe = "probe 0" - probe_raw = minirec_devices.get(this_probe) + probe_raw = mini_devices.get(this_probe) probe_id = probe_raw.probe_type probe_data = ( diff --git a/tests/common/test_interval_helpers.py b/tests/common/test_interval_helpers.py index d91ea4a96..621210a8a 100644 --- a/tests/common/test_interval_helpers.py +++ b/tests/common/test_interval_helpers.py @@ -1,6 +1,5 @@ import numpy as np import pytest -from numpy import all, array @pytest.fixture(scope="session") diff --git a/tests/conftest.py b/tests/conftest.py index 4213b7a9b..b17dbd674 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -135,12 +135,12 @@ def raw_dir(base_dir): @pytest.fixture(scope="session") -def minirec_path(raw_dir): +def mini_path(raw_dir): yield raw_dir / "test.nwb" @pytest.fixture(scope="session") -def minirec_download(): +def mini_download(): # test_path = ( # "ipfs://bafybeie4svt3paz5vr7cw7mkgibutbtbzyab4s24hqn5pzim3sgg56m3n4" # ) @@ -158,35 +158,31 @@ def minirec_download(): @pytest.fixture(scope="session") -def minirec_content(minirec_path): +def mini_content(mini_path): with pynwb.NWBHDF5IO( - path=str(minirec_path), mode="r", load_namespaces=True + path=str(mini_path), mode="r", load_namespaces=True ) as io: nwbfile = io.read() + assert nwbfile is not None, "NWBFile empty." yield nwbfile @pytest.fixture(scope="session") -def minirec_open(minirec_content): - yield minirec_content +def mini_open(mini_content): + yield mini_content @pytest.fixture(scope="session") -def minirec_closed(minirec_path): +def mini_closed(mini_path): with pynwb.NWBHDF5IO( - path=str(minirec_path), mode="r", load_namespaces=True + path=str(mini_path), mode="r", load_namespaces=True ) as io: nwbfile = io.read() yield nwbfile @pytest.fixture(scope="session") -def minirec_devices(minirec_content): - yield minirec_content.devices - - -@pytest.fixture(scope="session") -def minirec_insert(minirec_path, teardown, server, dj_conn): +def mini_insert(mini_path, teardown, server, dj_conn): from spyglass.common import Nwbfile, Session # noqa: E402 from spyglass.data_import import insert_sessions # noqa: E402 from spyglass.utils.nwb_helper_fn import close_nwb_files # noqa: E402 @@ -195,7 +191,7 @@ def minirec_insert(minirec_path, teardown, server, dj_conn): Nwbfile().delete(safemode=False) if server.connected: - insert_sessions(minirec_path.name) + insert_sessions(mini_path.name) else: logger.error("No server connection.") if len(Session()) == 0: @@ -209,8 +205,8 @@ def minirec_insert(minirec_path, teardown, server, dj_conn): @pytest.fixture(scope="session") -def minirec_restr(minirec_path): - yield f"nwb_file_name LIKE '{minirec_path.stem}%'" +def mini_restr(mini_path): + yield f"nwb_file_name LIKE '{mini_path.stem}%'" @pytest.fixture(scope="session") diff --git a/tests/data_import/test_insert_sessions.py b/tests/data_import/test_insert_sessions.py index ea907fcc1..c2c8e6c23 100644 --- a/tests/data_import/test_insert_sessions.py +++ b/tests/data_import/test_insert_sessions.py @@ -1,4 +1,3 @@ -import datetime import shutil import warnings from pathlib import Path @@ -17,22 +16,20 @@ def copy_nwb_link_raw_ephys(data_import): return copy_nwb_link_raw_ephys -def test_open_path(minirec_path, minirec_open): - this_acq = minirec_open.acquisition +def test_open_path(mini_path, mini_open): + this_acq = mini_open.acquisition assert "e-series" in this_acq, "Ephys link no longer exists" assert ( - str(minirec_path) == this_acq["e-series"].data.file.filename + str(mini_path) == this_acq["e-series"].data.file.filename ), "Path of ephys link is incorrect" -def test_copy_link( - minirec_path, settings, minirec_closed, copy_nwb_link_raw_ephys -): +def test_copy_link(mini_path, settings, mini_closed, copy_nwb_link_raw_ephys): """Test readabilty after moving the linking raw file, breaking link""" new_path = Path(settings.raw_dir) / "no_ephys.nwb" new_moved = Path(settings.temp_dir) / "no_ephys_moved.nwb" - copy_nwb_link_raw_ephys(minirec_path.name, new_path.name) + copy_nwb_link_raw_ephys(mini_path.name, new_path.name) shutil.move(new_path, new_moved) with warnings.catch_warnings(): warnings.simplefilter("ignore", category=UserWarning) diff --git a/tests/old_tests.py b/tests/old_tests.py deleted file mode 100644 index 7129e0bdf..000000000 --- a/tests/old_tests.py +++ /dev/null @@ -1,180 +0,0 @@ -import datetime -import shutil -from pathlib import Path - -import pynwb -import pytest -from hdmf.backends.warnings import BrokenLinkWarning - - -@pytest.fixture(scope="session") -def new_raw_name(): - return "raw.nwb" - - -@pytest.fixture(scope="session") -def write_new_raw(new_raw_name, settings): - nwbfile = pynwb.NWBFile( - session_description="session_description", - identifier="identifier", - session_start_time=datetime.datetime.now(datetime.timezone.utc), - ) - - nwbfile.add_electrode( - id=1, - x=1.0, - y=2.0, - z=3.0, - imp=-1.0, - location="CA1", - filtering="none", - group=nwbfile.create_electrode_group( - "tetrode1", - "tetrode description", - "tetrode location", - nwbfile.create_device("dev1"), - ), - group_name="tetrode1", - ) - - nwbfile.add_acquisition( - pynwb.ecephys.ElectricalSeries( - name="test_ts", - data=[1, 2, 3], - timestamps=[1.0, 2.0, 3.0], - electrodes=nwbfile.create_electrode_table_region( - region=[0], description="electrode 1" - ), - ), - ) - - file_path = Path(settings.raw_dir) / new_raw_name - - with pynwb.NWBHDF5IO(str(file_path), mode="w") as io: - io.write(nwbfile) - - -@pytest.fixture(scope="session") -def no_ephys_name(): - return "raw_no_ephys.nwb" - - -@pytest.fixture(scope="session") -def no_ephys_path_moved(settings, no_ephys_name): - from pathlib import Path - - return Path(settings.temp_dir) / no_ephys_name - - -def test_copy_nwb( - new_raw_name, - no_ephys_name, - no_ephys_path_moved, - copy_nwb_link_raw_ephys, - settings, - write_new_raw, - minirec_content, -): - copy_nwb_link_raw_ephys(new_raw_name, no_ephys_name) - raw_path = Path(settings.raw_dir) - - # new file should not have ephys data - new_raw_abspath = raw_path / new_raw_name - no_ephys_abspath = raw_path / no_ephys_name - with pynwb.NWBHDF5IO(path=str(no_ephys_abspath), mode="r") as io: - nwb_acq = io.read().acquisition - assert nwb_acq["test_ts"].data.file.filename == str(new_raw_abspath) - - assert "test_ts" in nwb_acq, "Ephys link no longer exists" - - # test readability after moving the linking raw file (paths are stored as - # relative paths in NWB) so this should break the link (moving the - # linked-to file should also break the link) - - shutil.move(no_ephys_abspath, no_ephys_path_moved) - - with pynwb.NWBHDF5IO(path=str(no_ephys_path_moved), mode="r") as io: - with pytest.warns(BrokenLinkWarning): - nwb_acq = io.read().acquisition - assert "test_ts" not in nwb_acq, "Ephys link still exists" - - -def trim_file( - file_in="beans20190718.nwb", - file_out="beans20190718_trimmed.nwb", - old_spatial_series=True, -): - file_in = "beans20190718.nwb" - file_out = "beans20190718_trimmed.nwb" - - n_timestamps_to_keep = 20 # / 20000 Hz sampling rate = 1 ms - - with pynwb.NWBHDF5IO(file_in, "r", load_namespaces=True) as io: - nwbfile = io.read() - orig_eseries = nwbfile.acquisition.pop("e-series") - - # create a new ElectricalSeries with a subset of the data and timestamps - data = orig_eseries.data[0:n_timestamps_to_keep, :] - ts = orig_eseries.timestamps[0:n_timestamps_to_keep] - - electrodes = nwbfile.create_electrode_table_region( - region=orig_eseries.electrodes.data[:].tolist(), - name=orig_eseries.electrodes.name, - description=orig_eseries.electrodes.description, - ) - new_eseries = pynwb.ecephys.ElectricalSeries( - name=orig_eseries.name, - description=orig_eseries.description, - data=data, - timestamps=ts, - electrodes=electrodes, - ) - nwbfile.add_acquisition(new_eseries) - - # create a new analog TimeSeries with a subset of the data and timestamps - orig_analog = nwbfile.processing["analog"]["analog"].time_series.pop( - "analog" - ) - data = orig_analog.data[0:n_timestamps_to_keep, :] - ts = orig_analog.timestamps[0:n_timestamps_to_keep] - new_analog = pynwb.TimeSeries( - name=orig_analog.name, - description=orig_analog.description, - data=data, - timestamps=ts, - unit=orig_analog.unit, - ) - nwbfile.processing["analog"]["analog"].add_timeseries(new_analog) - - if old_spatial_series: - # remove last two columns of all SpatialSeries data (xloc2, yloc2) - # because it does not conform with NWB 2.5 and they are all zeroes - # anyway - - new_spatial_series = list() - for spatial_series_name in list( - nwbfile.processing["behavior"]["position"].spatial_series - ): - spatial_series = nwbfile.processing["behavior"][ - "position" - ].spatial_series.pop(spatial_series_name) - assert isinstance(spatial_series, pynwb.behavior.SpatialSeries) - data = spatial_series.data[:, 0:2] - ts = spatial_series.timestamps[0:n_timestamps_to_keep] - new_spatial_series.append( - pynwb.behavior.SpatialSeries( - name=spatial_series.name, - description=spatial_series.description, - data=data, - timestamps=spatial_series.timestamps, - reference_frame=spatial_series.reference_frame, - ) - ) - - for spatial_series in new_spatial_series: - nwbfile.processing["behavior"]["position"].add_spatial_series( - spatial_series - ) - - with pynwb.NWBHDF5IO(file_out, "w") as export_io: - export_io.export(io, nwbfile) From 7ff01adcfa20b1e4feb710761dada629c22e9d38 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Mon, 8 Jan 2024 17:19:12 -0600 Subject: [PATCH 04/16] WIP: Improve coverage behav, dio --- src/spyglass/common/common_dio.py | 5 +- src/spyglass/settings.py | 4 +- tests/common/conftest.py | 15 +++++ tests/common/test_behav.py | 72 +++++++++++++++++++++++ tests/common/test_device.py | 40 +++++++++++++ tests/common/test_dio.py | 31 ++++++++++ tests/common/test_insert.py | 4 +- tests/conftest.py | 35 +++++++++-- tests/data_import/test_insert_sessions.py | 2 +- 9 files changed, 196 insertions(+), 12 deletions(-) create mode 100644 tests/common/test_behav.py create mode 100644 tests/common/test_device.py create mode 100644 tests/common/test_dio.py diff --git a/src/spyglass/common/common_dio.py b/src/spyglass/common/common_dio.py index 93a087116..7eae1e9d3 100644 --- a/src/spyglass/common/common_dio.py +++ b/src/spyglass/common/common_dio.py @@ -50,7 +50,7 @@ def make(self, key): key["dio_object_id"] = event_series.object_id self.insert1(key, skip_duplicates=True) - def plot_all_dio_events(self): + def plot_all_dio_events(self, return_fig=False): """Plot all DIO events in the session. Examples @@ -117,3 +117,6 @@ def plot_all_dio_events(self): plt.suptitle(f"DIO events in {nwb_file_names[0]}") else: plt.suptitle(f"DIO events in {', '.join(nwb_file_names)}") + + if return_fig: + return plt.gcf() diff --git a/src/spyglass/settings.py b/src/spyglass/settings.py index 122d53014..e2e0a2142 100644 --- a/src/spyglass/settings.py +++ b/src/spyglass/settings.py @@ -30,8 +30,8 @@ def __init__(self, base_dir: str = None, **kwargs): self.supplied_base_dir = base_dir self._config = dict() self.config_defaults = dict(prepopulate=True) - self._debug_mode = False - self._test_mode = False + self._debug_mode = kwargs.get("debug_mode", False) + self._test_mode = kwargs.get("test_mode", False) self._dlc_base = None self.relative_dirs = { diff --git a/tests/common/conftest.py b/tests/common/conftest.py index f74e8002e..dd9f2871f 100644 --- a/tests/common/conftest.py +++ b/tests/common/conftest.py @@ -21,8 +21,23 @@ def mini_pos_series(mini_pos): yield next(iter(mini_pos)) +@pytest.fixture(scope="session") +def mini_pos_interval_dict(common): + yield {"interval_list_name": common.PositionSource.get_pos_interval_name(0)} + + @pytest.fixture(scope="session") def mini_pos_tbl(common, mini_pos_series): yield common.PositionSource.SpatialSeries * common.RawPosition.PosObject & { "name": mini_pos_series } + + +@pytest.fixture(scope="session") +def pos_src(common): + yield common.PositionSource() + + +@pytest.fixture(scope="session") +def pos_interval_01(pos_src): + yield [pos_src.get_pos_interval_name(x) for x in range(1)] diff --git a/tests/common/test_behav.py b/tests/common/test_behav.py new file mode 100644 index 000000000..42595e0f6 --- /dev/null +++ b/tests/common/test_behav.py @@ -0,0 +1,72 @@ +import pytest +from numpy import array_equal +from pandas import DataFrame + + +def test_invalid_interval(pos_src): + """Test invalid interval""" + with pytest.raises(ValueError): + pos_src.get_pos_interval_name("invalid_interval") + + +def test_invalid_epoch_num(common): + """Test invalid epoch num""" + with pytest.raises(ValueError): + common.PositionSource.get_epoch_num("invalid_epoch_num") + + +def test_raw_position_fetchnwb(common, mini_pos, mini_pos_interval_dict): + """Test RawPosition fetch nwb""" + fetched = DataFrame( + (common.RawPosition & mini_pos_interval_dict) + .fetch_nwb()[0]["raw_position"] + .data + ) + raw = DataFrame(mini_pos["led_0_series_0"].data) + # compare with mini_pos + assert fetched.equals(raw), "RawPosition fetch_nwb failed" + + +@pytest.mark.skip(reason="No video files in mini") +def test_videofile_no_transaction(common, mini_restr): + """Test no transaction""" + common.VideoFile()._no_transaction_make(mini_restr) + + +@pytest.mark.skip(reason="No video files in mini") +def test_videofile_update_entries(common): + """Test update entries""" + common.VideoFile().update_entries() + + +@pytest.mark.skip(reason="No video files in mini") +def test_videofile_getabspath(common, mini_restr): + """Test get absolute path""" + common.VideoFile().getabspath(mini_restr) + + +def test_posinterval_no_transaction(verbose_context, common, mini_restr): + """Test no transaction""" + before = common.PositionIntervalMap().fetch() + with verbose_context: + common.PositionIntervalMap()._no_transaction_make(mini_restr) + after = common.PositionIntervalMap().fetch() + assert array_equal( + before, after + ), "PositionIntervalMap no_transaction had unexpected effect" + + +def test_get_pos_interval_name(pos_src, mini_copy_name, pos_interval_01): + """Test get pos interval name""" + names = [f"pos {x} valid times" for x in range(1)] + assert pos_interval_01 == names, "get_pos_interval_name failed" + + +def test_convert_epoch(common, pos_interval_01): + this_key = (common.IntervalList & {"interval_list_name": "01_s1"}).fetch1() + ret = common.common_behav.convert_epoch_interval_name_to_position_interval_name( + this_key + ) + assert ( + ret == pos_interval_01[0] + ), "convert_epoch_interval_name_to_position_interval_name failed" diff --git a/tests/common/test_device.py b/tests/common/test_device.py new file mode 100644 index 000000000..05682562a --- /dev/null +++ b/tests/common/test_device.py @@ -0,0 +1,40 @@ +import pytest +from numpy import array_equal + + +def test_invalid_device(common, populate_exception): + device_dict = common.DataAcquisitionDevice.fetch(as_dict=True)[0] + device_dict["other"] = "invalid" + with pytest.raises(populate_exception): + common.DataAcquisitionDevice._add_device(device_dict) + + +def test_spikegadets_system_alias(mini_insert, common): + assert ( + common.DataAcquisitionDevice()._add_system("MCU") == "SpikeGadgets" + ), "SpikeGadgets MCU alias not found" + + +def test_invalid_probe(common, populate_exception): + probe_dict = common.ProbeType.fetch(as_dict=True)[0] + probe_dict["other"] = "invalid" + with pytest.raises(populate_exception): + common.Probe._add_probe_type(probe_dict) + + +def test_create_probe(common, mini_devices, mini_path, mini_copy_name): + probe_id = common.Probe.fetch("KEY", as_dict=True)[0] + probe_type = common.ProbeType.fetch("KEY", as_dict=True)[0] + before = common.Probe.fetch() + common.Probe.create_from_nwbfile( + nwb_file_name=mini_copy_name.split("/")[-1], + nwb_device_name="probe 0", + contact_side_numbering=False, + **probe_id, + **probe_type, + ) + after = common.Probe.fetch() + # Because already inserted, expect no change + assert array_equal( + before, after + ), "Probe create_from_nwbfile had unexpected effect" diff --git a/tests/common/test_dio.py b/tests/common/test_dio.py new file mode 100644 index 000000000..f4b258dde --- /dev/null +++ b/tests/common/test_dio.py @@ -0,0 +1,31 @@ +import pytest +from numpy import allclose, array + + +@pytest.fixture(scope="session") +def dio_events(common): + yield common.common_dio.DIOEvents + + +@pytest.fixture(scope="session") +def dio_fig(mini_insert, dio_events, mini_restr): + yield (dio_events & mini_restr).plot_all_dio_events(return_fig=True) + + +def test_plot_dio_axes(dio_fig, dio_events): + """Check that all events are plotted.""" + events_fig = set(x.yaxis.get_label().get_text() for x in dio_fig.get_axes()) + events_fetch = set(dio_events.fetch("dio_event_name")) + assert events_fig == events_fetch, "Mismatch in events plotted." + + +def test_plot_dio_data(common, dio_fig): + """Hash summary of figure object.""" + data_fig = dio_fig.get_axes()[0].lines[0].get_xdata() + data_block = ( + common.IntervalList & 'interval_list_name LIKE "raw%"' + ).fetch1("valid_times") + data_fetch = array((data_block[0][0], data_block[-1][1])) + assert allclose( + data_fig, data_fetch, atol=1e-8 + ), "Mismatch in data plotted." diff --git a/tests/common/test_insert.py b/tests/common/test_insert.py index 5b5b2c4b1..9d6f87ef3 100644 --- a/tests/common/test_insert.py +++ b/tests/common/test_insert.py @@ -137,13 +137,13 @@ def test_insert_pos( assert data_obj_id == raw_obj_id, "PosObject insertion error" -def test_fetch_pos( +def test_fetch_posobj( mini_insert, common, mini_pos, mini_pos_series, mini_pos_tbl ): pos_key = ( common.PositionSource.SpatialSeries & mini_pos_tbl.fetch("KEY") ).fetch(as_dict=True)[0] - pos_df = (common.RawPosition.PosObject & pos_key).fetch1_dataframe() + pos_df = (common.RawPosition & pos_key).fetch1_dataframe().iloc[:, 0:2] series = mini_pos[mini_pos_series] raw_df = DataFrame( diff --git a/tests/conftest.py b/tests/conftest.py index b17dbd674..0727a891a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,7 +7,7 @@ import datajoint as dj import pynwb import pytest -from datajoint.logging import logger +from datajoint.logging import logger as dj_logger from .container import DockerMySQLManager @@ -139,6 +139,13 @@ def mini_path(raw_dir): yield raw_dir / "test.nwb" +@pytest.fixture(scope="session") +def mini_copy_name(mini_path): + from spyglass.utils.nwb_helper_fn import get_nwb_copy_filename # noqa: E402 + + yield get_nwb_copy_filename(mini_path) + + @pytest.fixture(scope="session") def mini_download(): # test_path = ( @@ -181,21 +188,23 @@ def mini_closed(mini_path): yield nwbfile -@pytest.fixture(scope="session") +@pytest.fixture(autouse=True, scope="session") def mini_insert(mini_path, teardown, server, dj_conn): from spyglass.common import Nwbfile, Session # noqa: E402 from spyglass.data_import import insert_sessions # noqa: E402 from spyglass.utils.nwb_helper_fn import close_nwb_files # noqa: E402 + dj_logger.info("Inserting test data.") + if len(Nwbfile()) > 0: Nwbfile().delete(safemode=False) if server.connected: insert_sessions(mini_path.name) else: - logger.error("No server connection.") + dj_logger.error("No server connection.") if len(Session()) == 0: - logger.error("No sessions inserted.") + dj_logger.error("No sessions inserted.") yield @@ -230,18 +239,32 @@ def settings(dj_conn): yield settings +@pytest.fixture(scope="session") +def populate_exception(): + from spyglass.common.errors import PopulateException + + yield PopulateException + + # ------------------ GENERAL FUNCTION ------------------ class QuietStdOut: """If quiet_spy, used to quiet prints, teardowns and table.delete prints""" + def __init__(self): + from spyglass.utils import logger as spyglass_logger + + self.spy_logger = spyglass_logger + self.previous_level = None + def __enter__(self): - logger.setLevel("CRITICAL") + self.previous_level = self.spy_logger.getEffectiveLevel() + self.spy_logger.setLevel("CRITICAL") self._original_stdout = sys.stdout sys.stdout = open(os.devnull, "w") def __exit__(self, exc_type, exc_val, exc_tb): - logger.setLevel("INFO") + self.spy_logger.setLevel(self.previous_level) sys.stdout.close() sys.stdout = self._original_stdout diff --git a/tests/data_import/test_insert_sessions.py b/tests/data_import/test_insert_sessions.py index c2c8e6c23..7c125ed6b 100644 --- a/tests/data_import/test_insert_sessions.py +++ b/tests/data_import/test_insert_sessions.py @@ -25,7 +25,7 @@ def test_open_path(mini_path, mini_open): def test_copy_link(mini_path, settings, mini_closed, copy_nwb_link_raw_ephys): - """Test readabilty after moving the linking raw file, breaking link""" + """Test readability after moving the linking raw file, breaking link""" new_path = Path(settings.raw_dir) / "no_ephys.nwb" new_moved = Path(settings.temp_dir) / "no_ephys_moved.nwb" From dec365586961441d40aeb1048cb0d53ddb917820 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Tue, 9 Jan 2024 17:37:09 -0600 Subject: [PATCH 05/16] WIP: Add coverage, see details: - Add `return_fig` param to plotting helper functions to permit tests - `common_filter` - `common_interval` - Add coverage for ~1/2 of `common` - `common_behav` - `common_device` - `common_ephys` - `common_filter` - `common_interval` - with helper funcs tested seperately - `common_lab` - `common_nwbfile` - partial --- src/spyglass/common/common_filter.py | 12 +- src/spyglass/common/common_interval.py | 8 +- tests/common/conftest.py | 5 + tests/common/test_behav.py | 6 +- tests/common/test_device.py | 2 +- tests/common/test_ephys.py | 33 +++++ tests/common/test_filter.py | 79 ++++++++++ tests/common/test_interval.py | 27 ++++ tests/common/test_interval_helpers.py | 198 +++++++++++++++++++++++++ tests/common/test_lab.py | 115 ++++++++++++++ tests/common/test_nwbfile.py | 42 ++++++ tests/conftest.py | 7 +- 12 files changed, 523 insertions(+), 11 deletions(-) create mode 100644 tests/common/test_ephys.py create mode 100644 tests/common/test_filter.py create mode 100644 tests/common/test_interval.py create mode 100644 tests/common/test_lab.py create mode 100644 tests/common/test_nwbfile.py diff --git a/src/spyglass/common/common_filter.py b/src/spyglass/common/common_filter.py index 0472c6e18..9d2cdf9d6 100644 --- a/src/spyglass/common/common_filter.py +++ b/src/spyglass/common/common_filter.py @@ -167,9 +167,9 @@ def add_filter( def _filter_restrict(self, filter_name, fs): return ( self & {"filter_name": filter_name} & {"filter_sampling_rate": fs} - ).fetch1(as_dict=True) + ).fetch1() - def plot_magnitude(self, filter_name, fs): + def plot_magnitude(self, filter_name, fs, return_fig=False): filter_dict = self._filter_restrict(filter_name, fs) plt.figure() w, h = signal.freqz(filter_dict["filter_coeff"], worN=65536) @@ -178,11 +178,13 @@ def plot_magnitude(self, filter_name, fs): plt.xlabel("Frequency (Hz)") plt.ylabel("Magnitude") plt.title("Frequency Response") - plt.xlim(0, np.max(filter_dict["filter_coeffand_edges"] * 2)) + plt.xlim(0, np.max(filter_dict["filter_band_edges"] * 2)) plt.ylim(np.min(magnitude), -1 * np.min(magnitude) * 0.1) plt.grid(True) + if return_fig: + return plt.gcf() - def plot_fir_filter(self, filter_name, fs): + def plot_fir_filter(self, filter_name, fs, return_fig=False): filter_dict = self._filter_restrict(filter_name, fs) plt.figure() plt.clf() @@ -191,6 +193,8 @@ def plot_fir_filter(self, filter_name, fs): plt.ylabel("Magnitude") plt.title("Filter Taps") plt.grid(True) + if return_fig: + return plt.gcf() def filter_delay(self, filter_name, fs): return self.calc_filter_delay( diff --git a/src/spyglass/common/common_interval.py b/src/spyglass/common/common_interval.py index 2ad12ad34..8adba824e 100644 --- a/src/spyglass/common/common_interval.py +++ b/src/spyglass/common/common_interval.py @@ -65,7 +65,7 @@ def insert_from_nwbfile(cls, nwbf, *, nwb_file_name): cls.insert1(epoch_dict, skip_duplicates=True) - def plot_intervals(self, figsize=(20, 5)): + def plot_intervals(self, figsize=(20, 5), return_fig=False): interval_list = pd.DataFrame(self) fig, ax = plt.subplots(figsize=figsize) interval_count = 0 @@ -83,8 +83,10 @@ def plot_intervals(self, figsize=(20, 5)): ax.set_yticklabels(interval_list.interval_list_name) ax.set_xlabel("Time [s]") ax.grid(True) + if return_fig: + return fig - def plot_epoch_pos_raw_intervals(self, figsize=(20, 5)): + def plot_epoch_pos_raw_intervals(self, figsize=(20, 5), return_fig=False): interval_list = pd.DataFrame(self) fig, ax = plt.subplots(figsize=(30, 3)) @@ -144,6 +146,8 @@ def plot_epoch_pos_raw_intervals(self, figsize=(20, 5)): ax.set_yticklabels(["pos valid times", "raw data valid times", "epoch"]) ax.set_xlabel("Time [s]") ax.grid(True) + if return_fig: + return fig def intervals_by_length(interval_list, min_length=0.0, max_length=1e10): diff --git a/tests/common/conftest.py b/tests/common/conftest.py index dd9f2871f..41fdea95a 100644 --- a/tests/common/conftest.py +++ b/tests/common/conftest.py @@ -41,3 +41,8 @@ def pos_src(common): @pytest.fixture(scope="session") def pos_interval_01(pos_src): yield [pos_src.get_pos_interval_name(x) for x in range(1)] + + +@pytest.fixture(scope="session") +def common_ephys(common): + yield common.common_ephys diff --git a/tests/common/test_behav.py b/tests/common/test_behav.py index 42595e0f6..65e4addc3 100644 --- a/tests/common/test_behav.py +++ b/tests/common/test_behav.py @@ -51,12 +51,12 @@ def test_posinterval_no_transaction(verbose_context, common, mini_restr): with verbose_context: common.PositionIntervalMap()._no_transaction_make(mini_restr) after = common.PositionIntervalMap().fetch() - assert array_equal( - before, after + assert ( + len(after) == len(before) + 2 ), "PositionIntervalMap no_transaction had unexpected effect" -def test_get_pos_interval_name(pos_src, mini_copy_name, pos_interval_01): +def test_get_pos_interval_name(pos_src, pos_interval_01): """Test get pos interval name""" names = [f"pos {x} valid times" for x in range(1)] assert pos_interval_01 == names, "get_pos_interval_name failed" diff --git a/tests/common/test_device.py b/tests/common/test_device.py index 05682562a..84323f2df 100644 --- a/tests/common/test_device.py +++ b/tests/common/test_device.py @@ -27,7 +27,7 @@ def test_create_probe(common, mini_devices, mini_path, mini_copy_name): probe_type = common.ProbeType.fetch("KEY", as_dict=True)[0] before = common.Probe.fetch() common.Probe.create_from_nwbfile( - nwb_file_name=mini_copy_name.split("/")[-1], + nwb_file_name=mini_copy_name, nwb_device_name="probe 0", contact_side_numbering=False, **probe_id, diff --git a/tests/common/test_ephys.py b/tests/common/test_ephys.py new file mode 100644 index 000000000..9ad1ea0a4 --- /dev/null +++ b/tests/common/test_ephys.py @@ -0,0 +1,33 @@ +import pytest +from numpy import array_equal + + +def test_create_from_config(mini_insert, common_ephys, mini_path): + before = common_ephys.Electrode().fetch() + common_ephys.Electrode.create_from_config(mini_path.stem) + after = common_ephys.Electrode().fetch() + # Because already inserted, expect no change + assert array_equal( + before, after + ), "Electrode.create_from_config had unexpected effect" + + +def test_raw_object(mini_insert, common_ephys, mini_dict, mini_content): + obj_fetch = common_ephys.Raw().nwb_object(mini_dict).object_id + obj_raw = mini_content.get_acquisition().object_id + assert obj_fetch == obj_raw, "Raw.nwb_object did not return expected object" + + +def test_set_lfp_electrodes(mini_insert, common_ephys, mini_copy_name): + before = common_ephys.LFPSelection().fetch() + common_ephys.LFPSelection().set_lfp_electrodes(mini_copy_name, [0]) + after = common_ephys.LFPSelection().fetch() + # Because already inserted, expect no change + assert ( + len(after) == len(before) + 1 + ), "Set LFP electrodes had unexpected effect" + + +@pytest.mark.skip(reason="Not testing V0: common lfp") +def test_lfp(): + pass diff --git a/tests/common/test_filter.py b/tests/common/test_filter.py new file mode 100644 index 000000000..9e0be584f --- /dev/null +++ b/tests/common/test_filter.py @@ -0,0 +1,79 @@ +import pytest + + +@pytest.fixture(scope="session") +def filter_parameters(common): + yield common.FirFilterParameters() + + +@pytest.fixture(scope="session") +def filter_dict(filter_parameters): + yield {"filter_name": "test", "fs": 10} + + +@pytest.fixture(scope="session") +def add_filter(filter_parameters, filter_dict): + filter_parameters.add_filter( + **filter_dict, filter_type="lowpass", band_edges=[1, 2] + ) + + +@pytest.fixture(scope="session") +def filter_coeff(filter_parameters, filter_dict): + yield filter_parameters._filter_restrict(**filter_dict)["filter_coeff"] + + +def test_add_filter(filter_parameters, add_filter, filter_dict): + """Test add filter""" + assert filter_parameters & filter_dict, "add_filter failed" + + +def test_filter_restrict( + filter_parameters, add_filter, filter_dict, filter_coeff +): + assert sum(filter_coeff) == pytest.approx( + 0.999134, abs=1e-6 + ), "filter_restrict failed" + + +def test_plot_magitude(filter_parameters, add_filter, filter_dict): + fig = filter_parameters.plot_magnitude(**filter_dict, return_fig=True) + assert sum(fig.get_axes()[0].lines[0].get_xdata()) == pytest.approx( + 163837.5, abs=1 + ), "plot_magnitude failed" + + +def test_plot_fir_filter( + filter_parameters, add_filter, filter_dict, filter_coeff +): + fig = filter_parameters.plot_fir_filter(**filter_dict, return_fig=True) + assert sum(fig.get_axes()[0].lines[0].get_ydata()) == sum( + filter_coeff + ), "Plot filter failed" + + +def test_filter_delay(filter_parameters, add_filter, filter_dict): + delay = filter_parameters.filter_delay(**filter_dict) + assert delay == 27, "filter_delay failed" + + +def test_time_bound_warning(filter_parameters, add_filter, filter_dict): + with pytest.warns(UserWarning): + filter_parameters._time_bound_check(1, 3, [2, 5], 4) + + +@pytest.mark.skip(reason="Not testing V0: filter_data") +def test_filter_data(filter_parameters, mini_content): + pass + + +def test_calc_filter_delay(filter_parameters, filter_coeff): + delay = filter_parameters.calc_filter_delay(filter_coeff) + assert delay == 27, "filter_delay failed" + + +def test_create_standard_filters(filter_parameters): + filter_parameters.create_standard_filters() + assert filter_parameters & { + "filter_name": "LFP 0-400 Hz" + }, "create_standard_filters failed" diff --git a/tests/common/test_interval.py b/tests/common/test_interval.py new file mode 100644 index 000000000..8353961f8 --- /dev/null +++ b/tests/common/test_interval.py @@ -0,0 +1,27 @@ +import pytest +from numpy import array_equal + + +@pytest.fixture(scope="session") +def interval_list(common): + yield common.IntervalList() + + +def test_plot_intervals(mini_insert, interval_list): + fig = interval_list.plot_intervals(return_fig=True) + interval_list_name = fig.get_axes()[0].get_yticklabels()[0].get_text() + times_fetch = ( + interval_list & {"interval_list_name": interval_list_name} + ).fetch1("valid_times")[0] + times_plot = fig.get_axes()[0].lines[0].get_xdata() + + assert array_equal(times_fetch, times_plot), "plot_intervals failed" + + +def test_plot_epoch(mini_insert, interval_list): + fig = interval_list.plot_epoch_pos_raw_intervals(return_fig=True) + epoch_label = fig.get_axes()[0].get_yticklabels()[-1].get_text() + assert epoch_label == "epoch", "plot_epoch failed" + + epoch_interv = fig.get_axes()[0].lines[0].get_ydata() + assert array_equal(epoch_interv, [1, 1]), "plot_epoch failed" diff --git a/tests/common/test_interval_helpers.py b/tests/common/test_interval_helpers.py index 621210a8a..de7aad6e7 100644 --- a/tests/common/test_interval_helpers.py +++ b/tests/common/test_interval_helpers.py @@ -72,3 +72,201 @@ def test_set_difference(set_difference, one, two, expected_result): assert ( set_difference(one, two) == expected_result ), "Problem with common_interval.interval_set_difference_inds" + + +@pytest.mark.parametrize( + "expected_result, min_len, max_len", + [ + (np.array([[0, 1]]), 0.0, 10), + (np.array([[0, 1], [0, 1e11]]), 0.0, 1e12), + (np.array([[0, 0], [0, 1]]), -1, 10), + ], +) +def test_intervals_by_length(common, expected_result, min_len, max_len): + # intput is the same across all tests. Could be parametrized as above + inds = common.common_interval.intervals_by_length( + interval_list=np.array([[0, 0], [0, 1], [0, 1e11]]), + min_length=min_len, + max_length=max_len, + ) + assert np.array_equal( + inds, expected_result + ), "Problem with common_interval.intervals_by_length" + + +@pytest.fixture +def interval_list_dict(): + yield { + "interval_list": np.array([[1, 4], [6, 8]]), + "timestamps": np.array([0, 1, 5, 7, 8, 9]), + } + + +def test_interval_list_contains_ind(common, interval_list_dict): + idxs = common.common_interval.interval_list_contains_ind( + **interval_list_dict + ) + assert np.array_equal( + idxs, np.array([1, 3, 4]) + ), "Problem with common_interval.interval_list_contains_ind" + + +def test_insterval_list_contains(common, interval_list_dict): + idxs = common.common_interval.interval_list_contains(**interval_list_dict) + assert np.array_equal( + idxs, np.array([1, 7, 8]) + ), "Problem with common_interval.interval_list_contains" + + +def test_interval_list_excludes_ind(common, interval_list_dict): + idxs = common.common_interval.interval_list_excludes_ind( + **interval_list_dict + ) + assert np.array_equal( + idxs, np.array([0, 2, 5]) + ), "Problem with common_interval.interval_list_excludes_ind" + + +def test_interval_list_excludes(common, interval_list_dict): + idxs = common.common_interval.interval_list_excludes(**interval_list_dict) + assert np.array_equal( + idxs, np.array([0, 5, 9]) + ), "Problem with common_interval.interval_list_excludes" + + +def test_consolidate_intervals_1dim(common): + exp = common.common_interval.consolidate_intervals(np.array([0, 1])) + assert np.array_equal( + exp, np.array([[0, 1]]) + ), "Problem with common_interval.consolidate_intervals" + + +@pytest.mark.parametrize( + "interval1, interval2, exp_result", + [ + ( + np.array([[0, 1]]), + np.array([[2, 3]]), + np.array([[0, 3]]), + ), + ( + np.array([[2, 3]]), + np.array([[0, 1]]), + np.array([[0, 3]]), + ), + ( + np.array([[0, 3]]), + np.array([[2, 4]]), + np.array([[0, 3], [2, 4]]), + ), + ], +) +def test_union_adjacent_index(common, interval1, interval2, exp_result): + assert np.array_equal( + common.common_interval.union_adjacent_index(interval1, interval2), + exp_result, + ), "Problem with common_interval.union_adjacent_index" + + +@pytest.mark.parametrize( + "interval1, interval2, exp_result", + [ + ( + np.array([[0, 3]]), + np.array([[2, 4]]), + np.array([[0, 4]]), + ), + ( + np.array([[0, -1]]), + np.array([[2, 4]]), + np.array([[2, 0]]), + ), + ( + np.array([[0, 1]]), + np.array([[2, 1e11]]), + np.array([[0, 1], [2, 1e11]]), + ), + ], +) +def test_interval_list_union(common, interval1, interval2, exp_result): + assert np.array_equal( + common.common_interval.interval_list_union(interval1, interval2), + exp_result, + ), "Problem with common_interval.interval_list_union" + + +def test_interval_list_censor_error(common): + with pytest.raises(ValueError): + common.common_interval.interval_list_censor( + np.array([[0, 1]]), np.array([2]) + ) + + +def test_interval_list_censor(common): + assert np.array_equal( + common.common_interval.interval_list_censor( + np.array([[0, 2], [4, 5]]), np.array([1, 2, 4]) + ), + np.array([[1, 2]]), + ), "Problem with common_interval.interval_list_censor" + + +@pytest.mark.parametrize( + "interval_list, exp_result", + [ + ( + np.array([0, 1, 2, 3, 6, 7, 8, 9]), + np.array([[0, 3], [6, 9]]), + ), + ( + np.array([0, 1, 2]), + np.array([[0, 2]]), + ), + ( + np.array([2, 3, 1, 0]), + np.array([[0, 3]]), + ), + ( + np.array([2, 3, 0]), + np.array([[0, 0], [2, 3]]), + ), + ], +) +def test_interval_from_inds(common, interval_list, exp_result): + assert np.array_equal( + common.common_interval.interval_from_inds(interval_list), + exp_result, + ), "Problem with common_interval.interval_from_inds" + + +@pytest.mark.parametrize( + "intervals1, intervals2, min_length, exp_result", + [ + ( + np.array([[0, 2], [4, 5]]), + np.array([[1, 3], [2, 4]]), + 0, + np.array([[0, 1], [4, 5]]), + ), + ( + np.array([[0, 2], [4, 5]]), + np.array([[1, 3], [2, 4]]), + 1, + np.zeros((0, 2)), + ), + ( + np.array([[0, 2], [4, 6]]), + np.array([[5, 8], [2, 4]]), + 1, + np.array([[0, 2]]), + ), + ], +) +def test_interval_list_complement( + common, intervals1, intervals2, min_length, exp_result +): + ic = common.common_interval.interval_list_complement + assert np.array_equal( + ic(intervals1, intervals2, min_length), + exp_result, + ), "Problem with common_interval.interval_list_compliment" diff --git a/tests/common/test_lab.py b/tests/common/test_lab.py new file mode 100644 index 000000000..7c74ecd1c --- /dev/null +++ b/tests/common/test_lab.py @@ -0,0 +1,115 @@ +import pytest +from numpy import array_equal + + +@pytest.fixture +def common_lab(common): + yield common.common_lab + + +@pytest.fixture +def add_admin(common_lab, teardown): + common_lab.LabMember.insert1( + dict( + lab_member_name="This Admin", + first_name="This", + last_name="Admin", + ), + skip_duplicates=True, + ) + common_lab.LabMember.LabMemberInfo.insert1( + dict( + lab_member_name="This Admin", + google_user_name="This Admin", + datajoint_user_name="this_admin", + admin=1, + ), + skip_duplicates=True, + ) + yield + if teardown: + common_lab.LabMember.delete(safe_mode=False) + + +@pytest.fixture +def add_member_team(common_lab, add_admin, teardown): + common_lab.LabMember.insert( + [ + dict( + lab_member_name="This Basic", + first_name="This", + last_name="Basic", + ), + dict( + lab_member_name="This Loner", + first_name="This", + last_name="Loner", + ), + ], + skip_duplicates=True, + ) + common_lab.LabMember.LabMemberInfo.insert( + [ + dict( + lab_member_name="This Basic", + google_user_name="This Basic", + datajoint_user_name="this_basic", + admin=0, + ), + dict( + lab_member_name="This Loner", + google_user_name="This Loner", + datajoint_user_name="this_loner", + admin=0, + ), + ], + skip_duplicates=True, + ) + common_lab.LabTeam.create_new_team( + team_name="This Team", + team_members=["This Admin", "This Basic"], + team_description="This Team Description", + ) + yield + if teardown: + common_lab.LabMember.delete(safe_mode=False) + common_lab.LabTeam.delete(safe_mode=False) + + +def test_labmember_insert_file_str(mini_insert, common_lab, mini_copy_name): + before = common_lab.LabMember.fetch() + common_lab.LabMember.insert_from_nwbfile(mini_copy_name) + after = common_lab.LabMember.fetch() + # Already inserted, test func raises no error + assert array_equal(before, after), "LabMember not inserted correctly" + + +def test_fetch_admin(common_lab, add_admin): + assert ( + "this_admin" in common_lab.LabMember().admin + ), "LabMember admin not fetched correctly" + + +def test_get_djuser(common_lab, add_admin): + assert "This Admin" == common_lab.LabMember().get_djuser_name( + "this_admin" + ), "LabMember get_djuser not fetched correctly" + + +def test_get_djuser_error(common_lab, add_admin): + with pytest.raises(ValueError): + common_lab.LabMember().get_djuser_name("This Admin2") + + +def test_get_team_members(common_lab, add_member_team): + assert common_lab.LabTeam().get_team_members("This Admin") == set( + ("This Admin", "This Basic") + ), "LabTeam get_team_members not fetched correctly" + + +def test_decompose_name_error(common_lab): + # NOTE: Should change with solve of #304 + with pytest.raises(ValueError): + common_lab.decompose_name("This Invalid Name") + with pytest.raises(ValueError): + common_lab.decompose_name("This, Invalid, Name") diff --git a/tests/common/test_nwbfile.py b/tests/common/test_nwbfile.py new file mode 100644 index 000000000..546d2f05f --- /dev/null +++ b/tests/common/test_nwbfile.py @@ -0,0 +1,42 @@ +import os + +import pytest + + +@pytest.fixture +def common_nwbfile(common): + """Return a common NWBFile object.""" + return common.common_nwbfile + + +@pytest.fixture +def lockfile(base_dir, teardown): + lockfile = base_dir / "temp.lock" + lockfile.touch() + os.environ["NWB_LOCK_FILE"] = str(lockfile) + yield lockfile + if teardown: + os.remove(lockfile) + lockfile.unlink() + + +def test_get_file_name_error(common_nwbfile): + """Test that an error is raised when trying non-existent file.""" + with pytest.raises(ValueError): + common_nwbfile.NWBFile.get_file_name("non-existent-file.nwb") + + +def test_add_to_lock(common_nwbfile, lockfile, mini_copy_name): + common_nwbfile.NWBFile.add_to_lock(mini_copy_name) + with lockfile.open("r") as f: + assert mini_copy_name in f.read() + + with pytest.raises(AssertionError): + common_nwbfile.NWBFile.add_to_lock(mini_copy_name) + + +def test_nwbfile_cleanup(common_nwbfile): + before = len(common_nwbfile.NWBFile.fetch()) + common_nwbfile.NWBFile.add_to_lock(delete_files=False) + after = len(common_nwbfile.NWBFile.fetch()) + assert before == after, "Nwbfile cleanup changed table entry count." diff --git a/tests/conftest.py b/tests/conftest.py index 0727a891a..877408256 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -143,7 +143,7 @@ def mini_path(raw_dir): def mini_copy_name(mini_path): from spyglass.utils.nwb_helper_fn import get_nwb_copy_filename # noqa: E402 - yield get_nwb_copy_filename(mini_path) + yield get_nwb_copy_filename(mini_path).split("/")[-1] @pytest.fixture(scope="session") @@ -218,6 +218,11 @@ def mini_restr(mini_path): yield f"nwb_file_name LIKE '{mini_path.stem}%'" +@pytest.fixture(scope="session") +def mini_dict(mini_copy_name): + yield {"nwb_file_name": mini_copy_name} + + @pytest.fixture(scope="session") def common(dj_conn): from spyglass import common From 468ae2d43bd985446150edced9227861bc8c78b9 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Wed, 10 Jan 2024 17:18:27 -0600 Subject: [PATCH 06/16] WIP pytest common 2nd half, start lfp --- .github/workflows/test-conda.yml | 10 ++ src/spyglass/common/common_session.py | 10 +- tests/common/test_position.py | 151 ++++++++++++++++++++++++ tests/common/test_region.py | 21 ++++ tests/common/test_ripple.py | 6 + tests/common/test_sensors.py | 21 ++++ tests/common/test_session.py | 81 +++++++++++++ tests/conftest.py | 29 ++--- tests/lfp/conftest.py | 158 ++++++++++++++++++++++++++ tests/lfp/test_pipeline.py | 28 +++++ 10 files changed, 492 insertions(+), 23 deletions(-) create mode 100644 tests/common/test_position.py create mode 100644 tests/common/test_region.py create mode 100644 tests/common/test_ripple.py create mode 100644 tests/common/test_sensors.py create mode 100644 tests/common/test_session.py create mode 100644 tests/lfp/conftest.py create mode 100644 tests/lfp/test_pipeline.py diff --git a/.github/workflows/test-conda.yml b/.github/workflows/test-conda.yml index 576713163..d30475527 100644 --- a/.github/workflows/test-conda.yml +++ b/.github/workflows/test-conda.yml @@ -39,6 +39,16 @@ jobs: - name: Install spyglass run: | pip install -e .[test] + - name: Download data + env: + UCSF_BOX_TOKEN: ${{ secrets.UCSF_BOX_TOKEN }} + UCSF_BOX_USER: ${{ secrets.UCSF_BOX_USER }} + WEBSITE: ftps://ftp.box.com/trodes_to_nwb_test_data/minirec20230622.nwb + run: | + mkdir -p ./tests/test_data; mkdir -p ./tests/test_data/raw + wget --recursive --no-verbose --no-host-directories --no-directories \ + --user $UCSF_BOX_USER --password $UCSF_BOX_TOKEN \ + -P ./tests/test_data/raw $WEBSITE - name: Run tests run: | pytest -rP # env vars are set within certain tests diff --git a/src/spyglass/common/common_session.py b/src/spyglass/common/common_session.py index 71a8323db..f6f783262 100644 --- a/src/spyglass/common/common_session.py +++ b/src/spyglass/common/common_session.py @@ -223,17 +223,19 @@ def add_session_to_group( ) @staticmethod - def remove_session_from_group(nwb_file_name: str, session_group_name: str): + def remove_session_from_group( + nwb_file_name: str, session_group_name: str, *args, **kwargs + ): query = { "session_group_name": session_group_name, "nwb_file_name": nwb_file_name, } - (SessionGroupSession & query).delete() + (SessionGroupSession & query).delete(*args, **kwargs) @staticmethod - def delete_group(session_group_name: str): + def delete_group(session_group_name: str, *args, **kwargs): query = {"session_group_name": session_group_name} - (SessionGroup & query).delete() + (SessionGroup & query).delete(*args, **kwargs) @staticmethod def get_group_sessions(session_group_name: str): diff --git a/tests/common/test_position.py b/tests/common/test_position.py new file mode 100644 index 000000000..47f285977 --- /dev/null +++ b/tests/common/test_position.py @@ -0,0 +1,151 @@ +import pytest +from datajoint.hash import key_hash + + +@pytest.fixture +def common_position(common): + yield common.common_position + + +@pytest.fixture +def interval_position_info(common_position): + yield common_position.IntervalPositionInfo + + +@pytest.fixture +def default_param_key(): + yield {"position_info_param_name": "default"} + + +@pytest.fixture +def interval_key(common): + yield (common.IntervalList & "interval_list_name LIKE 'pos 0%'").fetch1( + "KEY" + ) + + +@pytest.fixture +def param_table(common_position, default_param_key, teardown): + param_table = common_position.PositionInfoParameters() + param_table.insert1(default_param_key, skip_duplicates=True) + yield param_table + if teardown: + param_table.delete(safemode=False) + + +@pytest.fixture +def upsample_position( + common, + common_position, + param_table, + default_param_key, + teardown, + interval_key, +): + params = (param_table & default_param_key).fetch1() + upsample_param_key = {"position_info_param_name": "upsampled"} + param_table.insert1( + { + **params, + **upsample_param_key, + "is_upsampled": 1, + "max_separation": 80, + "upsampling_sampling_rate": 500, + }, + skip_duplicates=True, + ) + interval_pos_key = {**interval_key, **upsample_param_key} + common_position.IntervalPositionInfoSelection.insert1( + interval_pos_key, skip_duplicates=True + ) + common_position.IntervalPositionInfo.populate(interval_pos_key) + yield interval_pos_key + if teardown: + (param_table & upsample_param_key).delete(safemode=False) + + +@pytest.fixture +def interval_pos_key(upsample_position): + yield upsample_position + + +def test_interval_position_info_insert(common_position, interval_pos_key): + assert common_position.IntervalPositionInfo & interval_pos_key + + +@pytest.fixture +def upsample_position_error( + upsample_position, + default_param_key, + param_table, + common, + common_position, + teardown, + interval_key, +): + params = (param_table & default_param_key).fetch1() + upsample_param_key = {"position_info_param_name": "upsampled error"} + param_table.insert1( + { + **params, + **upsample_param_key, + "is_upsampled": 1, + "max_separation": 1, + "upsampling_sampling_rate": 500, + }, + skip_duplicates=True, + ) + interval_pos_key = {**interval_key, **upsample_param_key} + common_position.IntervalPositionInfoSelection.insert1(interval_pos_key) + yield interval_pos_key + if teardown: + (param_table & upsample_param_key).delete(safemode=False) + + +def test_interval_position_info_insert_error( + interval_position_info, upsample_position_error +): + with pytest.raises(ValueError): + interval_position_info.populate(upsample_position_error) + + +def test_fetch1_dataframe(interval_position_info, interval_pos_key): + df = (interval_position_info & interval_pos_key).fetch1_dataframe() + err_msg = "Unexpected output of IntervalPositionInfo.fetch1_dataframe" + assert df.shape == (5193, 6), err_msg + + df_sums = {c: df[c].iloc[:5].sum() for c in df.columns} + df_sums_exp = { + "head_orientation": 4.4300073600180125, + "head_position_x": 111.25, + "head_position_y": 141.75, + "head_speed": 0.6084872579024899, + "head_velocity_x": -0.4329520555149495, + "head_velocity_y": 0.42756198762527325, + } + for k in df_sums: + assert k in df_sums_exp, err_msg + assert df_sums[k] == pytest.approx(df_sums_exp[k], rel=0.02), err_msg + + +def test_interval_position_info_kwarg_error(interval_position_info): + with pytest.raises(ValueError): + interval_position_info._fix_kwargs() + + +def test_interval_position_info_kwarg_alias(interval_position_info): + in_tuple = (0, 1, 2, 3) + out_tuple = interval_position_info._fix_kwargs( + head_orient_smoothing_std_dev=in_tuple[0], + head_speed_smoothing_std_dev=in_tuple[1], + max_separation=in_tuple[2], + max_speed=in_tuple[3], + ) + assert ( + out_tuple == in_tuple + ), "IntervalPositionInfo._fix_kwargs() should alias old arg names." + + +@pytest.mark.skip(reason="Not testing with video data yet.") +def test_position_video(common_position): + pass diff --git a/tests/common/test_region.py b/tests/common/test_region.py new file mode 100644 index 000000000..1c07689ec --- /dev/null +++ b/tests/common/test_region.py @@ -0,0 +1,21 @@ +import pytest +from datajoint import U as dj_U + + +@pytest.fixture +def brain_region(common): + yield common.common_region.BrainRegion() + + +def test_region_add(brain_region): + next_id = ( + dj_U().aggr(brain_region, n="max(region_id)").fetch1("n") or 0 + ) + 1 + region_id = brain_region.fetch_add( + region_name="test_region_add", + subregion_name="test_subregion_add", + subsubregion_name="test_subsubregion_add", + ) + assert ( + region_id == next_id + ), "Region.fetch_add() should autincrement region_id." diff --git a/tests/common/test_ripple.py b/tests/common/test_ripple.py new file mode 100644 index 000000000..71a57d022 --- /dev/null +++ b/tests/common/test_ripple.py @@ -0,0 +1,6 @@ +import pytest + + +@pytest.mark.skip(reason="Not testing V0: common_ripple") +def test_common_ripple(common): + pass diff --git a/tests/common/test_sensors.py b/tests/common/test_sensors.py new file mode 100644 index 000000000..9cdedeeb4 --- /dev/null +++ b/tests/common/test_sensors.py @@ -0,0 +1,21 @@ +import pytest + + +@pytest.fixture +def sensor_data(common, mini_insert): + tbl = common.common_sensors.SensorData() + tbl.populate() + yield tbl + + +def test_sensor_data_insert(sensor_data, mini_insert, mini_restr, mini_content): + obj_fetch = (sensor_data & mini_restr).fetch1("sensor_data_object_id") + obj_raw = ( + mini_content.processing["analog"] + .data_interfaces["analog"] + .time_series["analog"] + .object_id + ) + assert ( + obj_fetch == obj_raw + ), "SensorData object_id does not match raw object_id." diff --git a/tests/common/test_session.py b/tests/common/test_session.py new file mode 100644 index 000000000..6e0a8f0ce --- /dev/null +++ b/tests/common/test_session.py @@ -0,0 +1,81 @@ +import pytest +from datajoint.errors import DataJointError + + +@pytest.fixture +def common_session(common): + return common.common_session + + +@pytest.fixture +def group_name_dict(): + return {"session_group_name": "group1"} + + +@pytest.fixture +def add_session_group(common_session, group_name_dict): + session_group = common_session.SessionGroup() + session_group_dict = { + **group_name_dict, + "session_group_description": "group1 description", + } + session_group.add_group(**session_group_dict, skip_duplicates=True) + session_group_dict["session_group_description"] = "updated description" + session_group.update_session_group_description(**session_group_dict) + yield session_group, session_group_dict + + +@pytest.fixture +def session_group(add_session_group): + yield add_session_group[0] + + +@pytest.fixture +def session_group_dict(add_session_group): + yield add_session_group[1] + + +def test_session_group_add(session_group, session_group_dict): + assert session_group & session_group_dict, "Session group not added" + + +@pytest.fixture +def add_session_to_group(session_group, mini_copy_name, group_name_dict): + session_group.add_session_to_group( + nwb_file_name=mini_copy_name, **group_name_dict + ) + + +def test_addremove_session_group( + common_session, + session_group, + session_group_dict, + group_name_dict, + mini_copy_name, + add_session_to_group, + add_session_group, +): + assert session_group & session_group_dict, "Session not added to group" + + session_group.remove_session_from_group( + nwb_file_name=mini_copy_name, + safemode=False, + **group_name_dict, + ) + assert ( + len(common_session.SessionGroupSession & session_group_dict) == 0 + ), "SessionGroupSession not removed from by helper function" + + +def test_get_group_sessions( + session_group, group_name_dict, add_session_to_group +): + ret = session_group.get_group_sessions(**group_name_dict) + assert len(ret) == 1, "Incorrect number of sessions returned" + + +def test_delete_group_error(session_group, group_name_dict): + session_group.delete_group(**group_name_dict, safemode=False) + assert ( + len(session_group & group_name_dict) == 0 + ), "Group not deleted by helper function" diff --git a/tests/conftest.py b/tests/conftest.py index 877408256..f48bc15be 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,7 @@ import warnings from contextlib import nullcontext from pathlib import Path +from time import sleep as tsleep import datajoint as dj import pynwb @@ -14,6 +15,7 @@ # ---------------------- CONSTANTS --------------------- # globals in pytest_configure: BASE_DIR, SERVER, TEARDOWN, VERBOSE +# download managed by gh-action test-conda, so no need to download here warnings.filterwarnings("ignore", category=UserWarning, module="hdmf") @@ -136,7 +138,14 @@ def raw_dir(base_dir): @pytest.fixture(scope="session") def mini_path(raw_dir): - yield raw_dir / "test.nwb" + path = raw_dir / "minirec20230622.nwb" + + timeout, wait = 60, 5 # download managed by gh-action test-conda + for _ in range(timeout // wait): # wait for download to finish + if path.exists(): + break + tsleep(wait) + yield path @pytest.fixture(scope="session") @@ -146,24 +155,6 @@ def mini_copy_name(mini_path): yield get_nwb_copy_filename(mini_path).split("/")[-1] -@pytest.fixture(scope="session") -def mini_download(): - # test_path = ( - # "ipfs://bafybeie4svt3paz5vr7cw7mkgibutbtbzyab4s24hqn5pzim3sgg56m3n4" - # ) - # try: - # local_test_path = kcl.load_file(test_path) - # except Exception as e: - # if os.environ.get("KACHERY_CLOUD_EPHEMERAL", None) != "TRUE": - # print( - # "Cannot load test file in non-ephemeral mode. Kachery cloud" - # + "client may need to be registered." - # ) - # raise e - # os.rename(local_test_path, nwbfile_path) - pass - - @pytest.fixture(scope="session") def mini_content(mini_path): with pynwb.NWBHDF5IO( diff --git a/tests/lfp/conftest.py b/tests/lfp/conftest.py new file mode 100644 index 000000000..de399b0e6 --- /dev/null +++ b/tests/lfp/conftest.py @@ -0,0 +1,158 @@ +import copy + +import numpy as np +import pytest + + +@pytest.fixture(scope="session") +def lfp(common): + from spyglass import lfp + + return lfp + + +@pytest.fixture(scope="session") +def lfp_band(lfp): + return lfp.analysis.v1 + + +@pytest.fixture(scope="session") +def lfp_constants(common, mini_copy_name): + n_delay = 9 + lfp_electrode_group_name = "test" + orig_interval_list_name, orig_valid_times = ( + common.IntervalList & "interval_list_name LIKE '01_%'" + ).fetch("interval_list_name", "valid_times")[0] + new_interval_list_name = orig_interval_list_name + f"_first{n_delay}" + new_interval_list_key = { + "nwb_file_name": mini_copy_name, + "interval_list_name": new_interval_list_name, + "valid_times": np.asarray( + [[orig_valid_times[0, 0], orig_valid_times[0, 0] + n_delay]] + ), + } + + yield dict( + lfp_electrode_ids=[0], + lfp_electrode_group_name=lfp_electrode_group_name, + lfp_eg_key={ + "nwb_file_name": mini_copy_name, + "lfp_electrode_group_name": lfp_electrode_group_name, + }, + n_delay=n_delay, + orig_interval_list_name="01_s1", + orig_valid_times=orig_valid_times, + interval_list_name=new_interval_list_name, + interval_key=new_interval_list_key, + filter1_name="LFP 0-400 Hz", + filter_sampling_rate=30_000, + filter2_name="Theta 5-11 Hz", + lfp_band_electrode_ids=[0], # assumes we've filtered these electrodes + lfp_band_sampling_rate=100, # desired sampling rate + ) + + +@pytest.fixture(scope="session") +def add_electrode_group(common, mini_copy_name, lfp_constants): + common.FirFilterParameters().create_standard_filters() + lfp.lfp_electrode.LFPElectrodeGroup.create_lfp_electrode_group( + nwb_file_name=mini_copy_name, + group_name=lfp_constants.get("lfp_electrode_group_name"), + electrode_list=lfp_constants.get("lfp_electrode_ids"), + ) + + +@pytest.fixture(scope="session") +def add_interval(common, lfp_constants): + common.IntervalList.insert1( + lfp_constants.get("interval_key"), skip_duplicates=True + ) + yield lfp_constants.get("interval_list_name") + + +@pytest.fixture(scope="session") +def add_selection(lfp, common, add_interval, lfp_constants): + lfp_s_key = { + **lfp_constants.get("lfp_eg_key"), + "target_interval_list_name": add_interval, + "filter_name": lfp_constants.get("filter1_name"), + "filter_sampling_rate": lfp_constants.get("filter_sampling_rate"), + } + lfp.v1.LFPSelection.insert1(lfp_s_key, skip_duplicates=True) + yield lfp_s_key + + +@pytest.fixture(scope="session") +def lfp_s_key(add_selection): + yield add_selection + + +@pytest.fixture(scope="session") +def populate_lfp(lfp, add_selection): + lfp.v1.LFPV1().populate(add_selection) + yield {"merge_id": (lfp.LFPOutput.LFPV1() & lfp_s_key).fetch1("merge_id")} + + +@pytest.fixture(scope="session") +def lfp_merge_key(populate_lfp): + yield populate_lfp + + +@pytest.fixture(scope="session") +def lfp_band_sampling_rate(lfp, lfp_merge_key): + yield lfp.LFPOutput.merge_get_parent(lfp_merge_key).fetch1( + "lfp_sampling_rate" + ) + + +@pytest.fixture(scope="session") +def add_band_filter(common, lfp_constants, lfp_band_sampling_rate): + common.FirFilterParameters().add_filter( + lfp_constants.get("filter2_name"), + lfp_band_sampling_rate, + "bandpass", + [4, 5, 11, 12], + "theta filter for 1 Khz data", + ) + yield lfp_constants.get("filter2_name") + + +@pytest.fixture(scope="session") +def add_band_selection( + lfp_band, + mini_copy_name, + lfp_merge_key, + add_interval, + lfp_constants, + add_band_filter, +): + lfp_band.LFPBandSelection().set_lfp_band_electrodes( + nwb_file_name=mini_copy_name, + lfp_merge_id=lfp_merge_key.get("merge_id"), + electrode_list=lfp_constants.get("lfp_band_electrode_ids"), + filter_name=add_band_filter, + interval_list_name=add_interval, + reference_electrode_list=[-1], + lfp_band_sampling_rate=lfp_constants.get("lfp_band_sampling_rate"), + ) + yield (lfp_band.LFPBandSelection().fetch1("KEY") & lfp_merge_key).fetch1( + "KEY" + ) + + +@pytest.fixture(scope="session") +def lfp_band_key(add_band_selection): + yield add_band_selection + + +@pytest.fixture(scope="session") +def populate_lfp_band(lfp_band, add_band_selection): + lfp_band.LFPBandV1().populate(add_band_selection) + yield + + +@pytest.fixture(scope="session") +def mini_eseries(common, mini_copy_name): + yield (common.Raw() & {"nwb_file_name": mini_copy_name}).fetch_nwb()[0][ + "raw" + ] diff --git a/tests/lfp/test_pipeline.py b/tests/lfp/test_pipeline.py new file mode 100644 index 000000000..a6567f548 --- /dev/null +++ b/tests/lfp/test_pipeline.py @@ -0,0 +1,28 @@ +import numpy as np + + +def test_lfp_eseries(common, lfp, mini_eseries, lfp_constants, lfp_merge_key): + lfp_elect_ids = lfp_constants.get("lfp_band_electrode_ids") + + lfp_elect_indices = common.get_electrode_indices( + mini_eseries, lfp_elect_ids + ) + lfp_timestamps = np.asarray(mini_eseries.timestamps) + lfp_eseries = lfp.LFPOutput.fetch_nwb(lfp_merge_key)[0]["lfp"] + assert False + + +def test_lfp_band_eseries(ccf, lfp_band, lfp_band_key, lfp_constants): + lfp_band_elect_ids = lfp_constants.get("lfp_band_electrode_ids") + lfp_elect_indices = common.get_electrode_indices( + lfp_eseries, lfp_band_electrode_ids + ) + lfp_timestamps = np.asarray(lfp_eseries.timestamps) + lfp_band_eseries = (lfp_band.LFPBandV1 & lfp_band_key).fetch_nwb()[0][ + "lfp_band" + ] + lfp_band_elect_indices = common.get_electrode_indices( + lfp_band_eseries, lfp_band_electrode_ids + ) + lfp_band_timestamps = np.asarray(lfp_band_eseries.timestamps) + assert False From 61d832591866d29f3914f5f10ece3fcc5b9eb2c5 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Fri, 12 Jan 2024 14:34:55 -0600 Subject: [PATCH 07/16] WIP lfp tests, ahead of fetch upstream --- tests/common/test_nwbfile.py | 12 +++--- tests/common/test_region.py | 16 ++++++-- tests/conftest.py | 2 +- tests/lfp/conftest.py | 77 ++++++++++++++++++++++++++---------- 4 files changed, 76 insertions(+), 31 deletions(-) diff --git a/tests/common/test_nwbfile.py b/tests/common/test_nwbfile.py index 546d2f05f..518955274 100644 --- a/tests/common/test_nwbfile.py +++ b/tests/common/test_nwbfile.py @@ -23,20 +23,20 @@ def lockfile(base_dir, teardown): def test_get_file_name_error(common_nwbfile): """Test that an error is raised when trying non-existent file.""" with pytest.raises(ValueError): - common_nwbfile.NWBFile.get_file_name("non-existent-file.nwb") + common_nwbfile.Nwbfile._get_file_name("non-existent-file.nwb") def test_add_to_lock(common_nwbfile, lockfile, mini_copy_name): - common_nwbfile.NWBFile.add_to_lock(mini_copy_name) + common_nwbfile.Nwbfile.add_to_lock(mini_copy_name) with lockfile.open("r") as f: assert mini_copy_name in f.read() with pytest.raises(AssertionError): - common_nwbfile.NWBFile.add_to_lock(mini_copy_name) + common_nwbfile.Nwbfile.add_to_lock("non-existent-file.nwb") def test_nwbfile_cleanup(common_nwbfile): - before = len(common_nwbfile.NWBFile.fetch()) - common_nwbfile.NWBFile.add_to_lock(delete_files=False) - after = len(common_nwbfile.NWBFile.fetch()) + before = len(common_nwbfile.Nwbfile.fetch()) + common_nwbfile.Nwbfile.cleanup(delete_files=False) + after = len(common_nwbfile.Nwbfile.fetch()) assert before == after, "Nwbfile cleanup changed table entry count." diff --git a/tests/common/test_region.py b/tests/common/test_region.py index 1c07689ec..95f62fe1b 100644 --- a/tests/common/test_region.py +++ b/tests/common/test_region.py @@ -3,16 +3,24 @@ @pytest.fixture -def brain_region(common): - yield common.common_region.BrainRegion() +def region_dict(): + yield dict(region_name="test_region") -def test_region_add(brain_region): +@pytest.fixture +def brain_region(common, region_dict): + brain_region = common.common_region.BrainRegion() + (brain_region & "region_id > 1").delete(safemode=False) + yield brain_region + (brain_region & "region_id > 1").delete(safemode=False) + + +def test_region_add(brain_region, region_dict): next_id = ( dj_U().aggr(brain_region, n="max(region_id)").fetch1("n") or 0 ) + 1 region_id = brain_region.fetch_add( - region_name="test_region_add", + **region_dict, subregion_name="test_subregion_add", subsubregion_name="test_subsubregion_add", ) diff --git a/tests/conftest.py b/tests/conftest.py index f48bc15be..0a328c767 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -188,7 +188,7 @@ def mini_insert(mini_path, teardown, server, dj_conn): dj_logger.info("Inserting test data.") if len(Nwbfile()) > 0: - Nwbfile().delete(safemode=False) + Nwbfile().cautious_delete(force_permission=True, safemode=False) if server.connected: insert_sessions(mini_path.name) diff --git a/tests/lfp/conftest.py b/tests/lfp/conftest.py index de399b0e6..d6b845ec9 100644 --- a/tests/lfp/conftest.py +++ b/tests/lfp/conftest.py @@ -1,5 +1,3 @@ -import copy - import numpy as np import pytest @@ -13,20 +11,33 @@ def lfp(common): @pytest.fixture(scope="session") def lfp_band(lfp): - return lfp.analysis.v1 + from spyglass.lfp.analysis.v1 import lfp_band + + return lfp_band + + +@pytest.fixture(scope="session") +def firfilters_table(common): + return common.FirFilterParameters() + + +@pytest.fixture(scope="session") +def electrodegroup_table(lfp): + return lfp.v1.LFPElectrodeGroup() @pytest.fixture(scope="session") def lfp_constants(common, mini_copy_name): n_delay = 9 lfp_electrode_group_name = "test" - orig_interval_list_name, orig_valid_times = ( - common.IntervalList & "interval_list_name LIKE '01_%'" - ).fetch("interval_list_name", "valid_times")[0] - new_interval_list_name = orig_interval_list_name + f"_first{n_delay}" - new_interval_list_key = { + orig_list_name = "01_s1" + orig_valid_times = ( + common.IntervalList & f"interval_list_name = '{orig_list_name}'" + ).fetch1("valid_times") + new_list_name = orig_list_name + f"_first{n_delay}" + new_list_key = { "nwb_file_name": mini_copy_name, - "interval_list_name": new_interval_list_name, + "interval_list_name": new_list_name, "valid_times": np.asarray( [[orig_valid_times[0, 0], orig_valid_times[0, 0] + n_delay]] ), @@ -40,10 +51,10 @@ def lfp_constants(common, mini_copy_name): "lfp_electrode_group_name": lfp_electrode_group_name, }, n_delay=n_delay, - orig_interval_list_name="01_s1", + orig_interval_list_name=orig_list_name, orig_valid_times=orig_valid_times, - interval_list_name=new_interval_list_name, - interval_key=new_interval_list_key, + interval_list_name=new_list_name, + interval_key=new_list_key, filter1_name="LFP 0-400 Hz", filter_sampling_rate=30_000, filter2_name="Theta 5-11 Hz", @@ -53,13 +64,21 @@ def lfp_constants(common, mini_copy_name): @pytest.fixture(scope="session") -def add_electrode_group(common, mini_copy_name, lfp_constants): - common.FirFilterParameters().create_standard_filters() - lfp.lfp_electrode.LFPElectrodeGroup.create_lfp_electrode_group( +def add_electrode_group( + firfilters_table, + electrodegroup_table, + mini_copy_name, + lfp_constants, + teardown, +): + firfilters_table.create_standard_filters() + electrodegroup_table.create_lfp_electrode_group( nwb_file_name=mini_copy_name, group_name=lfp_constants.get("lfp_electrode_group_name"), electrode_list=lfp_constants.get("lfp_electrode_ids"), ) + if teardown: + electrodegroup_table.delete(safemode=False) @pytest.fixture(scope="session") @@ -71,7 +90,9 @@ def add_interval(common, lfp_constants): @pytest.fixture(scope="session") -def add_selection(lfp, common, add_interval, lfp_constants): +def add_selection( + lfp, common, add_electrode_group, add_interval, lfp_constants, teardown +): lfp_s_key = { **lfp_constants.get("lfp_eg_key"), "target_interval_list_name": add_interval, @@ -80,11 +101,19 @@ def add_selection(lfp, common, add_interval, lfp_constants): } lfp.v1.LFPSelection.insert1(lfp_s_key, skip_duplicates=True) yield lfp_s_key + if teardown: + lfp.v1.LFPSelection().delete(safemode=False) @pytest.fixture(scope="session") -def lfp_s_key(add_selection): - yield add_selection +def lfp_s_key(lfp_constants, mini_copy_name): + yield { + "nwb_file_name": mini_copy_name, + "lfp_electrode_group_name": lfp_constants.get( + "lfp_electrode_group_name" + ), + "target_interval_list_name": lfp_constants.get("interval_list_name"), + } @pytest.fixture(scope="session") @@ -106,15 +135,20 @@ def lfp_band_sampling_rate(lfp, lfp_merge_key): @pytest.fixture(scope="session") -def add_band_filter(common, lfp_constants, lfp_band_sampling_rate): +def add_band_filter(common, lfp_constants, lfp_band_sampling_rate, teardown): + filter_name = lfp_constants.get("filter2_name") common.FirFilterParameters().add_filter( - lfp_constants.get("filter2_name"), + filter_name, lfp_band_sampling_rate, "bandpass", [4, 5, 11, 12], "theta filter for 1 Khz data", ) yield lfp_constants.get("filter2_name") + if teardown: + (common.FirFilterParameters() & {"filter_name": filter_name}).delete( + safemode=False + ) @pytest.fixture(scope="session") @@ -125,6 +159,7 @@ def add_band_selection( add_interval, lfp_constants, add_band_filter, + teardown, ): lfp_band.LFPBandSelection().set_lfp_band_electrodes( nwb_file_name=mini_copy_name, @@ -138,6 +173,8 @@ def add_band_selection( yield (lfp_band.LFPBandSelection().fetch1("KEY") & lfp_merge_key).fetch1( "KEY" ) + if teardown: + (lfp_band.LFPBandSelection() & lfp_merge_key).delete(safemode=False) @pytest.fixture(scope="session") From 4fad2e19419d0784b269cf92e8accdb4ad6f5d50 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Tue, 16 Jan 2024 12:58:01 -0600 Subject: [PATCH 08/16] Add lfp pipeline tests --- pyproject.toml | 6 ++--- src/spyglass/utils/dj_mixin.py | 10 +++++-- tests/common/test_behav.py | 6 +++-- tests/conftest.py | 14 ++++++---- tests/lfp/conftest.py | 48 ++++++++++++++++++++++++++-------- tests/lfp/test_pipeline.py | 39 +++++++++++++-------------- 6 files changed, 79 insertions(+), 44 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3a482178b..8347836ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -131,15 +131,15 @@ log_level = "INFO" [tool.coverage.run] source = ["*/src/spyglass/*"] -omit = [ +omit = [ # which submodules have no tests "*/__init__.py", "*/_version.py", "*/cli/*", # "*/common/*", - # "*/data_import/*", + "*/data_import/*", "*/decoding/*", "*/figurl_views/*", - "*/lfp/*", + # "*/lfp/*", "*/linearization/*", "*/lock/*", "*/position/*", diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index 8a53743de..3ee0f6292 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -226,7 +226,13 @@ def _check_delete_permission(self) -> None: user_name = LabMember().get_djuser_name(dj_user) for experimenter in set(experimenters): - if user_name not in LabTeam().get_team_members(experimenter): + # Check once with cache, if fails, reload and check again + # On eval as set, reload will only be called once + if user_name not in LabTeam().get_team_members( + experimenter + ) and user_name not in LabTeam().get_team_members( + experimenter, reload=True + ): sess_w_exp = sess_summary & {self._member_pk: experimenter} raise PermissionError( f"User '{user_name}' is not on a team with '{experimenter}'" @@ -259,7 +265,7 @@ def cautious_delete(self, force_permission: bool = False, *args, **kwargs): merge_deletes = self._merge_del_func( self, - restriction=self.restriction, + restriction=self.restriction if self.restriction else None, dry_run=True, disable_warning=True, ) diff --git a/tests/common/test_behav.py b/tests/common/test_behav.py index 65e4addc3..4a3689f41 100644 --- a/tests/common/test_behav.py +++ b/tests/common/test_behav.py @@ -62,8 +62,10 @@ def test_get_pos_interval_name(pos_src, pos_interval_01): assert pos_interval_01 == names, "get_pos_interval_name failed" -def test_convert_epoch(common, pos_interval_01): - this_key = (common.IntervalList & {"interval_list_name": "01_s1"}).fetch1() +def test_convert_epoch(common, mini_dict, pos_interval_01): + this_key = ( + common.IntervalList & mini_dict & {"interval_list_name": "01_s1"} + ).fetch1() ret = common.common_behav.convert_epoch_interval_name_to_position_interval_name( this_key ) diff --git a/tests/conftest.py b/tests/conftest.py index 0a328c767..5ce8e9c77 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -183,17 +183,20 @@ def mini_closed(mini_path): def mini_insert(mini_path, teardown, server, dj_conn): from spyglass.common import Nwbfile, Session # noqa: E402 from spyglass.data_import import insert_sessions # noqa: E402 + from spyglass.utils.dj_merge_tables import delete_downstream_merge # noqa: E402 from spyglass.utils.nwb_helper_fn import close_nwb_files # noqa: E402 dj_logger.info("Inserting test data.") - if len(Nwbfile()) > 0: - Nwbfile().cautious_delete(force_permission=True, safemode=False) - - if server.connected: + if not server.connected: + dj_logger.error("No server connection.") + elif len(Nwbfile()) == 0 and server.connected: insert_sessions(mini_path.name) else: - dj_logger.error("No server connection.") + dj_logger.warning( + "Nwbfile table not empty. Skipping insert, use existing data." + ) + if len(Session()) == 0: dj_logger.error("No sessions inserted.") @@ -201,6 +204,7 @@ def mini_insert(mini_path, teardown, server, dj_conn): close_nwb_files() if teardown: + delete_downstream_merge(table=Nwbfile()) Nwbfile().delete(safemode=False) diff --git a/tests/lfp/conftest.py b/tests/lfp/conftest.py index d6b845ec9..0a546ca62 100644 --- a/tests/lfp/conftest.py +++ b/tests/lfp/conftest.py @@ -1,5 +1,6 @@ import numpy as np import pytest +from pynwb import NWBHDF5IO @pytest.fixture(scope="session") @@ -27,12 +28,14 @@ def electrodegroup_table(lfp): @pytest.fixture(scope="session") -def lfp_constants(common, mini_copy_name): +def lfp_constants(common, mini_copy_name, mini_dict): n_delay = 9 lfp_electrode_group_name = "test" orig_list_name = "01_s1" orig_valid_times = ( - common.IntervalList & f"interval_list_name = '{orig_list_name}'" + common.IntervalList + & mini_dict + & f"interval_list_name = '{orig_list_name}'" ).fetch1("valid_times") new_list_name = orig_list_name + f"_first{n_delay}" new_list_key = { @@ -117,7 +120,7 @@ def lfp_s_key(lfp_constants, mini_copy_name): @pytest.fixture(scope="session") -def populate_lfp(lfp, add_selection): +def populate_lfp(lfp, add_selection, lfp_s_key): lfp.v1.LFPV1().populate(add_selection) yield {"merge_id": (lfp.LFPOutput.LFPV1() & lfp_s_key).fetch1("merge_id")} @@ -127,6 +130,18 @@ def lfp_merge_key(populate_lfp): yield populate_lfp +@pytest.fixture(scope="module") +def lfp_analysis_raw(common, lfp, populate_lfp, mini_dict): + abs_path = (common.AnalysisNwbfile * lfp.v1.LFPV1 & mini_dict).fetch( + "analysis_file_abs_path" + )[0] + assert abs_path is not None, "No NWBFile found." + with NWBHDF5IO(path=str(abs_path), mode="r", load_namespaces=True) as io: + nwbfile = io.read() + assert nwbfile is not None, "NWBFile empty." + yield nwbfile + + @pytest.fixture(scope="session") def lfp_band_sampling_rate(lfp, lfp_merge_key): yield lfp.LFPOutput.merge_get_parent(lfp_merge_key).fetch1( @@ -155,6 +170,7 @@ def add_band_filter(common, lfp_constants, lfp_band_sampling_rate, teardown): def add_band_selection( lfp_band, mini_copy_name, + mini_dict, lfp_merge_key, add_interval, lfp_constants, @@ -170,9 +186,7 @@ def add_band_selection( reference_electrode_list=[-1], lfp_band_sampling_rate=lfp_constants.get("lfp_band_sampling_rate"), ) - yield (lfp_band.LFPBandSelection().fetch1("KEY") & lfp_merge_key).fetch1( - "KEY" - ) + yield (lfp_band.LFPBandSelection & mini_dict).fetch1("KEY") if teardown: (lfp_band.LFPBandSelection() & lfp_merge_key).delete(safemode=False) @@ -188,8 +202,20 @@ def populate_lfp_band(lfp_band, add_band_selection): yield -@pytest.fixture(scope="session") -def mini_eseries(common, mini_copy_name): - yield (common.Raw() & {"nwb_file_name": mini_copy_name}).fetch_nwb()[0][ - "raw" - ] +# @pytest.fixture(scope="session") +# def mini_eseries(common, mini_copy_name): +# yield (common.Raw() & {"nwb_file_name": mini_copy_name}).fetch_nwb()[0][ +# "raw" +# ] + + +@pytest.fixture(scope="module") +def lfp_band_analysis_raw(common, lfp_band, populate_lfp_band, mini_dict): + abs_path = (common.AnalysisNwbfile * lfp_band.LFPBandV1 & mini_dict).fetch( + "analysis_file_abs_path" + )[0] + assert abs_path is not None, "No NWBFile found." + with NWBHDF5IO(path=str(abs_path), mode="r", load_namespaces=True) as io: + nwbfile = io.read() + assert nwbfile is not None, "NWBFile empty." + yield nwbfile diff --git a/tests/lfp/test_pipeline.py b/tests/lfp/test_pipeline.py index a6567f548..86599190d 100644 --- a/tests/lfp/test_pipeline.py +++ b/tests/lfp/test_pipeline.py @@ -1,28 +1,25 @@ -import numpy as np +from pandas import DataFrame, Index -def test_lfp_eseries(common, lfp, mini_eseries, lfp_constants, lfp_merge_key): - lfp_elect_ids = lfp_constants.get("lfp_band_electrode_ids") - - lfp_elect_indices = common.get_electrode_indices( - mini_eseries, lfp_elect_ids +def test_lfp_dataframe(common, lfp, lfp_analysis_raw, lfp_merge_key): + lfp_raw = lfp_analysis_raw.scratch["filtered data"] + df_raw = DataFrame( + lfp_raw.data, index=Index(lfp_raw.timestamps, name="time") ) - lfp_timestamps = np.asarray(mini_eseries.timestamps) - lfp_eseries = lfp.LFPOutput.fetch_nwb(lfp_merge_key)[0]["lfp"] - assert False + df_fetch = (lfp.LFPOutput & lfp_merge_key).fetch1_dataframe() + + assert df_raw.equals(df_fetch), "LFP dataframe not match." -def test_lfp_band_eseries(ccf, lfp_band, lfp_band_key, lfp_constants): - lfp_band_elect_ids = lfp_constants.get("lfp_band_electrode_ids") - lfp_elect_indices = common.get_electrode_indices( - lfp_eseries, lfp_band_electrode_ids +def test_lfp_band_dataframe(lfp_band_analysis_raw, lfp_band, lfp_band_key): + lfp_band_raw = ( + lfp_band_analysis_raw.processing["ecephys"] + .fields["data_interfaces"]["LFP"] + .electrical_series["filtered data"] ) - lfp_timestamps = np.asarray(lfp_eseries.timestamps) - lfp_band_eseries = (lfp_band.LFPBandV1 & lfp_band_key).fetch_nwb()[0][ - "lfp_band" - ] - lfp_band_elect_indices = common.get_electrode_indices( - lfp_band_eseries, lfp_band_electrode_ids + df_raw = DataFrame( + lfp_band_raw.data, index=Index(lfp_band_raw.timestamps, name="time") ) - lfp_band_timestamps = np.asarray(lfp_band_eseries.timestamps) - assert False + df_fetch = (lfp_band.LFPBandV1 & lfp_band_key).fetch1_dataframe() + + assert df_raw.equals(df_fetch), "LFPBand dataframe not match." From 6fdb0ccabd59ae2e0248d1b27c676b79d9d38659 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Tue, 16 Jan 2024 15:26:44 -0600 Subject: [PATCH 09/16] Run pre-commit checks --- src/spyglass/decoding/decoding_merge.py | 4 ++-- tests/common/test_interval_helpers.py | 2 +- tests/conftest.py | 4 +++- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/spyglass/decoding/decoding_merge.py b/src/spyglass/decoding/decoding_merge.py index c49971c78..1752b1165 100644 --- a/src/spyglass/decoding/decoding_merge.py +++ b/src/spyglass/decoding/decoding_merge.py @@ -21,14 +21,14 @@ class DecodingOutput(_Merge, SpyglassMixin): source: varchar(32) """ - class ClusterlessDecodingV1(SpyglassMixin, dj.Part): + class ClusterlessDecodingV1(SpyglassMixin, dj.Part): # noqa: F811 definition = """ -> master --- -> ClusterlessDecodingV1 """ - class SortedSpikesDecodingV1(SpyglassMixin, dj.Part): + class SortedSpikesDecodingV1(SpyglassMixin, dj.Part): # noqa: F811 definition = """ -> master --- diff --git a/tests/common/test_interval_helpers.py b/tests/common/test_interval_helpers.py index de7aad6e7..d4e7eb1ac 100644 --- a/tests/common/test_interval_helpers.py +++ b/tests/common/test_interval_helpers.py @@ -83,7 +83,7 @@ def test_set_difference(set_difference, one, two, expected_result): ], ) def test_intervals_by_length(common, expected_result, min_len, max_len): - # intput is the same across all tests. Could be parametrized as above + # input is the same across all tests. Could be parametrized as above inds = common.common_interval.intervals_by_length( interval_list=np.array([[0, 0], [0, 1], [0, 1e11]]), min_length=min_len, diff --git a/tests/conftest.py b/tests/conftest.py index 5ce8e9c77..e6c331637 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -183,7 +183,9 @@ def mini_closed(mini_path): def mini_insert(mini_path, teardown, server, dj_conn): from spyglass.common import Nwbfile, Session # noqa: E402 from spyglass.data_import import insert_sessions # noqa: E402 - from spyglass.utils.dj_merge_tables import delete_downstream_merge # noqa: E402 + from spyglass.utils.dj_merge_tables import ( + delete_downstream_merge, + ) # noqa: E402 from spyglass.utils.nwb_helper_fn import close_nwb_files # noqa: E402 dj_logger.info("Inserting test data.") From 5d0b15fa6bc756b5e859502823d5e844af9ee9f5 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Tue, 16 Jan 2024 15:52:52 -0600 Subject: [PATCH 10/16] Fix bug --- pyproject.toml | 2 +- src/spyglass/common/common_position.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8347836ab..096d3a185 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ keywords = [ dependencies = [ "pydotplus", "dask", - "position_tools", + "position_tools>=0.1", "track_linearization>=2.3", "non_local_detector", "ripple_detection", diff --git a/src/spyglass/common/common_position.py b/src/spyglass/common/common_position.py index 86ddd1403..732c9779e 100644 --- a/src/spyglass/common/common_position.py +++ b/src/spyglass/common/common_position.py @@ -32,8 +32,8 @@ try: from position_tools import get_centroid except ImportError: - logger.warnint("Please update position_tools to >= 0.1.0") - from position_tools import get_centroid + logger.warning("Please update position_tools to >= 0.1.0") + from position_tools import get_centriod as get_centroid schema = dj.schema("common_position") From ef3d85ef51854322e7011261b54b9fd921ec0f2a Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Wed, 17 Jan 2024 10:41:42 -0600 Subject: [PATCH 11/16] Unpin position_tools for CI --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 096d3a185..8347836ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ keywords = [ dependencies = [ "pydotplus", "dask", - "position_tools>=0.1", + "position_tools", "track_linearization>=2.3", "non_local_detector", "ripple_detection", From 770e53ea8fdaada899fb6caf9a1bc0350b3aa30d Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Wed, 17 Jan 2024 13:05:27 -0600 Subject: [PATCH 12/16] Change download data dir --- .github/workflows/test-conda.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-conda.yml b/.github/workflows/test-conda.yml index d30475527..ede4f2c68 100644 --- a/.github/workflows/test-conda.yml +++ b/.github/workflows/test-conda.yml @@ -45,7 +45,7 @@ jobs: UCSF_BOX_USER: ${{ secrets.UCSF_BOX_USER }} WEBSITE: ftps://ftp.box.com/trodes_to_nwb_test_data/minirec20230622.nwb run: | - mkdir -p ./tests/test_data; mkdir -p ./tests/test_data/raw + mkdir -p ./tests/_data; mkdir -p ./tests/_data/raw wget --recursive --no-verbose --no-host-directories --no-directories \ --user $UCSF_BOX_USER --password $UCSF_BOX_TOKEN \ -P ./tests/test_data/raw $WEBSITE From e8c19114b71d6b62165af7bcdb94c0870115ff3e Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Wed, 17 Jan 2024 13:45:43 -0600 Subject: [PATCH 13/16] Change download data dir 2 --- .github/workflows/test-conda.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test-conda.yml b/.github/workflows/test-conda.yml index ede4f2c68..594a7b2b8 100644 --- a/.github/workflows/test-conda.yml +++ b/.github/workflows/test-conda.yml @@ -44,11 +44,12 @@ jobs: UCSF_BOX_TOKEN: ${{ secrets.UCSF_BOX_TOKEN }} UCSF_BOX_USER: ${{ secrets.UCSF_BOX_USER }} WEBSITE: ftps://ftp.box.com/trodes_to_nwb_test_data/minirec20230622.nwb + RAW_DIR: /home/runner/work/spyglass/spyglass/tests/_data/raw/ run: | - mkdir -p ./tests/_data; mkdir -p ./tests/_data/raw + mkdir -p $RAW_DIR wget --recursive --no-verbose --no-host-directories --no-directories \ --user $UCSF_BOX_USER --password $UCSF_BOX_TOKEN \ - -P ./tests/test_data/raw $WEBSITE + -P $RAW_DIR $WEBSITE - name: Run tests run: | pytest -rP # env vars are set within certain tests From dce153065a9794d0a09d6c60ee09d899d96f1c10 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Wed, 17 Jan 2024 16:40:57 -0600 Subject: [PATCH 14/16] Fix teardown. Coverage 67% --- pyproject.toml | 4 +- src/spyglass/utils/dj_merge_tables.py | 2 +- tests/common/test_behav.py | 1 - tests/common/test_insert.py | 3 +- tests/common/test_lab.py | 9 +-- tests/common/test_nwbfile.py | 1 - tests/conftest.py | 92 +++++++++++++++++++++------ tests/container.py | 9 ++- tests/lfp/conftest.py | 24 +++---- 9 files changed, 93 insertions(+), 52 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8347836ab..521224737 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -118,8 +118,8 @@ minversion = "7.0" addopts = [ "-sv", "-p no:warnings", - "--no-teardown", - "--quiet-spy", + # "--no-teardown", # don't teardown the database after tests + # "--quiet-spy", # don't show logging from spyglass "--show-capture=no", "--pdbcls=IPython.terminal.debugger:TerminalPdb", # use ipython debugger "--cov=spyglass", diff --git a/src/spyglass/utils/dj_merge_tables.py b/src/spyglass/utils/dj_merge_tables.py index eddd77652..5c900b66c 100644 --- a/src/spyglass/utils/dj_merge_tables.py +++ b/src/spyglass/utils/dj_merge_tables.py @@ -758,7 +758,7 @@ def delete_downstream_merge( def _warn_on_restriction(table: dj.Table, restriction: str = None): """Warn if restriction on table object differs from input restriction""" - if restriction is None and table().restriction: + if restriction is None and table.restriction: logger.warn( f"Warning: ignoring table restriction: {table().restriction}.\n\t" + "Please pass restrictions as an arg" diff --git a/tests/common/test_behav.py b/tests/common/test_behav.py index 4a3689f41..c21ed96f6 100644 --- a/tests/common/test_behav.py +++ b/tests/common/test_behav.py @@ -1,5 +1,4 @@ import pytest -from numpy import array_equal from pandas import DataFrame diff --git a/tests/common/test_insert.py b/tests/common/test_insert.py index 9d6f87ef3..6d2fd18b3 100644 --- a/tests/common/test_insert.py +++ b/tests/common/test_insert.py @@ -109,8 +109,9 @@ def test_insert_dio(mini_insert, mini_behavior, mini_restr, common): assert len(events_data) == len(events_raw), "Number of events not match" - event = "Poke1" + event = [p for p in events_raw.keys() if "Poke" in p][0] event_raw = events_raw.get(event) + # event_data = (common.DIOEvents & {"dio_event_name": event}).fetch(as_dict=True)[0] event_data = (common.DIOEvents & {"dio_event_name": event}).fetch1() assert ( diff --git a/tests/common/test_lab.py b/tests/common/test_lab.py index 7c74ecd1c..83ab84c10 100644 --- a/tests/common/test_lab.py +++ b/tests/common/test_lab.py @@ -8,7 +8,7 @@ def common_lab(common): @pytest.fixture -def add_admin(common_lab, teardown): +def add_admin(common_lab): common_lab.LabMember.insert1( dict( lab_member_name="This Admin", @@ -27,12 +27,10 @@ def add_admin(common_lab, teardown): skip_duplicates=True, ) yield - if teardown: - common_lab.LabMember.delete(safe_mode=False) @pytest.fixture -def add_member_team(common_lab, add_admin, teardown): +def add_member_team(common_lab, add_admin): common_lab.LabMember.insert( [ dict( @@ -71,9 +69,6 @@ def add_member_team(common_lab, add_admin, teardown): team_description="This Team Description", ) yield - if teardown: - common_lab.LabMember.delete(safe_mode=False) - common_lab.LabTeam.delete(safe_mode=False) def test_labmember_insert_file_str(mini_insert, common_lab, mini_copy_name): diff --git a/tests/common/test_nwbfile.py b/tests/common/test_nwbfile.py index 518955274..a8671b7ce 100644 --- a/tests/common/test_nwbfile.py +++ b/tests/common/test_nwbfile.py @@ -17,7 +17,6 @@ def lockfile(base_dir, teardown): yield lockfile if teardown: os.remove(lockfile) - lockfile.unlink() def test_get_file_name_error(common_nwbfile): diff --git a/tests/conftest.py b/tests/conftest.py index e6c331637..3c2bc866b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,7 @@ import warnings from contextlib import nullcontext from pathlib import Path +from subprocess import Popen from time import sleep as tsleep import datajoint as dj @@ -14,8 +15,8 @@ # ---------------------- CONSTANTS --------------------- -# globals in pytest_configure: BASE_DIR, SERVER, TEARDOWN, VERBOSE -# download managed by gh-action test-conda, so no need to download here +# globals in pytest_configure: +# BASE_DIR, RAW_DIR, SERVER, TEARDOWN, VERBOSE, TEST_FILE, DOWNLOAD warnings.filterwarnings("ignore", category=UserWarning, module="hdmf") @@ -65,13 +66,15 @@ def pytest_addoption(parser): def pytest_configure(config): - global BASE_DIR, SERVER, TEARDOWN, VERBOSE + global BASE_DIR, RAW_DIR, SERVER, TEARDOWN, VERBOSE, TEST_FILE, DOWNLOAD + TEST_FILE = "minirec20230622.nwb" TEARDOWN = not config.option.no_teardown VERBOSE = not config.option.quiet_spy BASE_DIR = Path(config.option.base_dir).absolute() BASE_DIR.mkdir(parents=True, exist_ok=True) + RAW_DIR = BASE_DIR / "raw" os.environ["SPYGLASS_BASE_DIR"] = str(BASE_DIR) SERVER = DockerMySQLManager( @@ -80,6 +83,52 @@ def pytest_configure(config): null_server=config.option.no_server, verbose=VERBOSE, ) + DOWNLOAD = download_data(verbose=VERBOSE) + + +def data_is_downloaded(): + """Check if data is downloaded.""" + return os.path.exists(RAW_DIR / TEST_FILE) + + +def download_data(verbose=False): + """Download data from BOX using environment variable credentials. + + Note: In gh-actions, this is handled by the test-conda workflow. + """ + if data_is_downloaded(): + return None + UCSF_BOX_USER = os.environ.get("UCSF_BOX_USER") + UCSF_BOX_TOKEN = os.environ.get("UCSF_BOX_TOKEN") + if not all([UCSF_BOX_USER, UCSF_BOX_TOKEN]): + raise ValueError( + "Missing data, no credentials: UCSF_BOX_USER or UCSF_BOX_TOKEN." + ) + data_url = f"ftps://ftp.box.com/trodes_to_nwb_test_data/{TEST_FILE}" + + cmd = [ + "wget", + "--recursive", + "--no-host-directories", + "--no-directories", + "--user", + UCSF_BOX_USER, + "--password", + UCSF_BOX_TOKEN, + "-P", + RAW_DIR, + data_url, + ] + if not verbose: + cmd.insert(cmd.index("--recursive") + 1, "--no-verbose") + cmd_kwargs = dict(stdout=sys.stdout, stderr=sys.stderr) if verbose else {} + + return Popen(cmd, **cmd_kwargs) + + +def pytest_unconfigure(config): + if TEARDOWN: + SERVER.stop() # ------------------- FIXTURES ------------------- @@ -138,13 +187,23 @@ def raw_dir(base_dir): @pytest.fixture(scope="session") def mini_path(raw_dir): - path = raw_dir / "minirec20230622.nwb" + path = raw_dir / TEST_FILE - timeout, wait = 60, 5 # download managed by gh-action test-conda - for _ in range(timeout // wait): # wait for download to finish + # wait for wget download to finish + if DOWNLOAD is not None: + DOWNLOAD.wait() + + # wait for gh-actions download to finish + timeout, wait, found = 60, 5, False + for _ in range(timeout // wait): if path.exists(): + found = True break tsleep(wait) + + if not found: + raise ConnectionError("Download failed.") + yield path @@ -183,31 +242,26 @@ def mini_closed(mini_path): def mini_insert(mini_path, teardown, server, dj_conn): from spyglass.common import Nwbfile, Session # noqa: E402 from spyglass.data_import import insert_sessions # noqa: E402 - from spyglass.utils.dj_merge_tables import ( - delete_downstream_merge, - ) # noqa: E402 from spyglass.utils.nwb_helper_fn import close_nwb_files # noqa: E402 dj_logger.info("Inserting test data.") if not server.connected: - dj_logger.error("No server connection.") - elif len(Nwbfile()) == 0 and server.connected: - insert_sessions(mini_path.name) + raise ConnectionError("No server connection.") + + if len(Nwbfile()) != 0: + dj_logger.warning("Skipping insert, use existing data.") else: - dj_logger.warning( - "Nwbfile table not empty. Skipping insert, use existing data." - ) + insert_sessions(mini_path.name) if len(Session()) == 0: - dj_logger.error("No sessions inserted.") + raise ValueError("No sessions inserted.") yield close_nwb_files() - if teardown: - delete_downstream_merge(table=Nwbfile()) - Nwbfile().delete(safemode=False) + # Note: no need to run deletes in teardown, since we are using teardown + # will remove the container @pytest.fixture(scope="session") diff --git a/tests/container.py b/tests/container.py index d178e1fce..df820f1d0 100644 --- a/tests/container.py +++ b/tests/container.py @@ -131,10 +131,10 @@ def wait(self, timeout=120, wait=5) -> None: if not self.container_status or self.container_status == "exited": self.start() - for _ in range(timeout // wait): + for i in range(timeout // wait): if self.container.health == "healthy": break - self.logger.info(f"Container {self.container_name} starting...") + self.logger.info(f"Container {self.container_name} starting... {i}") time.sleep(wait) self.logger.info( f"Container {self.container_name}, {self.container.health}." @@ -206,12 +206,11 @@ def stop(self, remove=True) -> None: if self.null_server: return None if not self.container_status or self.container_status == "exited": - self.logger.info( - f"Container {self.container_name} already stopped." - ) return + self.container.stop() self.logger.info(f"Container {self.container_name} stopped.") + if remove: self.container.remove() self.logger.info(f"Container {self.container_name} removed.") diff --git a/tests/lfp/conftest.py b/tests/lfp/conftest.py index 0a546ca62..2eb511265 100644 --- a/tests/lfp/conftest.py +++ b/tests/lfp/conftest.py @@ -72,16 +72,18 @@ def add_electrode_group( electrodegroup_table, mini_copy_name, lfp_constants, - teardown, ): firfilters_table.create_standard_filters() + group_name = lfp_constants.get("lfp_electrode_group_name") electrodegroup_table.create_lfp_electrode_group( nwb_file_name=mini_copy_name, - group_name=lfp_constants.get("lfp_electrode_group_name"), + group_name=group_name, electrode_list=lfp_constants.get("lfp_electrode_ids"), ) - if teardown: - electrodegroup_table.delete(safemode=False) + assert len( + electrodegroup_table & {"lfp_electrode_group_name": group_name} + ), "Failed to add LFPElectrodeGroup." + yield @pytest.fixture(scope="session") @@ -94,7 +96,7 @@ def add_interval(common, lfp_constants): @pytest.fixture(scope="session") def add_selection( - lfp, common, add_electrode_group, add_interval, lfp_constants, teardown + lfp, common, add_electrode_group, add_interval, lfp_constants ): lfp_s_key = { **lfp_constants.get("lfp_eg_key"), @@ -104,8 +106,6 @@ def add_selection( } lfp.v1.LFPSelection.insert1(lfp_s_key, skip_duplicates=True) yield lfp_s_key - if teardown: - lfp.v1.LFPSelection().delete(safemode=False) @pytest.fixture(scope="session") @@ -150,7 +150,7 @@ def lfp_band_sampling_rate(lfp, lfp_merge_key): @pytest.fixture(scope="session") -def add_band_filter(common, lfp_constants, lfp_band_sampling_rate, teardown): +def add_band_filter(common, lfp_constants, lfp_band_sampling_rate): filter_name = lfp_constants.get("filter2_name") common.FirFilterParameters().add_filter( filter_name, @@ -160,10 +160,6 @@ def add_band_filter(common, lfp_constants, lfp_band_sampling_rate, teardown): "theta filter for 1 Khz data", ) yield lfp_constants.get("filter2_name") - if teardown: - (common.FirFilterParameters() & {"filter_name": filter_name}).delete( - safemode=False - ) @pytest.fixture(scope="session") @@ -175,7 +171,7 @@ def add_band_selection( add_interval, lfp_constants, add_band_filter, - teardown, + add_electrode_group, ): lfp_band.LFPBandSelection().set_lfp_band_electrodes( nwb_file_name=mini_copy_name, @@ -187,8 +183,6 @@ def add_band_selection( lfp_band_sampling_rate=lfp_constants.get("lfp_band_sampling_rate"), ) yield (lfp_band.LFPBandSelection & mini_dict).fetch1("KEY") - if teardown: - (lfp_band.LFPBandSelection() & lfp_merge_key).delete(safemode=False) @pytest.fixture(scope="session") From 344ca0b69cdbe6b88a0e088c148c31b62600d657 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Thu, 18 Jan 2024 10:54:26 -0600 Subject: [PATCH 15/16] Update changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 895702b43..302c116d3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ - Add `deprecation_factory` to facilitate table migration. #717 - Add Spyglass logger. #730 - IntervalList: Add secondary key `pipeline` #742 +- Increase pytest coverage for `common`, `lfp`, and `utils`. #743 ### Pipelines @@ -31,7 +32,6 @@ - Allow multiple spike waveform features for clusterelss decoding #731 - Reorder notebooks #731 - ## [0.4.3] (November 7, 2023) - Migrate `config` helper scripts to Spyglass codebase. #662 From 587e231b3fed1527ee1cd35c315b7c339e1c8df8 Mon Sep 17 00:00:00 2001 From: Chris Brozdowski Date: Fri, 19 Jan 2024 13:12:01 -0600 Subject: [PATCH 16/16] logger.warn -> logger.warning --- src/spyglass/data_import/insert_sessions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spyglass/data_import/insert_sessions.py b/src/spyglass/data_import/insert_sessions.py index f31b0c09e..329a7be42 100644 --- a/src/spyglass/data_import/insert_sessions.py +++ b/src/spyglass/data_import/insert_sessions.py @@ -101,7 +101,7 @@ def copy_nwb_link_raw_ephys(nwb_file_name, out_nwb_file_name): if os.path.exists(out_nwb_file_abs_path): if debug_mode: return out_nwb_file_abs_path - logger.warn( + logger.warning( f"Output file {out_nwb_file_abs_path} exists and will be " + "overwritten." )