Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytest revamp #743

Merged
merged 18 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions .github/workflows/test-conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,6 @@ jobs:
env:
OS: ${{ matrix.os }}
PYTHON: '3.8'
# SPYGLASS_BASE_DIR: ./data
# KACHERY_STORAGE_DIR: ./data/kachery-storage
# DJ_SUPPORT_FILEPATH_MANAGEMENT: True
# services:
# datajoint_test_server:
# image: datajoint/mysql
# ports:
# - 3306:3306
# options: >-
# -e MYSQL_ROOT_PASSWORD=tutorial
steps:
- name: Cancel Workflow Action
uses: styfle/[email protected]
Expand All @@ -49,6 +39,17 @@ jobs:
- name: Install spyglass
run: |
pip install -e .[test]
- name: Download data
env:
UCSF_BOX_TOKEN: ${{ secrets.UCSF_BOX_TOKEN }}
UCSF_BOX_USER: ${{ secrets.UCSF_BOX_USER }}
WEBSITE: ftps://ftp.box.com/trodes_to_nwb_test_data/minirec20230622.nwb
RAW_DIR: /home/runner/work/spyglass/spyglass/tests/_data/raw/
run: |
mkdir -p $RAW_DIR
wget --recursive --no-verbose --no-host-directories --no-directories \
--user $UCSF_BOX_USER --password $UCSF_BOX_TOKEN \
-P $RAW_DIR $WEBSITE
- name: Run tests
run: |
pytest -rP # env vars are set within certain tests
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
- Add `deprecation_factory` to facilitate table migration. #717
- Add Spyglass logger. #730
- IntervalList: Add secondary key `pipeline` #742
- Increase pytest coverage for `common`, `lfp`, and `utils`. #743

### Pipelines

Expand All @@ -31,7 +32,6 @@
- Allow multiple spike waveform features for clusterelss decoding #731
- Reorder notebooks #731


## [0.4.3] (November 7, 2023)

- Migrate `config` helper scripts to Spyglass codebase. #662
Expand Down
39 changes: 38 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ spyglass_cli = "spyglass.cli:cli"
[project.optional-dependencies]
position = ["ffmpeg", "numba>=0.54", "deeplabcut<2.3.0"]
test = [
"docker", # for tests in a container
"pytest", # unit testing
"pytest-cov", # code coverage
"kachery", # database access
Expand Down Expand Up @@ -109,5 +110,41 @@ line-length = 80

[tool.codespell]
skip = '.git,*.pdf,*.svg,*.ipynb,./docs/site/**,temp*'
# Nevers - name in Citation
ignore-words-list = 'nevers'
# Nevers - name in Citation

[tool.pytest.ini_options]
minversion = "7.0"
addopts = [
"-sv",
"-p no:warnings",
# "--no-teardown", # don't teardown the database after tests
# "--quiet-spy", # don't show logging from spyglass
"--show-capture=no",
"--pdbcls=IPython.terminal.debugger:TerminalPdb", # use ipython debugger
"--cov=spyglass",
"--cov-report=term-missing",
"--no-cov-on-fail",
]
testpaths = ["tests"]
log_level = "INFO"

[tool.coverage.run]
source = ["*/src/spyglass/*"]
omit = [ # which submodules have no tests
"*/__init__.py",
"*/_version.py",
"*/cli/*",
# "*/common/*",
"*/data_import/*",
"*/decoding/*",
"*/figurl_views/*",
# "*/lfp/*",
"*/linearization/*",
"*/lock/*",
"*/position/*",
"*/ripple/*",
"*/sharing/*",
"*/spikesorting/*",
# "*/utils/*",
]
158 changes: 70 additions & 88 deletions src/spyglass/common/common_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import ndx_franklab_novela

from spyglass.common.errors import PopulateException
from spyglass.utils.dj_mixin import SpyglassMixin
from spyglass.utils.logging import logger
from spyglass.settings import test_mode
from spyglass.utils import SpyglassMixin, logger
from spyglass.utils.nwb_helper_fn import get_nwb_file

schema = dj.schema("common_device")
Expand Down Expand Up @@ -154,25 +154,9 @@ def _add_device(cls, new_device_dict):
all_values = DataAcquisitionDevice.fetch(
"data_acquisition_device_name"
).tolist()
if name not in all_values:
# no entry with the same name exists, prompt user to add a new entry
logger.info(
f"\nData acquisition device '{name}' was not found in the "
f"database. The current values are: {all_values}. "
"Please ensure that the device you want to add does not already"
" exist in the database under a different name or spelling. "
"If you want to use an existing device in the database, "
"please change the corresponding Device object in the NWB file."
" Entering 'N' will raise an exception."
)
to_db = " to the database"
val = input(f"Add data acquisition device '{name}'{to_db}? (y/N)")
if val.lower() in ["y", "yes"]:
cls.insert1(new_device_dict, skip_duplicates=True)
return
raise PopulateException(
f"User chose not to add device '{name}'{to_db}."
)
if prompt_insert(name=name, all_values=all_values):
cls.insert1(new_device_dict, skip_duplicates=True)
return

# Check if values provided match the values stored in the database
db_dict = (
Expand Down Expand Up @@ -213,28 +197,11 @@ def _add_system(cls, system):
all_values = DataAcquisitionDeviceSystem.fetch(
"data_acquisition_device_system"
).tolist()
if system not in all_values:
logger.info(
f"\nData acquisition device system '{system}' was not found in"
f" the database. The current values are: {all_values}. "
"Please ensure that the system you want to add does not already"
" exist in the database under a different name or spelling. "
"If you want to use an existing system in the database, "
"please change the corresponding Device object in the NWB file."
" Entering 'N' will raise an exception."
)
val = input(
f"Do you want to add data acquisition device system '{system}'"
+ " to the database? (y/N)"
)
if val.lower() in ["y", "yes"]:
key = {"data_acquisition_device_system": system}
DataAcquisitionDeviceSystem.insert1(key, skip_duplicates=True)
else:
raise PopulateException(
"User chose not to add data acquisition device system "
+ f"'{system}' to the database."
)
if prompt_insert(
name=system, all_values=all_values, table_type="system"
):
key = {"data_acquisition_device_system": system}
DataAcquisitionDeviceSystem.insert1(key, skip_duplicates=True)
return system

@classmethod
Expand Down Expand Up @@ -264,30 +231,11 @@ def _add_amplifier(cls, amplifier):
all_values = DataAcquisitionDeviceAmplifier.fetch(
"data_acquisition_device_amplifier"
).tolist()
if amplifier not in all_values:
logger.info(
f"\nData acquisition device amplifier '{amplifier}' was not "
f"found in the database. The current values are: {all_values}. "
"Please ensure that the amplifier you want to add does not "
"already exist in the database under a different name or "
"spelling. If you want to use an existing name in the database,"
" please change the corresponding Device object in the NWB "
"file. Entering 'N' will raise an exception."
)
val = input(
"Do you want to add data acquisition device amplifier "
+ f"'{amplifier}' to the database? (y/N)"
)
if val.lower() in ["y", "yes"]:
key = {"data_acquisition_device_amplifier": amplifier}
DataAcquisitionDeviceAmplifier.insert1(
key, skip_duplicates=True
)
else:
raise PopulateException(
"User chose not to add data acquisition device amplifier "
+ f"'{amplifier}' to the database."
)
if prompt_insert(
name=amplifier, all_values=all_values, table_type="amplifier"
):
key = {"data_acquisition_device_amplifier": amplifier}
DataAcquisitionDeviceAmplifier.insert1(key, skip_duplicates=True)
return amplifier


Expand Down Expand Up @@ -576,27 +524,9 @@ def _add_probe_type(cls, new_probe_type_dict):
"""
probe_type = new_probe_type_dict["probe_type"]
all_values = ProbeType.fetch("probe_type").tolist()
if probe_type not in all_values:
logger.info(
f"\nProbe type '{probe_type}' was not found in the database. "
f"The current values are: {all_values}. "
"Please ensure that the probe type you want to add does not "
"already exist in the database under a different name or "
"spelling. If you want to use an existing name in the "
"database, please change the corresponding Probe object in the "
"NWB file. Entering 'N' will raise an exception."
)
val = input(
f"Do you want to add probe type '{probe_type}' to the database?"
+ " (y/N)"
)
if val.lower() in ["y", "yes"]:
ProbeType.insert1(new_probe_type_dict, skip_duplicates=True)
return
raise PopulateException(
f"User chose not to add probe type '{probe_type}' to the "
+ "database."
)
if prompt_insert(probe_type, all_values, table="probe type"):
ProbeType.insert1(new_probe_type_dict, skip_duplicates=True)
return

# else / entry exists: check whether the values provided match the
# values stored in the database
Expand Down Expand Up @@ -738,3 +668,55 @@ def create_from_nwbfile(
cls.Shank.insert1(shank, skip_duplicates=True)
for electrode in elect_dict.values():
cls.Electrode.insert1(electrode, skip_duplicates=True)


# ---------------------------- Helper functions ----------------------------


# Migrated down to reduce redundancy and centralize 'test_mode' check for pytest
def prompt_insert(
CBroz1 marked this conversation as resolved.
Show resolved Hide resolved
name: str,
all_values: list,
table: str = "Data Acquisition Device",
table_type: str = None,
) -> bool:
"""Prompt user to add an item to the database. Return True if yes.

Assume insert during test mode.

Parameters
----------
name : str
The name of the item to add.
all_values : list
List of all values in the database.
table : str, optional
The name of the table to add to, by default Data Acquisition Device
table_type : str, optional
The type of item to add, by default None. Data Acquisition Device X
"""
if name in all_values:
return False

if test_mode:
return True

if table_type:
table_type += " "

logger.info(
f"{table}{table_type} '{name}' was not found in the"
f"database. The current values are: {all_values}.\n"
"Please ensure that the device you want to add does not already"
"exist in the database under a different name or spelling. If you"
"want to use an existing device in the database, please change the"
"corresponding Device object in the NWB file.\nEntering 'N' will "
"raise an exception."
)
msg = f"Do you want to add {table}{table_type} '{name}' to the database?"
if dj.utils.user_choice(msg).lower() in ["y", "yes"]:
return True

raise PopulateException(
f"User chose not to add {table}{table_type} '{name}' to the database."
)
5 changes: 4 additions & 1 deletion src/spyglass/common/common_dio.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def make(self, key):
key["dio_object_id"] = event_series.object_id
self.insert1(key, skip_duplicates=True)

def plot_all_dio_events(self):
def plot_all_dio_events(self, return_fig=False):
"""Plot all DIO events in the session.

Examples
Expand Down Expand Up @@ -117,3 +117,6 @@ def plot_all_dio_events(self):
plt.suptitle(f"DIO events in {nwb_file_names[0]}")
else:
plt.suptitle(f"DIO events in {', '.join(nwb_file_names)}")

if return_fig:
return plt.gcf()
12 changes: 8 additions & 4 deletions src/spyglass/common/common_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,9 @@ def add_filter(
def _filter_restrict(self, filter_name, fs):
return (
self & {"filter_name": filter_name} & {"filter_sampling_rate": fs}
).fetch1(as_dict=True)
edeno marked this conversation as resolved.
Show resolved Hide resolved
).fetch1()

def plot_magnitude(self, filter_name, fs):
def plot_magnitude(self, filter_name, fs, return_fig=False):
filter_dict = self._filter_restrict(filter_name, fs)
plt.figure()
w, h = signal.freqz(filter_dict["filter_coeff"], worN=65536)
Expand All @@ -178,11 +178,13 @@ def plot_magnitude(self, filter_name, fs):
plt.xlabel("Frequency (Hz)")
plt.ylabel("Magnitude")
plt.title("Frequency Response")
plt.xlim(0, np.max(filter_dict["filter_coeffand_edges"] * 2))
plt.xlim(0, np.max(filter_dict["filter_band_edges"] * 2))
plt.ylim(np.min(magnitude), -1 * np.min(magnitude) * 0.1)
plt.grid(True)
if return_fig:
return plt.gcf()

def plot_fir_filter(self, filter_name, fs):
def plot_fir_filter(self, filter_name, fs, return_fig=False):
filter_dict = self._filter_restrict(filter_name, fs)
plt.figure()
plt.clf()
Expand All @@ -191,6 +193,8 @@ def plot_fir_filter(self, filter_name, fs):
plt.ylabel("Magnitude")
plt.title("Filter Taps")
plt.grid(True)
if return_fig:
return plt.gcf()

def filter_delay(self, filter_name, fs):
return self.calc_filter_delay(
Expand Down
8 changes: 6 additions & 2 deletions src/spyglass/common/common_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def insert_from_nwbfile(cls, nwbf, *, nwb_file_name):

cls.insert1(epoch_dict, skip_duplicates=True)

def plot_intervals(self, figsize=(20, 5)):
def plot_intervals(self, figsize=(20, 5), return_fig=False):
interval_list = pd.DataFrame(self)
fig, ax = plt.subplots(figsize=figsize)
interval_count = 0
Expand All @@ -84,8 +84,10 @@ def plot_intervals(self, figsize=(20, 5)):
ax.set_yticklabels(interval_list.interval_list_name)
ax.set_xlabel("Time [s]")
ax.grid(True)
if return_fig:
return fig

def plot_epoch_pos_raw_intervals(self, figsize=(20, 5)):
def plot_epoch_pos_raw_intervals(self, figsize=(20, 5), return_fig=False):
interval_list = pd.DataFrame(self)
fig, ax = plt.subplots(figsize=(30, 3))

Expand Down Expand Up @@ -145,6 +147,8 @@ def plot_epoch_pos_raw_intervals(self, figsize=(20, 5)):
ax.set_yticklabels(["pos valid times", "raw data valid times", "epoch"])
ax.set_xlabel("Time [s]")
ax.grid(True)
if return_fig:
return fig


def intervals_by_length(interval_list, min_length=0.0, max_length=1e10):
Expand Down
Loading