Skip to content

Commit

Permalink
Merge pull request #106 from ttngu207/main
Browse files Browse the repository at this point in the history
Add pytest
  • Loading branch information
MilagrosMarin authored Mar 22, 2024
2 parents 499d85e + 47bbbbc commit 16c845e
Show file tree
Hide file tree
Showing 7 changed files with 758 additions and 531 deletions.
2 changes: 1 addition & 1 deletion .devcontainer/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ RUN \
# pipeline dependencies
apt-get update && \
apt-get install -y gcc ffmpeg graphviz && \
pip install --no-cache-dir -e /tmp/element-deeplabcut[elements,dlc_default] && \
pip install --no-cache-dir -e /tmp/element-deeplabcut[elements,dlc_default,tests] && \
# clean up
rm -rf /tmp/element-deeplabcut/ && \
apt-get clean
Expand Down
1,051 changes: 528 additions & 523 deletions notebooks/tutorial.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,6 @@
"element-session @ git+https://github.com/datajoint/element-session.git",
"element-interface @ git+https://github.com/datajoint/element-interface.git",
],
"tests": ["pytest", "pytest-cov", "shutils"],
},
)
Empty file added tests/__init__.py
Empty file.
166 changes: 166 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import os
from pathlib import Path
import datajoint as dj
import pytest


logger = dj.logger
_tear_down = True

# ---------------------- FIXTURES ----------------------


@pytest.fixture(autouse=True, scope="session")
def dj_config():
"""If dj_local_config exists, load"""
if Path("./dj_local_conf.json").exists():
dj.config.load("./dj_local_conf.json")
dj.config.update(
{
"safemode": False,
"database.host": os.environ.get("DJ_HOST") or dj.config["database.host"],
"database.password": os.environ.get("DJ_PASS")
or dj.config["database.password"],
"database.user": os.environ.get("DJ_USER") or dj.config["database.user"],
}
)
os.environ["DATABASE_PREFIX"] = "test_"
return


@pytest.fixture(autouse=True, scope="session")
def pipeline():
from . import tutorial_pipeline as pipeline

yield {
"lab": pipeline.lab,
"subject": pipeline.subject,
"session": pipeline.session,
"model": pipeline.model,
"train": pipeline.train,
"Device": pipeline.Device
}

if _tear_down:
pipeline.model.schema.drop()
pipeline.train.schema.drop()
pipeline.session.schema.drop()
pipeline.subject.schema.drop()
pipeline.lab.schema.drop()


@pytest.fixture(scope="session")
def insert_upstreams(pipeline):
subject = pipeline["subject"]
session = pipeline["session"]
model = pipeline["model"]

subject.Subject.insert1(
dict(
subject="subject6",
sex="F",
subject_birth_date="2020-01-01",
subject_description="hneih_E105",
),
skip_duplicates=True,
)

session_keys = [
dict(subject="subject6", session_datetime="2021-06-02 14:04:22"),
dict(subject="subject6", session_datetime="2021-06-03 14:43:10"),
]

session.Session.insert(session_keys, skip_duplicates=True)

recording_key = {
"subject": "subject6",
"session_datetime": "2021-06-02 14:04:22",
"recording_id": "1",
}
model.VideoRecording.insert1(
{**recording_key, "device": "Camera1"}, skip_duplicates=True
)

video_files = [
"./example_data/inbox/from_top_tracking-DataJoint-2023-10-11/videos/train1.mp4"
]

model.VideoRecording.File.insert(
[{**recording_key, "file_id": v_idx, "file_path": Path(f)}
for v_idx, f in enumerate(video_files)], skip_duplicates=True
)

yield

if _tear_down:
subject.Subject.delete()


@pytest.fixture(scope="session")
def recording_info(pipeline, insert_upstreams):
model = pipeline["model"]
model.RecordingInfo.populate()

yield

if _tear_down:
model.RecordingInfo.delete()


@pytest.fixture(scope="session")
def insert_dlc_model(pipeline):
model = pipeline["model"]

if not model.Model & {"model_name": "from_top_tracking_model_test"}:
config_file_rel = "from_top_tracking-DataJoint-2023-10-11/config.yaml"

model.Model.insert_new_model(
model_name="from_top_tracking_model_test",
dlc_config=config_file_rel,
shuffle=1,
trainingsetindex=0,
model_description="Model in example data: from_top_tracking model",
prompt=False
)

yield

if _tear_down:
model.Model.delete()


@pytest.fixture(scope="session")
def insert_pose_estimation_task(pipeline, recording_info, insert_dlc_model):
model = pipeline["model"]

recording_key = {
"subject": "subject6",
"session_datetime": "2021-06-02 14:04:22",
"recording_id": "1",
}
task_key = {**recording_key, "model_name": "from_top_tracking_model_test"}

model.PoseEstimationTask.insert1(
{
**task_key,
"task_mode": "load",
"pose_estimation_output_dir": "from_top_tracking-DataJoint-2023-10-11/videos/device_1_recording_1_model_from_top_tracking_100000_maxiters",
}
)

yield

if _tear_down:
model.PoseEstimationTask.delete()


@pytest.fixture(scope="session")
def pose_estimation(pipeline, insert_pose_estimation_task):
model = pipeline["model"]

model.PoseEstimation.populate()

yield

if _tear_down:
model.PoseEstimation.delete()
56 changes: 56 additions & 0 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import datetime


def test_generate_pipeline(pipeline):
subject = pipeline["subject"]
session = pipeline["session"]
train = pipeline["train"]
model = pipeline["model"]
Device = pipeline["Device"]

# Test connection from Subject to Session
assert subject.Subject.full_table_name in session.Session.parents()

# Test connection from Session and Equipment to Scan
assert session.Session.full_table_name in model.VideoRecording.parents()
assert Device.full_table_name in model.VideoRecording.parents()

assert "snapshotindex" in model.Model.heading.secondary_attributes
assert "trainingsetindex" in model.Model.heading.secondary_attributes
assert "x_pos" in model.PoseEstimation.BodyPartPosition.heading.secondary_attributes
assert "y_pos" in model.PoseEstimation.BodyPartPosition.heading.secondary_attributes
assert "likelihood" in model.PoseEstimation.BodyPartPosition.heading.secondary_attributes

assert len(train.schema.list_tables()) == 5


def test_recording_info(pipeline, recording_info):
model = pipeline["model"]
expected_rec_info = {'subject': 'subject6',
'session_datetime': datetime.datetime(2021, 6, 2, 14, 4, 22),
'recording_id': 1,
'px_height': 500,
'px_width': 500,
'nframes': 60000,
'fps': 60,
'recording_datetime': None,
'recording_duration': 1000.0}

rec_info = model.RecordingInfo.fetch1()

assert rec_info == expected_rec_info


def test_pose_estimation(pipeline, pose_estimation):
model = pipeline["model"]

body_parts = model.PoseEstimation.BodyPartPosition.fetch("body_part")

assert set(body_parts) == {"head", "tailbase"}

head_x = (model.PoseEstimation.BodyPartPosition & {"body_part": "head"}).fetch1("x_pos")
tail_y = (model.PoseEstimation.BodyPartPosition & {"body_part": "tailbase"}).fetch1("y_pos")

assert len(head_x) == len(tail_y)
assert (round(head_x.std())) == 129
assert (round(tail_y.std())) == 133
13 changes: 6 additions & 7 deletions notebooks/tutorial_pipeline.py → tests/tutorial_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,12 @@ def get_dlc_processed_data_dir() -> str:


__all__ = [
"Subject",
"Source",
"Lab",
"Protocol",
"User",
"Project",
"Session",
"lab",
"subject",
"session",
"train",
"model",
"Device"
]

# Activate schemas -------------
Expand Down

0 comments on commit 16c845e

Please sign in to comment.