Skip to content

Commit

Permalink
WIP: Improve coverage behav, dio
Browse files Browse the repository at this point in the history
  • Loading branch information
CBroz1 committed Jan 8, 2024
1 parent 8e15822 commit 7ff01ad
Show file tree
Hide file tree
Showing 9 changed files with 196 additions and 12 deletions.
5 changes: 4 additions & 1 deletion src/spyglass/common/common_dio.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def make(self, key):
key["dio_object_id"] = event_series.object_id
self.insert1(key, skip_duplicates=True)

def plot_all_dio_events(self):
def plot_all_dio_events(self, return_fig=False):
"""Plot all DIO events in the session.
Examples
Expand Down Expand Up @@ -117,3 +117,6 @@ def plot_all_dio_events(self):
plt.suptitle(f"DIO events in {nwb_file_names[0]}")
else:
plt.suptitle(f"DIO events in {', '.join(nwb_file_names)}")

if return_fig:
return plt.gcf()
4 changes: 2 additions & 2 deletions src/spyglass/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def __init__(self, base_dir: str = None, **kwargs):
self.supplied_base_dir = base_dir
self._config = dict()
self.config_defaults = dict(prepopulate=True)
self._debug_mode = False
self._test_mode = False
self._debug_mode = kwargs.get("debug_mode", False)
self._test_mode = kwargs.get("test_mode", False)
self._dlc_base = None

self.relative_dirs = {
Expand Down
15 changes: 15 additions & 0 deletions tests/common/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,23 @@ def mini_pos_series(mini_pos):
yield next(iter(mini_pos))


@pytest.fixture(scope="session")
def mini_pos_interval_dict(common):
yield {"interval_list_name": common.PositionSource.get_pos_interval_name(0)}


@pytest.fixture(scope="session")
def mini_pos_tbl(common, mini_pos_series):
yield common.PositionSource.SpatialSeries * common.RawPosition.PosObject & {
"name": mini_pos_series
}


@pytest.fixture(scope="session")
def pos_src(common):
yield common.PositionSource()


@pytest.fixture(scope="session")
def pos_interval_01(pos_src):
yield [pos_src.get_pos_interval_name(x) for x in range(1)]
72 changes: 72 additions & 0 deletions tests/common/test_behav.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import pytest
from numpy import array_equal
from pandas import DataFrame


def test_invalid_interval(pos_src):
"""Test invalid interval"""
with pytest.raises(ValueError):
pos_src.get_pos_interval_name("invalid_interval")


def test_invalid_epoch_num(common):
"""Test invalid epoch num"""
with pytest.raises(ValueError):
common.PositionSource.get_epoch_num("invalid_epoch_num")


def test_raw_position_fetchnwb(common, mini_pos, mini_pos_interval_dict):
"""Test RawPosition fetch nwb"""
fetched = DataFrame(
(common.RawPosition & mini_pos_interval_dict)
.fetch_nwb()[0]["raw_position"]
.data
)
raw = DataFrame(mini_pos["led_0_series_0"].data)
# compare with mini_pos
assert fetched.equals(raw), "RawPosition fetch_nwb failed"


@pytest.mark.skip(reason="No video files in mini")
def test_videofile_no_transaction(common, mini_restr):
"""Test no transaction"""
common.VideoFile()._no_transaction_make(mini_restr)


@pytest.mark.skip(reason="No video files in mini")
def test_videofile_update_entries(common):
"""Test update entries"""
common.VideoFile().update_entries()


@pytest.mark.skip(reason="No video files in mini")
def test_videofile_getabspath(common, mini_restr):
"""Test get absolute path"""
common.VideoFile().getabspath(mini_restr)


def test_posinterval_no_transaction(verbose_context, common, mini_restr):
"""Test no transaction"""
before = common.PositionIntervalMap().fetch()
with verbose_context:
common.PositionIntervalMap()._no_transaction_make(mini_restr)
after = common.PositionIntervalMap().fetch()
assert array_equal(
before, after
), "PositionIntervalMap no_transaction had unexpected effect"


def test_get_pos_interval_name(pos_src, mini_copy_name, pos_interval_01):
"""Test get pos interval name"""
names = [f"pos {x} valid times" for x in range(1)]
assert pos_interval_01 == names, "get_pos_interval_name failed"


def test_convert_epoch(common, pos_interval_01):
this_key = (common.IntervalList & {"interval_list_name": "01_s1"}).fetch1()
ret = common.common_behav.convert_epoch_interval_name_to_position_interval_name(
this_key
)
assert (
ret == pos_interval_01[0]
), "convert_epoch_interval_name_to_position_interval_name failed"
40 changes: 40 additions & 0 deletions tests/common/test_device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import pytest
from numpy import array_equal


def test_invalid_device(common, populate_exception):
device_dict = common.DataAcquisitionDevice.fetch(as_dict=True)[0]
device_dict["other"] = "invalid"
with pytest.raises(populate_exception):
common.DataAcquisitionDevice._add_device(device_dict)


def test_spikegadets_system_alias(mini_insert, common):
assert (
common.DataAcquisitionDevice()._add_system("MCU") == "SpikeGadgets"
), "SpikeGadgets MCU alias not found"


def test_invalid_probe(common, populate_exception):
probe_dict = common.ProbeType.fetch(as_dict=True)[0]
probe_dict["other"] = "invalid"
with pytest.raises(populate_exception):
common.Probe._add_probe_type(probe_dict)


def test_create_probe(common, mini_devices, mini_path, mini_copy_name):
probe_id = common.Probe.fetch("KEY", as_dict=True)[0]
probe_type = common.ProbeType.fetch("KEY", as_dict=True)[0]
before = common.Probe.fetch()
common.Probe.create_from_nwbfile(
nwb_file_name=mini_copy_name.split("/")[-1],
nwb_device_name="probe 0",
contact_side_numbering=False,
**probe_id,
**probe_type,
)
after = common.Probe.fetch()
# Because already inserted, expect no change
assert array_equal(
before, after
), "Probe create_from_nwbfile had unexpected effect"
31 changes: 31 additions & 0 deletions tests/common/test_dio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pytest
from numpy import allclose, array


@pytest.fixture(scope="session")
def dio_events(common):
yield common.common_dio.DIOEvents


@pytest.fixture(scope="session")
def dio_fig(mini_insert, dio_events, mini_restr):
yield (dio_events & mini_restr).plot_all_dio_events(return_fig=True)


def test_plot_dio_axes(dio_fig, dio_events):
"""Check that all events are plotted."""
events_fig = set(x.yaxis.get_label().get_text() for x in dio_fig.get_axes())
events_fetch = set(dio_events.fetch("dio_event_name"))
assert events_fig == events_fetch, "Mismatch in events plotted."


def test_plot_dio_data(common, dio_fig):
"""Hash summary of figure object."""
data_fig = dio_fig.get_axes()[0].lines[0].get_xdata()
data_block = (
common.IntervalList & 'interval_list_name LIKE "raw%"'
).fetch1("valid_times")
data_fetch = array((data_block[0][0], data_block[-1][1]))
assert allclose(
data_fig, data_fetch, atol=1e-8
), "Mismatch in data plotted."
4 changes: 2 additions & 2 deletions tests/common/test_insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,13 @@ def test_insert_pos(
assert data_obj_id == raw_obj_id, "PosObject insertion error"


def test_fetch_pos(
def test_fetch_posobj(
mini_insert, common, mini_pos, mini_pos_series, mini_pos_tbl
):
pos_key = (
common.PositionSource.SpatialSeries & mini_pos_tbl.fetch("KEY")
).fetch(as_dict=True)[0]
pos_df = (common.RawPosition.PosObject & pos_key).fetch1_dataframe()
pos_df = (common.RawPosition & pos_key).fetch1_dataframe().iloc[:, 0:2]

series = mini_pos[mini_pos_series]
raw_df = DataFrame(
Expand Down
35 changes: 29 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import datajoint as dj
import pynwb
import pytest
from datajoint.logging import logger
from datajoint.logging import logger as dj_logger

from .container import DockerMySQLManager

Expand Down Expand Up @@ -139,6 +139,13 @@ def mini_path(raw_dir):
yield raw_dir / "test.nwb"


@pytest.fixture(scope="session")
def mini_copy_name(mini_path):
from spyglass.utils.nwb_helper_fn import get_nwb_copy_filename # noqa: E402

yield get_nwb_copy_filename(mini_path)


@pytest.fixture(scope="session")
def mini_download():
# test_path = (
Expand Down Expand Up @@ -181,21 +188,23 @@ def mini_closed(mini_path):
yield nwbfile


@pytest.fixture(scope="session")
@pytest.fixture(autouse=True, scope="session")
def mini_insert(mini_path, teardown, server, dj_conn):
from spyglass.common import Nwbfile, Session # noqa: E402
from spyglass.data_import import insert_sessions # noqa: E402
from spyglass.utils.nwb_helper_fn import close_nwb_files # noqa: E402

dj_logger.info("Inserting test data.")

if len(Nwbfile()) > 0:
Nwbfile().delete(safemode=False)

if server.connected:
insert_sessions(mini_path.name)
else:
logger.error("No server connection.")
dj_logger.error("No server connection.")
if len(Session()) == 0:
logger.error("No sessions inserted.")
dj_logger.error("No sessions inserted.")

yield

Expand Down Expand Up @@ -230,18 +239,32 @@ def settings(dj_conn):
yield settings


@pytest.fixture(scope="session")
def populate_exception():
from spyglass.common.errors import PopulateException

yield PopulateException


# ------------------ GENERAL FUNCTION ------------------


class QuietStdOut:
"""If quiet_spy, used to quiet prints, teardowns and table.delete prints"""

def __init__(self):
from spyglass.utils import logger as spyglass_logger

self.spy_logger = spyglass_logger
self.previous_level = None

def __enter__(self):
logger.setLevel("CRITICAL")
self.previous_level = self.spy_logger.getEffectiveLevel()
self.spy_logger.setLevel("CRITICAL")
self._original_stdout = sys.stdout
sys.stdout = open(os.devnull, "w")

def __exit__(self, exc_type, exc_val, exc_tb):
logger.setLevel("INFO")
self.spy_logger.setLevel(self.previous_level)
sys.stdout.close()
sys.stdout = self._original_stdout
2 changes: 1 addition & 1 deletion tests/data_import/test_insert_sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_open_path(mini_path, mini_open):


def test_copy_link(mini_path, settings, mini_closed, copy_nwb_link_raw_ephys):
"""Test readabilty after moving the linking raw file, breaking link"""
"""Test readability after moving the linking raw file, breaking link"""
new_path = Path(settings.raw_dir) / "no_ephys.nwb"
new_moved = Path(settings.temp_dir) / "no_ephys_moved.nwb"

Expand Down

0 comments on commit 7ff01ad

Please sign in to comment.