-
Notifications
You must be signed in to change notification settings - Fork 38
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #106 from ttngu207/main
Add pytest
- Loading branch information
Showing
7 changed files
with
758 additions
and
531 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
import os | ||
from pathlib import Path | ||
import datajoint as dj | ||
import pytest | ||
|
||
|
||
logger = dj.logger | ||
_tear_down = True | ||
|
||
# ---------------------- FIXTURES ---------------------- | ||
|
||
|
||
@pytest.fixture(autouse=True, scope="session") | ||
def dj_config(): | ||
"""If dj_local_config exists, load""" | ||
if Path("./dj_local_conf.json").exists(): | ||
dj.config.load("./dj_local_conf.json") | ||
dj.config.update( | ||
{ | ||
"safemode": False, | ||
"database.host": os.environ.get("DJ_HOST") or dj.config["database.host"], | ||
"database.password": os.environ.get("DJ_PASS") | ||
or dj.config["database.password"], | ||
"database.user": os.environ.get("DJ_USER") or dj.config["database.user"], | ||
} | ||
) | ||
os.environ["DATABASE_PREFIX"] = "test_" | ||
return | ||
|
||
|
||
@pytest.fixture(autouse=True, scope="session") | ||
def pipeline(): | ||
from . import tutorial_pipeline as pipeline | ||
|
||
yield { | ||
"lab": pipeline.lab, | ||
"subject": pipeline.subject, | ||
"session": pipeline.session, | ||
"model": pipeline.model, | ||
"train": pipeline.train, | ||
"Device": pipeline.Device | ||
} | ||
|
||
if _tear_down: | ||
pipeline.model.schema.drop() | ||
pipeline.train.schema.drop() | ||
pipeline.session.schema.drop() | ||
pipeline.subject.schema.drop() | ||
pipeline.lab.schema.drop() | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def insert_upstreams(pipeline): | ||
subject = pipeline["subject"] | ||
session = pipeline["session"] | ||
model = pipeline["model"] | ||
|
||
subject.Subject.insert1( | ||
dict( | ||
subject="subject6", | ||
sex="F", | ||
subject_birth_date="2020-01-01", | ||
subject_description="hneih_E105", | ||
), | ||
skip_duplicates=True, | ||
) | ||
|
||
session_keys = [ | ||
dict(subject="subject6", session_datetime="2021-06-02 14:04:22"), | ||
dict(subject="subject6", session_datetime="2021-06-03 14:43:10"), | ||
] | ||
|
||
session.Session.insert(session_keys, skip_duplicates=True) | ||
|
||
recording_key = { | ||
"subject": "subject6", | ||
"session_datetime": "2021-06-02 14:04:22", | ||
"recording_id": "1", | ||
} | ||
model.VideoRecording.insert1( | ||
{**recording_key, "device": "Camera1"}, skip_duplicates=True | ||
) | ||
|
||
video_files = [ | ||
"./example_data/inbox/from_top_tracking-DataJoint-2023-10-11/videos/train1.mp4" | ||
] | ||
|
||
model.VideoRecording.File.insert( | ||
[{**recording_key, "file_id": v_idx, "file_path": Path(f)} | ||
for v_idx, f in enumerate(video_files)], skip_duplicates=True | ||
) | ||
|
||
yield | ||
|
||
if _tear_down: | ||
subject.Subject.delete() | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def recording_info(pipeline, insert_upstreams): | ||
model = pipeline["model"] | ||
model.RecordingInfo.populate() | ||
|
||
yield | ||
|
||
if _tear_down: | ||
model.RecordingInfo.delete() | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def insert_dlc_model(pipeline): | ||
model = pipeline["model"] | ||
|
||
if not model.Model & {"model_name": "from_top_tracking_model_test"}: | ||
config_file_rel = "from_top_tracking-DataJoint-2023-10-11/config.yaml" | ||
|
||
model.Model.insert_new_model( | ||
model_name="from_top_tracking_model_test", | ||
dlc_config=config_file_rel, | ||
shuffle=1, | ||
trainingsetindex=0, | ||
model_description="Model in example data: from_top_tracking model", | ||
prompt=False | ||
) | ||
|
||
yield | ||
|
||
if _tear_down: | ||
model.Model.delete() | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def insert_pose_estimation_task(pipeline, recording_info, insert_dlc_model): | ||
model = pipeline["model"] | ||
|
||
recording_key = { | ||
"subject": "subject6", | ||
"session_datetime": "2021-06-02 14:04:22", | ||
"recording_id": "1", | ||
} | ||
task_key = {**recording_key, "model_name": "from_top_tracking_model_test"} | ||
|
||
model.PoseEstimationTask.insert1( | ||
{ | ||
**task_key, | ||
"task_mode": "load", | ||
"pose_estimation_output_dir": "from_top_tracking-DataJoint-2023-10-11/videos/device_1_recording_1_model_from_top_tracking_100000_maxiters", | ||
} | ||
) | ||
|
||
yield | ||
|
||
if _tear_down: | ||
model.PoseEstimationTask.delete() | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def pose_estimation(pipeline, insert_pose_estimation_task): | ||
model = pipeline["model"] | ||
|
||
model.PoseEstimation.populate() | ||
|
||
yield | ||
|
||
if _tear_down: | ||
model.PoseEstimation.delete() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import datetime | ||
|
||
|
||
def test_generate_pipeline(pipeline): | ||
subject = pipeline["subject"] | ||
session = pipeline["session"] | ||
train = pipeline["train"] | ||
model = pipeline["model"] | ||
Device = pipeline["Device"] | ||
|
||
# Test connection from Subject to Session | ||
assert subject.Subject.full_table_name in session.Session.parents() | ||
|
||
# Test connection from Session and Equipment to Scan | ||
assert session.Session.full_table_name in model.VideoRecording.parents() | ||
assert Device.full_table_name in model.VideoRecording.parents() | ||
|
||
assert "snapshotindex" in model.Model.heading.secondary_attributes | ||
assert "trainingsetindex" in model.Model.heading.secondary_attributes | ||
assert "x_pos" in model.PoseEstimation.BodyPartPosition.heading.secondary_attributes | ||
assert "y_pos" in model.PoseEstimation.BodyPartPosition.heading.secondary_attributes | ||
assert "likelihood" in model.PoseEstimation.BodyPartPosition.heading.secondary_attributes | ||
|
||
assert len(train.schema.list_tables()) == 5 | ||
|
||
|
||
def test_recording_info(pipeline, recording_info): | ||
model = pipeline["model"] | ||
expected_rec_info = {'subject': 'subject6', | ||
'session_datetime': datetime.datetime(2021, 6, 2, 14, 4, 22), | ||
'recording_id': 1, | ||
'px_height': 500, | ||
'px_width': 500, | ||
'nframes': 60000, | ||
'fps': 60, | ||
'recording_datetime': None, | ||
'recording_duration': 1000.0} | ||
|
||
rec_info = model.RecordingInfo.fetch1() | ||
|
||
assert rec_info == expected_rec_info | ||
|
||
|
||
def test_pose_estimation(pipeline, pose_estimation): | ||
model = pipeline["model"] | ||
|
||
body_parts = model.PoseEstimation.BodyPartPosition.fetch("body_part") | ||
|
||
assert set(body_parts) == {"head", "tailbase"} | ||
|
||
head_x = (model.PoseEstimation.BodyPartPosition & {"body_part": "head"}).fetch1("x_pos") | ||
tail_y = (model.PoseEstimation.BodyPartPosition & {"body_part": "tailbase"}).fetch1("y_pos") | ||
|
||
assert len(head_x) == len(tail_y) | ||
assert (round(head_x.std())) == 129 | ||
assert (round(tail_y.std())) == 133 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters