From 37b8fb3a32f99be6457f41b8e4a3696b2d954156 Mon Sep 17 00:00:00 2001 From: Ethan Ho Date: Wed, 13 Dec 2023 12:40:52 -0700 Subject: [PATCH 1/9] cp to tests --- tests/test_relation.py | 311 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 311 insertions(+) create mode 100644 tests/test_relation.py diff --git a/tests/test_relation.py b/tests/test_relation.py new file mode 100644 index 000000000..a5f5da3af --- /dev/null +++ b/tests/test_relation.py @@ -0,0 +1,311 @@ +from inspect import getmembers +import re +import pandas +import numpy as np +from nose.tools import ( + assert_equal, + assert_not_equal, + assert_true, + assert_list_equal, + raises, +) +import datajoint as dj +from datajoint.table import Table +from unittest.mock import patch + +from . import schema + + +def relation_selector(attr): + try: + return issubclass(attr, Table) + except TypeError: + return False + + +class TestRelation: + """ + Test base relations: insert, delete + """ + + @classmethod + def setup_class(cls): + cls.test = schema.TTest() + cls.test_extra = schema.TTestExtra() + cls.test_no_extra = schema.TTestNoExtra() + cls.user = schema.User() + cls.subject = schema.Subject() + cls.experiment = schema.Experiment() + cls.trial = schema.Trial() + cls.ephys = schema.Ephys() + cls.channel = schema.Ephys.Channel() + cls.img = schema.Image() + cls.trash = schema.UberTrash() + + def test_contents(self): + """ + test the ability of tables to self-populate using the contents property + """ + # test contents + assert_true(self.user) + assert_true(len(self.user) == len(self.user.contents)) + u = self.user.fetch(order_by=["username"]) + assert_list_equal( + list(u["username"]), sorted([s[0] for s in self.user.contents]) + ) + + # test prepare + assert_true(self.subject) + assert_true(len(self.subject) == len(self.subject.contents)) + u = self.subject.fetch(order_by=["subject_id"]) + assert_list_equal( + list(u["subject_id"]), sorted([s[0] for s in self.subject.contents]) + ) + + @raises(dj.DataJointError) + def test_misnamed_attribute1(self): + self.user.insert([dict(username="Bob"), dict(user="Alice")]) + + @raises(KeyError) + def test_misnamed_attribute2(self): + self.user.insert1(dict(user="Bob")) + + @raises(KeyError) + def test_extra_attribute1(self): + self.user.insert1(dict(username="Robert", spouse="Alice")) + + def test_extra_attribute2(self): + self.user.insert1( + dict(username="Robert", spouse="Alice"), ignore_extra_fields=True + ) + + @raises(NotImplementedError) + def test_missing_definition(self): + @schema.schema + class MissingDefinition(dj.Manual): + definitions = """ # misspelled definition + id : int + --- + comment : varchar(16) # otherwise everything's normal + """ + + @raises(dj.DataJointError) + def test_empty_insert1(self): + self.user.insert1(()) + + @raises(dj.DataJointError) + def test_empty_insert(self): + self.user.insert([()]) + + @raises(dj.DataJointError) + def test_wrong_arguments_insert(self): + self.user.insert1(("First", "Second")) + + @raises(dj.DataJointError) + def test_wrong_insert_type(self): + self.user.insert1(3) + + def test_insert_select(self): + schema.TTest2.delete() + schema.TTest2.insert(schema.TTest) + assert_equal(len(schema.TTest2()), len(schema.TTest())) + + original_length = len(self.subject) + elements = self.subject.proj(..., s="subject_id") + elements = elements.proj( + "real_id", + "date_of_birth", + "subject_notes", + subject_id="s+1000", + species='"human"', + ) + self.subject.insert(elements, ignore_extra_fields=True) + assert_equal(len(self.subject), 2 * original_length) + + def test_insert_pandas_roundtrip(self): + """ensure fetched frames can be inserted""" + schema.TTest2.delete() + n = len(schema.TTest()) + assert_true(n > 0) + df = schema.TTest.fetch(format="frame") + assert_true(isinstance(df, pandas.DataFrame)) + assert_equal(len(df), n) + schema.TTest2.insert(df) + assert_equal(len(schema.TTest2()), n) + + def test_insert_pandas_userframe(self): + """ + ensure simple user-created frames (1 field, non-custom index) + can be inserted without extra index adjustment + """ + schema.TTest2.delete() + n = len(schema.TTest()) + assert_true(n > 0) + df = pandas.DataFrame(schema.TTest.fetch()) + assert_true(isinstance(df, pandas.DataFrame)) + assert_equal(len(df), n) + schema.TTest2.insert(df) + assert_equal(len(schema.TTest2()), n) + + @raises(dj.DataJointError) + def test_insert_select_ignore_extra_fields0(self): + """need ignore extra fields for insert select""" + self.test_extra.insert1((self.test.fetch("key").max() + 1, 0, 0)) + self.test.insert(self.test_extra) + + def test_insert_select_ignore_extra_fields1(self): + """make sure extra fields works in insert select""" + self.test_extra.delete() + keyno = self.test.fetch("key").max() + 1 + self.test_extra.insert1((keyno, 0, 0)) + self.test.insert(self.test_extra, ignore_extra_fields=True) + assert keyno in self.test.fetch("key") + + def test_insert_select_ignore_extra_fields2(self): + """make sure insert select still works when ignoring extra fields when there are none""" + self.test_no_extra.delete() + self.test_no_extra.insert(self.test, ignore_extra_fields=True) + + def test_insert_select_ignore_extra_fields3(self): + """make sure insert select works for from query result""" + self.test_no_extra.delete() + keystr = str(self.test_extra.fetch("key").max()) + self.test_no_extra.insert( + (self.test_extra & "`key`=" + keystr), ignore_extra_fields=True + ) + + def test_skip_duplicates(self): + """test that skip_duplicates works when inserting from another table""" + self.test_no_extra.delete() + self.test_no_extra.insert( + self.test, ignore_extra_fields=True, skip_duplicates=True + ) + self.test_no_extra.insert( + self.test, ignore_extra_fields=True, skip_duplicates=True + ) + + def test_replace(self): + """ + Test replacing or ignoring duplicate entries + """ + key = dict(subject_id=7) + date = "2015-01-01" + self.subject.insert1(dict(key, real_id=7, date_of_birth=date, subject_notes="")) + assert_equal( + date, str((self.subject & key).fetch1("date_of_birth")), "incorrect insert" + ) + date = "2015-01-02" + self.subject.insert1( + dict(key, real_id=7, date_of_birth=date, subject_notes=""), + skip_duplicates=True, + ) + assert_not_equal( + date, + str((self.subject & key).fetch1("date_of_birth")), + "inappropriate replace", + ) + self.subject.insert1( + dict(key, real_id=7, date_of_birth=date, subject_notes=""), replace=True + ) + assert_equal( + date, str((self.subject & key).fetch1("date_of_birth")), "replace failed" + ) + + def test_delete_quick(self): + """Tests quick deletion""" + tmp = np.array( + [ + (2, "Klara", "monkey", "2010-01-01", ""), + (1, "Peter", "mouse", "2015-01-01", ""), + ], + dtype=self.subject.heading.as_dtype, + ) + self.subject.insert(tmp) + s = self.subject & ( + "subject_id in (%s)" % ",".join(str(r) for r in tmp["subject_id"]) + ) + assert_true(len(s) == 2, "insert did not work.") + s.delete_quick() + assert_true(len(s) == 0, "delete did not work.") + + def test_skip_duplicate(self): + """Tests if duplicates are properly skipped.""" + tmp = np.array( + [ + (2, "Klara", "monkey", "2010-01-01", ""), + (1, "Peter", "mouse", "2015-01-01", ""), + ], + dtype=self.subject.heading.as_dtype, + ) + self.subject.insert(tmp) + tmp = np.array( + [ + (2, "Klara", "monkey", "2010-01-01", ""), + (1, "Peter", "mouse", "2015-01-01", ""), + ], + dtype=self.subject.heading.as_dtype, + ) + self.subject.insert(tmp, skip_duplicates=True) + + @raises(dj.errors.DuplicateError) + def test_not_skip_duplicate(self): + """Tests if duplicates are not skipped.""" + tmp = np.array( + [ + (2, "Klara", "monkey", "2010-01-01", ""), + (2, "Klara", "monkey", "2010-01-01", ""), + (1, "Peter", "mouse", "2015-01-01", ""), + ], + dtype=self.subject.heading.as_dtype, + ) + self.subject.insert(tmp, skip_duplicates=False) + + @raises(dj.errors.MissingAttributeError) + def test_no_error_suppression(self): + """skip_duplicates=True should not suppress other errors""" + self.test.insert([dict(key=100)], skip_duplicates=True) + + def test_blob_insert(self): + """Tests inserting and retrieving blobs.""" + X = np.random.randn(20, 10) + self.img.insert1((1, X)) + Y = self.img.fetch()[0]["img"] + assert_true(np.all(X == Y), "Inserted and retrieved image are not identical") + + def test_drop(self): + """Tests dropping tables""" + dj.config["safemode"] = True + with patch.object(dj.utils, "input", create=True, return_value="yes"): + self.trash.drop() + try: + self.trash.fetch() + raise Exception("Fetched after table dropped.") + except dj.DataJointError: + pass + finally: + dj.config["safemode"] = False + + def test_table_regexp(self): + """Test whether table names are matched by regular expressions""" + tiers = [dj.Imported, dj.Manual, dj.Lookup, dj.Computed] + for name, rel in getmembers(schema, relation_selector): + assert_true( + re.match(rel.tier_regexp, rel.table_name), + "Regular expression does not match for {name}".format(name=name), + ) + for tier in tiers: + assert_true( + issubclass(rel, tier) + or not re.match(tier.tier_regexp, rel.table_name), + "Regular expression matches for {name} but should not".format( + name=name + ), + ) + + def test_table_size(self): + """test getting the size of the table and its indices in bytes""" + number_of_bytes = self.experiment.size_on_disk + assert_true(isinstance(number_of_bytes, int) and number_of_bytes > 100) + + def test_repr_html(self): + assert_true(self.ephys._repr_html_().strip().startswith(" Date: Wed, 13 Dec 2023 12:42:06 -0700 Subject: [PATCH 2/9] nose2pytest test_reconnection --- tests/test_relation.py | 58 ++++++++++++++++++++---------------------- 1 file changed, 27 insertions(+), 31 deletions(-) diff --git a/tests/test_relation.py b/tests/test_relation.py index a5f5da3af..0a6e2f436 100644 --- a/tests/test_relation.py +++ b/tests/test_relation.py @@ -47,20 +47,18 @@ def test_contents(self): test the ability of tables to self-populate using the contents property """ # test contents - assert_true(self.user) - assert_true(len(self.user) == len(self.user.contents)) + assert self.user + assert len(self.user) == len(self.user.contents) u = self.user.fetch(order_by=["username"]) - assert_list_equal( - list(u["username"]), sorted([s[0] for s in self.user.contents]) - ) + assert ( + list(u["username"]) == sorted([s[0] for s in self.user.contents])) # test prepare - assert_true(self.subject) - assert_true(len(self.subject) == len(self.subject.contents)) + assert self.subject + assert len(self.subject) == len(self.subject.contents) u = self.subject.fetch(order_by=["subject_id"]) - assert_list_equal( - list(u["subject_id"]), sorted([s[0] for s in self.subject.contents]) - ) + assert ( + list(u["subject_id"]) == sorted([s[0] for s in self.subject.contents])) @raises(dj.DataJointError) def test_misnamed_attribute1(self): @@ -108,7 +106,7 @@ def test_wrong_insert_type(self): def test_insert_select(self): schema.TTest2.delete() schema.TTest2.insert(schema.TTest) - assert_equal(len(schema.TTest2()), len(schema.TTest())) + assert len(schema.TTest2()) == len(schema.TTest()) original_length = len(self.subject) elements = self.subject.proj(..., s="subject_id") @@ -120,18 +118,18 @@ def test_insert_select(self): species='"human"', ) self.subject.insert(elements, ignore_extra_fields=True) - assert_equal(len(self.subject), 2 * original_length) + assert len(self.subject) == 2 * original_length def test_insert_pandas_roundtrip(self): """ensure fetched frames can be inserted""" schema.TTest2.delete() n = len(schema.TTest()) - assert_true(n > 0) + assert n > 0 df = schema.TTest.fetch(format="frame") - assert_true(isinstance(df, pandas.DataFrame)) - assert_equal(len(df), n) + assert isinstance(df, pandas.DataFrame) + assert len(df) == n schema.TTest2.insert(df) - assert_equal(len(schema.TTest2()), n) + assert len(schema.TTest2()) == n def test_insert_pandas_userframe(self): """ @@ -140,12 +138,12 @@ def test_insert_pandas_userframe(self): """ schema.TTest2.delete() n = len(schema.TTest()) - assert_true(n > 0) + assert n > 0 df = pandas.DataFrame(schema.TTest.fetch()) - assert_true(isinstance(df, pandas.DataFrame)) - assert_equal(len(df), n) + assert isinstance(df, pandas.DataFrame) + assert len(df) == n schema.TTest2.insert(df) - assert_equal(len(schema.TTest2()), n) + assert len(schema.TTest2()) == n @raises(dj.DataJointError) def test_insert_select_ignore_extra_fields0(self): @@ -191,9 +189,8 @@ def test_replace(self): key = dict(subject_id=7) date = "2015-01-01" self.subject.insert1(dict(key, real_id=7, date_of_birth=date, subject_notes="")) - assert_equal( - date, str((self.subject & key).fetch1("date_of_birth")), "incorrect insert" - ) + assert ( + date == str((self.subject & key).fetch1("date_of_birth"))), "incorrect insert" date = "2015-01-02" self.subject.insert1( dict(key, real_id=7, date_of_birth=date, subject_notes=""), @@ -207,9 +204,8 @@ def test_replace(self): self.subject.insert1( dict(key, real_id=7, date_of_birth=date, subject_notes=""), replace=True ) - assert_equal( - date, str((self.subject & key).fetch1("date_of_birth")), "replace failed" - ) + assert ( + date == str((self.subject & key).fetch1("date_of_birth"))), "replace failed" def test_delete_quick(self): """Tests quick deletion""" @@ -224,9 +220,9 @@ def test_delete_quick(self): s = self.subject & ( "subject_id in (%s)" % ",".join(str(r) for r in tmp["subject_id"]) ) - assert_true(len(s) == 2, "insert did not work.") + assert len(s) == 2, "insert did not work." s.delete_quick() - assert_true(len(s) == 0, "delete did not work.") + assert len(s) == 0, "delete did not work." def test_skip_duplicate(self): """Tests if duplicates are properly skipped.""" @@ -270,7 +266,7 @@ def test_blob_insert(self): X = np.random.randn(20, 10) self.img.insert1((1, X)) Y = self.img.fetch()[0]["img"] - assert_true(np.all(X == Y), "Inserted and retrieved image are not identical") + assert np.all(X == Y), "Inserted and retrieved image are not identical" def test_drop(self): """Tests dropping tables""" @@ -305,7 +301,7 @@ def test_table_regexp(self): def test_table_size(self): """test getting the size of the table and its indices in bytes""" number_of_bytes = self.experiment.size_on_disk - assert_true(isinstance(number_of_bytes, int) and number_of_bytes > 100) + assert isinstance(number_of_bytes, int) and number_of_bytes > 100 def test_repr_html(self): - assert_true(self.ephys._repr_html_().strip().startswith(" Date: Wed, 13 Dec 2023 14:16:28 -0700 Subject: [PATCH 3/9] WIP migrate test_relation --- tests/test_relation.py | 334 ++++++++++++++++++++++------------------- 1 file changed, 178 insertions(+), 156 deletions(-) diff --git a/tests/test_relation.py b/tests/test_relation.py index 0a6e2f436..e5e4a0ba0 100644 --- a/tests/test_relation.py +++ b/tests/test_relation.py @@ -1,14 +1,8 @@ +import pytest from inspect import getmembers import re import pandas import numpy as np -from nose.tools import ( - assert_equal, - assert_not_equal, - assert_true, - assert_list_equal, - raises, -) import datajoint as dj from datajoint.table import Table from unittest.mock import patch @@ -16,11 +10,55 @@ from . import schema -def relation_selector(attr): - try: - return issubclass(attr, Table) - except TypeError: - return False +@pytest.fixture +def test(schema_any): + assert len(schema.TTest.contents) + yield schema.TTest() + assert len(schema.TTest.contents) + + +@pytest.fixture +def test_extra(schema_any): + assert len(schema.TTest.contents) + yield schema.TTestExtra() + assert len(schema.TTest.contents) + + +@pytest.fixture +def test_no_extra(schema_any): + assert len(schema.TTest.contents) + yield schema.TTestNoExtra() + assert len(schema.TTest.contents) + + +@pytest.fixture +def user(schema_any): + return schema.User() + + +@pytest.fixture +def subject(schema_any): + return schema.Subject() + + +@pytest.fixture +def experiment(schema_any): + return schema.Experiment() + + +@pytest.fixture +def ephys(schema_any): + return schema.Ephys() + + +@pytest.fixture +def img(schema_any): + return schema.Image() + + +@pytest.fixture +def trash(schema_any): + return schema.UberTrash() class TestRelation: @@ -28,58 +66,38 @@ class TestRelation: Test base relations: insert, delete """ - @classmethod - def setup_class(cls): - cls.test = schema.TTest() - cls.test_extra = schema.TTestExtra() - cls.test_no_extra = schema.TTestNoExtra() - cls.user = schema.User() - cls.subject = schema.Subject() - cls.experiment = schema.Experiment() - cls.trial = schema.Trial() - cls.ephys = schema.Ephys() - cls.channel = schema.Ephys.Channel() - cls.img = schema.Image() - cls.trash = schema.UberTrash() - - def test_contents(self): + def test_contents(self, user, subject): """ test the ability of tables to self-populate using the contents property """ # test contents - assert self.user - assert len(self.user) == len(self.user.contents) - u = self.user.fetch(order_by=["username"]) - assert ( - list(u["username"]) == sorted([s[0] for s in self.user.contents])) + assert user + assert len(user) == len(user.contents) + u = user.fetch(order_by=["username"]) + assert list(u["username"]) == sorted([s[0] for s in user.contents]) # test prepare - assert self.subject - assert len(self.subject) == len(self.subject.contents) - u = self.subject.fetch(order_by=["subject_id"]) - assert ( - list(u["subject_id"]) == sorted([s[0] for s in self.subject.contents])) - - @raises(dj.DataJointError) - def test_misnamed_attribute1(self): - self.user.insert([dict(username="Bob"), dict(user="Alice")]) - - @raises(KeyError) - def test_misnamed_attribute2(self): - self.user.insert1(dict(user="Bob")) - - @raises(KeyError) - def test_extra_attribute1(self): - self.user.insert1(dict(username="Robert", spouse="Alice")) - - def test_extra_attribute2(self): - self.user.insert1( - dict(username="Robert", spouse="Alice"), ignore_extra_fields=True - ) + assert subject + assert len(subject) == len(subject.contents) + u = subject.fetch(order_by=["subject_id"]) + assert list(u["subject_id"]) == sorted([s[0] for s in subject.contents]) + + def test_misnamed_attribute1(self, user): + with pytest.raises(dj.DataJointError): + user.insert([dict(username="Bob"), dict(user="Alice")]) + + def test_misnamed_attribute2(self, user): + with pytest.raises(KeyError): + user.insert1(dict(user="Bob")) + + def test_extra_attribute1(self, user): + with pytest.raises(KeyError): + user.insert1(dict(username="Robert", spouse="Alice")) - @raises(NotImplementedError) - def test_missing_definition(self): - @schema.schema + def test_extra_attribute2(self, user): + user.insert1(dict(username="Robert", spouse="Alice"), ignore_extra_fields=True) + + def test_missing_definition(self, schema_any): class MissingDefinition(dj.Manual): definitions = """ # misspelled definition id : int @@ -87,29 +105,34 @@ class MissingDefinition(dj.Manual): comment : varchar(16) # otherwise everything's normal """ - @raises(dj.DataJointError) - def test_empty_insert1(self): - self.user.insert1(()) + with pytest.raises(NotImplementedError): + schema_any( + MissingDefinition, context=dict(MissingDefinition=MissingDefinition) + ) + + def test_empty_insert1(self, user): + with pytest.raises(dj.DataJointError): + user.insert1(()) - @raises(dj.DataJointError) - def test_empty_insert(self): - self.user.insert([()]) + def test_empty_insert(self, user): + with pytest.raises(dj.DataJointError): + user.insert([()]) - @raises(dj.DataJointError) - def test_wrong_arguments_insert(self): - self.user.insert1(("First", "Second")) + def test_wrong_arguments_insert(self, user): + with pytest.raises(dj.DataJointError): + user.insert1(("First", "Second")) - @raises(dj.DataJointError) - def test_wrong_insert_type(self): - self.user.insert1(3) + def test_wrong_insert_type(self, user): + with pytest.raises(dj.DataJointError): + user.insert1(3) - def test_insert_select(self): + def test_insert_select(self, subject): schema.TTest2.delete() schema.TTest2.insert(schema.TTest) assert len(schema.TTest2()) == len(schema.TTest()) - original_length = len(self.subject) - elements = self.subject.proj(..., s="subject_id") + original_length = len(subject) + elements = subject.proj(..., s="subject_id") elements = elements.proj( "real_id", "date_of_birth", @@ -117,10 +140,10 @@ def test_insert_select(self): subject_id="s+1000", species='"human"', ) - self.subject.insert(elements, ignore_extra_fields=True) - assert len(self.subject) == 2 * original_length + subject.insert(elements, ignore_extra_fields=True) + assert len(subject) == 2 * original_length - def test_insert_pandas_roundtrip(self): + def test_insert_pandas_roundtrip(self, schema_any): """ensure fetched frames can be inserted""" schema.TTest2.delete() n = len(schema.TTest()) @@ -131,7 +154,7 @@ def test_insert_pandas_roundtrip(self): schema.TTest2.insert(df) assert len(schema.TTest2()) == n - def test_insert_pandas_userframe(self): + def test_insert_pandas_userframe(self, schema_any): """ ensure simple user-created frames (1 field, non-custom index) can be inserted without extra index adjustment @@ -145,106 +168,102 @@ def test_insert_pandas_userframe(self): schema.TTest2.insert(df) assert len(schema.TTest2()) == n - @raises(dj.DataJointError) - def test_insert_select_ignore_extra_fields0(self): + def test_insert_select_ignore_extra_fields0(self, test, test_extra): """need ignore extra fields for insert select""" - self.test_extra.insert1((self.test.fetch("key").max() + 1, 0, 0)) - self.test.insert(self.test_extra) + test_extra.insert1((test.fetch("key").max() + 1, 0, 0)) + with pytest.raises(dj.DataJointError): + test.insert(test_extra) - def test_insert_select_ignore_extra_fields1(self): + def test_insert_select_ignore_extra_fields1(self, test, test_extra): """make sure extra fields works in insert select""" - self.test_extra.delete() - keyno = self.test.fetch("key").max() + 1 - self.test_extra.insert1((keyno, 0, 0)) - self.test.insert(self.test_extra, ignore_extra_fields=True) - assert keyno in self.test.fetch("key") + test_extra.delete() + keyno = test.fetch("key").max() + 1 + test_extra.insert1((keyno, 0, 0)) + test.insert(test_extra, ignore_extra_fields=True) + assert keyno in test.fetch("key") - def test_insert_select_ignore_extra_fields2(self): + def test_insert_select_ignore_extra_fields2(self, test_no_extra, test): """make sure insert select still works when ignoring extra fields when there are none""" - self.test_no_extra.delete() - self.test_no_extra.insert(self.test, ignore_extra_fields=True) + test_no_extra.delete() + test_no_extra.insert(test, ignore_extra_fields=True) - def test_insert_select_ignore_extra_fields3(self): + def test_insert_select_ignore_extra_fields3(self, test, test_no_extra, test_extra): """make sure insert select works for from query result""" - self.test_no_extra.delete() - keystr = str(self.test_extra.fetch("key").max()) - self.test_no_extra.insert( - (self.test_extra & "`key`=" + keystr), ignore_extra_fields=True - ) - - def test_skip_duplicates(self): + # Recreate table state from previous tests + keyno = test.fetch("key").max() + 1 + test_extra.insert1((keyno, 0, 0)) + test.insert(test_extra, ignore_extra_fields=True) + + assert len(test_extra.fetch("key")), "test_extra is empty" + test_no_extra.delete() + assert len(test_extra.fetch("key")), "test_extra is empty" + keystr = str(test_extra.fetch("key").max()) + test_no_extra.insert((test_extra & "`key`=" + keystr), ignore_extra_fields=True) + + def test_skip_duplicates(self, test_no_extra, test): """test that skip_duplicates works when inserting from another table""" - self.test_no_extra.delete() - self.test_no_extra.insert( - self.test, ignore_extra_fields=True, skip_duplicates=True - ) - self.test_no_extra.insert( - self.test, ignore_extra_fields=True, skip_duplicates=True - ) + test_no_extra.delete() + test_no_extra.insert(test, ignore_extra_fields=True, skip_duplicates=True) + test_no_extra.insert(test, ignore_extra_fields=True, skip_duplicates=True) - def test_replace(self): + def test_replace(self, subject): """ Test replacing or ignoring duplicate entries """ key = dict(subject_id=7) date = "2015-01-01" - self.subject.insert1(dict(key, real_id=7, date_of_birth=date, subject_notes="")) - assert ( - date == str((self.subject & key).fetch1("date_of_birth"))), "incorrect insert" + subject.insert1(dict(key, real_id=7, date_of_birth=date, subject_notes="")) + assert date == str((subject & key).fetch1("date_of_birth")), "incorrect insert" date = "2015-01-02" - self.subject.insert1( + subject.insert1( dict(key, real_id=7, date_of_birth=date, subject_notes=""), skip_duplicates=True, ) - assert_not_equal( - date, - str((self.subject & key).fetch1("date_of_birth")), - "inappropriate replace", - ) - self.subject.insert1( + assert date != str( + (subject & key).fetch1("date_of_birth") + ), "inappropriate replace" + subject.insert1( dict(key, real_id=7, date_of_birth=date, subject_notes=""), replace=True ) - assert ( - date == str((self.subject & key).fetch1("date_of_birth"))), "replace failed" + assert date == str((subject & key).fetch1("date_of_birth")), "replace failed" - def test_delete_quick(self): + def test_delete_quick(self, subject): """Tests quick deletion""" tmp = np.array( [ (2, "Klara", "monkey", "2010-01-01", ""), (1, "Peter", "mouse", "2015-01-01", ""), ], - dtype=self.subject.heading.as_dtype, + dtype=subject.heading.as_dtype, ) - self.subject.insert(tmp) - s = self.subject & ( + subject.insert(tmp) + s = subject & ( "subject_id in (%s)" % ",".join(str(r) for r in tmp["subject_id"]) ) assert len(s) == 2, "insert did not work." s.delete_quick() assert len(s) == 0, "delete did not work." - def test_skip_duplicate(self): + def test_skip_duplicate(self, subject): """Tests if duplicates are properly skipped.""" tmp = np.array( [ (2, "Klara", "monkey", "2010-01-01", ""), (1, "Peter", "mouse", "2015-01-01", ""), ], - dtype=self.subject.heading.as_dtype, + dtype=subject.heading.as_dtype, ) - self.subject.insert(tmp) + subject.insert(tmp) tmp = np.array( [ (2, "Klara", "monkey", "2010-01-01", ""), (1, "Peter", "mouse", "2015-01-01", ""), ], - dtype=self.subject.heading.as_dtype, + dtype=subject.heading.as_dtype, ) - self.subject.insert(tmp, skip_duplicates=True) + subject.insert(tmp, skip_duplicates=True) - @raises(dj.errors.DuplicateError) - def test_not_skip_duplicate(self): + def test_not_skip_duplicate(self, subject): """Tests if duplicates are not skipped.""" tmp = np.array( [ @@ -252,56 +271,59 @@ def test_not_skip_duplicate(self): (2, "Klara", "monkey", "2010-01-01", ""), (1, "Peter", "mouse", "2015-01-01", ""), ], - dtype=self.subject.heading.as_dtype, + dtype=subject.heading.as_dtype, ) - self.subject.insert(tmp, skip_duplicates=False) + with pytest.raises(dj.errors.DuplicateError): + subject.insert(tmp, skip_duplicates=False) - @raises(dj.errors.MissingAttributeError) - def test_no_error_suppression(self): + def test_no_error_suppression(self, test): """skip_duplicates=True should not suppress other errors""" - self.test.insert([dict(key=100)], skip_duplicates=True) + with pytest.raises(dj.errors.MissingAttributeError): + test.insert([dict(key=100)], skip_duplicates=True) - def test_blob_insert(self): + def test_blob_insert(self, img): """Tests inserting and retrieving blobs.""" X = np.random.randn(20, 10) - self.img.insert1((1, X)) - Y = self.img.fetch()[0]["img"] + img.insert1((1, X)) + Y = img.fetch()[0]["img"] assert np.all(X == Y), "Inserted and retrieved image are not identical" - def test_drop(self): + def test_drop(self, trash): """Tests dropping tables""" dj.config["safemode"] = True with patch.object(dj.utils, "input", create=True, return_value="yes"): - self.trash.drop() + trash.drop() try: - self.trash.fetch() + trash.fetch() raise Exception("Fetched after table dropped.") except dj.DataJointError: pass finally: dj.config["safemode"] = False - def test_table_regexp(self): + def test_table_regexp(self, schema_any): """Test whether table names are matched by regular expressions""" + + def relation_selector(attr): + try: + return issubclass(attr, Table) + except TypeError: + return False + tiers = [dj.Imported, dj.Manual, dj.Lookup, dj.Computed] for name, rel in getmembers(schema, relation_selector): - assert_true( - re.match(rel.tier_regexp, rel.table_name), - "Regular expression does not match for {name}".format(name=name), - ) + assert re.match( + rel.tier_regexp, rel.table_name + ) == "Regular expression does not match for {name}".format(name=name) for tier in tiers: - assert_true( - issubclass(rel, tier) - or not re.match(tier.tier_regexp, rel.table_name), - "Regular expression matches for {name} but should not".format( - name=name - ), - ) - - def test_table_size(self): + assert issubclass(rel, tier) or not re.match( + tier.tier_regexp, rel.table_name + ), "Regular expression matches for {name} but should not".format(name=name) + + def test_table_size(self, experiment): """test getting the size of the table and its indices in bytes""" - number_of_bytes = self.experiment.size_on_disk + number_of_bytes = experiment.size_on_disk assert isinstance(number_of_bytes, int) and number_of_bytes > 100 - def test_repr_html(self): - assert self.ephys._repr_html_().strip().startswith(" Date: Wed, 13 Dec 2023 15:04:59 -0700 Subject: [PATCH 4/9] Move tests to top level --- tests/test_relation.py | 517 ++++++++++++++++++++--------------------- 1 file changed, 256 insertions(+), 261 deletions(-) diff --git a/tests/test_relation.py b/tests/test_relation.py index e5e4a0ba0..f03328886 100644 --- a/tests/test_relation.py +++ b/tests/test_relation.py @@ -61,269 +61,264 @@ def trash(schema_any): return schema.UberTrash() -class TestRelation: +def test_contents(user, subject): """ - Test base relations: insert, delete + test the ability of tables to self-populate using the contents property """ - - def test_contents(self, user, subject): - """ - test the ability of tables to self-populate using the contents property - """ - # test contents - assert user - assert len(user) == len(user.contents) - u = user.fetch(order_by=["username"]) - assert list(u["username"]) == sorted([s[0] for s in user.contents]) - - # test prepare - assert subject - assert len(subject) == len(subject.contents) - u = subject.fetch(order_by=["subject_id"]) - assert list(u["subject_id"]) == sorted([s[0] for s in subject.contents]) - - def test_misnamed_attribute1(self, user): - with pytest.raises(dj.DataJointError): - user.insert([dict(username="Bob"), dict(user="Alice")]) - - def test_misnamed_attribute2(self, user): - with pytest.raises(KeyError): - user.insert1(dict(user="Bob")) - - def test_extra_attribute1(self, user): - with pytest.raises(KeyError): - user.insert1(dict(username="Robert", spouse="Alice")) - - def test_extra_attribute2(self, user): - user.insert1(dict(username="Robert", spouse="Alice"), ignore_extra_fields=True) - - def test_missing_definition(self, schema_any): - class MissingDefinition(dj.Manual): - definitions = """ # misspelled definition - id : int - --- - comment : varchar(16) # otherwise everything's normal - """ - - with pytest.raises(NotImplementedError): - schema_any( - MissingDefinition, context=dict(MissingDefinition=MissingDefinition) - ) - - def test_empty_insert1(self, user): - with pytest.raises(dj.DataJointError): - user.insert1(()) - - def test_empty_insert(self, user): - with pytest.raises(dj.DataJointError): - user.insert([()]) - - def test_wrong_arguments_insert(self, user): - with pytest.raises(dj.DataJointError): - user.insert1(("First", "Second")) - - def test_wrong_insert_type(self, user): - with pytest.raises(dj.DataJointError): - user.insert1(3) - - def test_insert_select(self, subject): - schema.TTest2.delete() - schema.TTest2.insert(schema.TTest) - assert len(schema.TTest2()) == len(schema.TTest()) - - original_length = len(subject) - elements = subject.proj(..., s="subject_id") - elements = elements.proj( - "real_id", - "date_of_birth", - "subject_notes", - subject_id="s+1000", - species='"human"', - ) - subject.insert(elements, ignore_extra_fields=True) - assert len(subject) == 2 * original_length - - def test_insert_pandas_roundtrip(self, schema_any): - """ensure fetched frames can be inserted""" - schema.TTest2.delete() - n = len(schema.TTest()) - assert n > 0 - df = schema.TTest.fetch(format="frame") - assert isinstance(df, pandas.DataFrame) - assert len(df) == n - schema.TTest2.insert(df) - assert len(schema.TTest2()) == n - - def test_insert_pandas_userframe(self, schema_any): - """ - ensure simple user-created frames (1 field, non-custom index) - can be inserted without extra index adjustment + # test contents + assert user + assert len(user) == len(user.contents) + u = user.fetch(order_by=["username"]) + assert list(u["username"]) == sorted([s[0] for s in user.contents]) + + # test prepare + assert subject + assert len(subject) == len(subject.contents) + u = subject.fetch(order_by=["subject_id"]) + assert list(u["subject_id"]) == sorted([s[0] for s in subject.contents]) + +def test_misnamed_attribute1(user): + with pytest.raises(dj.DataJointError): + user.insert([dict(username="Bob"), dict(user="Alice")]) + +def test_misnamed_attribute2(user): + with pytest.raises(KeyError): + user.insert1(dict(user="Bob")) + +def test_extra_attribute1(user): + with pytest.raises(KeyError): + user.insert1(dict(username="Robert", spouse="Alice")) + +def test_extra_attribute2(user): + user.insert1(dict(username="Robert", spouse="Alice"), ignore_extra_fields=True) + +def test_missing_definition(schema_any): + class MissingDefinition(dj.Manual): + definitions = """ # misspelled definition + id : int + --- + comment : varchar(16) # otherwise everything's normal """ - schema.TTest2.delete() - n = len(schema.TTest()) - assert n > 0 - df = pandas.DataFrame(schema.TTest.fetch()) - assert isinstance(df, pandas.DataFrame) - assert len(df) == n - schema.TTest2.insert(df) - assert len(schema.TTest2()) == n - - def test_insert_select_ignore_extra_fields0(self, test, test_extra): - """need ignore extra fields for insert select""" - test_extra.insert1((test.fetch("key").max() + 1, 0, 0)) - with pytest.raises(dj.DataJointError): - test.insert(test_extra) - - def test_insert_select_ignore_extra_fields1(self, test, test_extra): - """make sure extra fields works in insert select""" - test_extra.delete() - keyno = test.fetch("key").max() + 1 - test_extra.insert1((keyno, 0, 0)) - test.insert(test_extra, ignore_extra_fields=True) - assert keyno in test.fetch("key") - - def test_insert_select_ignore_extra_fields2(self, test_no_extra, test): - """make sure insert select still works when ignoring extra fields when there are none""" - test_no_extra.delete() - test_no_extra.insert(test, ignore_extra_fields=True) - - def test_insert_select_ignore_extra_fields3(self, test, test_no_extra, test_extra): - """make sure insert select works for from query result""" - # Recreate table state from previous tests - keyno = test.fetch("key").max() + 1 - test_extra.insert1((keyno, 0, 0)) - test.insert(test_extra, ignore_extra_fields=True) - - assert len(test_extra.fetch("key")), "test_extra is empty" - test_no_extra.delete() - assert len(test_extra.fetch("key")), "test_extra is empty" - keystr = str(test_extra.fetch("key").max()) - test_no_extra.insert((test_extra & "`key`=" + keystr), ignore_extra_fields=True) - - def test_skip_duplicates(self, test_no_extra, test): - """test that skip_duplicates works when inserting from another table""" - test_no_extra.delete() - test_no_extra.insert(test, ignore_extra_fields=True, skip_duplicates=True) - test_no_extra.insert(test, ignore_extra_fields=True, skip_duplicates=True) - - def test_replace(self, subject): - """ - Test replacing or ignoring duplicate entries - """ - key = dict(subject_id=7) - date = "2015-01-01" - subject.insert1(dict(key, real_id=7, date_of_birth=date, subject_notes="")) - assert date == str((subject & key).fetch1("date_of_birth")), "incorrect insert" - date = "2015-01-02" - subject.insert1( - dict(key, real_id=7, date_of_birth=date, subject_notes=""), - skip_duplicates=True, - ) - assert date != str( - (subject & key).fetch1("date_of_birth") - ), "inappropriate replace" - subject.insert1( - dict(key, real_id=7, date_of_birth=date, subject_notes=""), replace=True - ) - assert date == str((subject & key).fetch1("date_of_birth")), "replace failed" - - def test_delete_quick(self, subject): - """Tests quick deletion""" - tmp = np.array( - [ - (2, "Klara", "monkey", "2010-01-01", ""), - (1, "Peter", "mouse", "2015-01-01", ""), - ], - dtype=subject.heading.as_dtype, - ) - subject.insert(tmp) - s = subject & ( - "subject_id in (%s)" % ",".join(str(r) for r in tmp["subject_id"]) - ) - assert len(s) == 2, "insert did not work." - s.delete_quick() - assert len(s) == 0, "delete did not work." - - def test_skip_duplicate(self, subject): - """Tests if duplicates are properly skipped.""" - tmp = np.array( - [ - (2, "Klara", "monkey", "2010-01-01", ""), - (1, "Peter", "mouse", "2015-01-01", ""), - ], - dtype=subject.heading.as_dtype, - ) - subject.insert(tmp) - tmp = np.array( - [ - (2, "Klara", "monkey", "2010-01-01", ""), - (1, "Peter", "mouse", "2015-01-01", ""), - ], - dtype=subject.heading.as_dtype, - ) - subject.insert(tmp, skip_duplicates=True) - - def test_not_skip_duplicate(self, subject): - """Tests if duplicates are not skipped.""" - tmp = np.array( - [ - (2, "Klara", "monkey", "2010-01-01", ""), - (2, "Klara", "monkey", "2010-01-01", ""), - (1, "Peter", "mouse", "2015-01-01", ""), - ], - dtype=subject.heading.as_dtype, + + with pytest.raises(NotImplementedError): + schema_any( + MissingDefinition, context=dict(MissingDefinition=MissingDefinition) ) - with pytest.raises(dj.errors.DuplicateError): - subject.insert(tmp, skip_duplicates=False) - - def test_no_error_suppression(self, test): - """skip_duplicates=True should not suppress other errors""" - with pytest.raises(dj.errors.MissingAttributeError): - test.insert([dict(key=100)], skip_duplicates=True) - - def test_blob_insert(self, img): - """Tests inserting and retrieving blobs.""" - X = np.random.randn(20, 10) - img.insert1((1, X)) - Y = img.fetch()[0]["img"] - assert np.all(X == Y), "Inserted and retrieved image are not identical" - - def test_drop(self, trash): - """Tests dropping tables""" - dj.config["safemode"] = True - with patch.object(dj.utils, "input", create=True, return_value="yes"): - trash.drop() + +def test_empty_insert1(user): + with pytest.raises(dj.DataJointError): + user.insert1(()) + +def test_empty_insert(user): + with pytest.raises(dj.DataJointError): + user.insert([()]) + +def test_wrong_arguments_insert(user): + with pytest.raises(dj.DataJointError): + user.insert1(("First", "Second")) + +def test_wrong_insert_type(user): + with pytest.raises(dj.DataJointError): + user.insert1(3) + +def test_insert_select(subject): + schema.TTest2.delete() + schema.TTest2.insert(schema.TTest) + assert len(schema.TTest2()) == len(schema.TTest()) + + original_length = len(subject) + elements = subject.proj(..., s="subject_id") + elements = elements.proj( + "real_id", + "date_of_birth", + "subject_notes", + subject_id="s+1000", + species='"human"', + ) + subject.insert(elements, ignore_extra_fields=True) + assert len(subject) == 2 * original_length + +def test_insert_pandas_roundtrip(schema_any): + """ensure fetched frames can be inserted""" + schema.TTest2.delete() + n = len(schema.TTest()) + assert n > 0 + df = schema.TTest.fetch(format="frame") + assert isinstance(df, pandas.DataFrame) + assert len(df) == n + schema.TTest2.insert(df) + assert len(schema.TTest2()) == n + +def test_insert_pandas_userframe(schema_any): + """ + ensure simple user-created frames (1 field, non-custom index) + can be inserted without extra index adjustment + """ + schema.TTest2.delete() + n = len(schema.TTest()) + assert n > 0 + df = pandas.DataFrame(schema.TTest.fetch()) + assert isinstance(df, pandas.DataFrame) + assert len(df) == n + schema.TTest2.insert(df) + assert len(schema.TTest2()) == n + +def test_insert_select_ignore_extra_fields0(test, test_extra): + """need ignore extra fields for insert select""" + test_extra.insert1((test.fetch("key").max() + 1, 0, 0)) + with pytest.raises(dj.DataJointError): + test.insert(test_extra) + +def test_insert_select_ignore_extra_fields1(test, test_extra): + """make sure extra fields works in insert select""" + test_extra.delete() + keyno = test.fetch("key").max() + 1 + test_extra.insert1((keyno, 0, 0)) + test.insert(test_extra, ignore_extra_fields=True) + assert keyno in test.fetch("key") + +def test_insert_select_ignore_extra_fields2(test_no_extra, test): + """make sure insert select still works when ignoring extra fields when there are none""" + test_no_extra.delete() + test_no_extra.insert(test, ignore_extra_fields=True) + +def test_insert_select_ignore_extra_fields3(test, test_no_extra, test_extra): + """make sure insert select works for from query result""" + # Recreate table state from previous tests + keyno = test.fetch("key").max() + 1 + test_extra.insert1((keyno, 0, 0)) + test.insert(test_extra, ignore_extra_fields=True) + + assert len(test_extra.fetch("key")), "test_extra is empty" + test_no_extra.delete() + assert len(test_extra.fetch("key")), "test_extra is empty" + keystr = str(test_extra.fetch("key").max()) + test_no_extra.insert((test_extra & "`key`=" + keystr), ignore_extra_fields=True) + +def test_skip_duplicates(test_no_extra, test): + """test that skip_duplicates works when inserting from another table""" + test_no_extra.delete() + test_no_extra.insert(test, ignore_extra_fields=True, skip_duplicates=True) + test_no_extra.insert(test, ignore_extra_fields=True, skip_duplicates=True) + +def test_replace(subject): + """ + Test replacing or ignoring duplicate entries + """ + key = dict(subject_id=7) + date = "2015-01-01" + subject.insert1(dict(key, real_id=7, date_of_birth=date, subject_notes="")) + assert date == str((subject & key).fetch1("date_of_birth")), "incorrect insert" + date = "2015-01-02" + subject.insert1( + dict(key, real_id=7, date_of_birth=date, subject_notes=""), + skip_duplicates=True, + ) + assert date != str( + (subject & key).fetch1("date_of_birth") + ), "inappropriate replace" + subject.insert1( + dict(key, real_id=7, date_of_birth=date, subject_notes=""), replace=True + ) + assert date == str((subject & key).fetch1("date_of_birth")), "replace failed" + +def test_delete_quick(subject): + """Tests quick deletion""" + tmp = np.array( + [ + (2, "Klara", "monkey", "2010-01-01", ""), + (1, "Peter", "mouse", "2015-01-01", ""), + ], + dtype=subject.heading.as_dtype, + ) + subject.insert(tmp) + s = subject & ( + "subject_id in (%s)" % ",".join(str(r) for r in tmp["subject_id"]) + ) + assert len(s) == 2, "insert did not work." + s.delete_quick() + assert len(s) == 0, "delete did not work." + +def test_skip_duplicate(subject): + """Tests if duplicates are properly skipped.""" + tmp = np.array( + [ + (2, "Klara", "monkey", "2010-01-01", ""), + (1, "Peter", "mouse", "2015-01-01", ""), + ], + dtype=subject.heading.as_dtype, + ) + subject.insert(tmp) + tmp = np.array( + [ + (2, "Klara", "monkey", "2010-01-01", ""), + (1, "Peter", "mouse", "2015-01-01", ""), + ], + dtype=subject.heading.as_dtype, + ) + subject.insert(tmp, skip_duplicates=True) + +def test_not_skip_duplicate(subject): + """Tests if duplicates are not skipped.""" + tmp = np.array( + [ + (2, "Klara", "monkey", "2010-01-01", ""), + (2, "Klara", "monkey", "2010-01-01", ""), + (1, "Peter", "mouse", "2015-01-01", ""), + ], + dtype=subject.heading.as_dtype, + ) + with pytest.raises(dj.errors.DuplicateError): + subject.insert(tmp, skip_duplicates=False) + +def test_no_error_suppression(test): + """skip_duplicates=True should not suppress other errors""" + with pytest.raises(dj.errors.MissingAttributeError): + test.insert([dict(key=100)], skip_duplicates=True) + +def test_blob_insert(img): + """Tests inserting and retrieving blobs.""" + X = np.random.randn(20, 10) + img.insert1((1, X)) + Y = img.fetch()[0]["img"] + assert np.all(X == Y), "Inserted and retrieved image are not identical" + +def test_drop(trash): + """Tests dropping tables""" + dj.config["safemode"] = True + with patch.object(dj.utils, "input", create=True, return_value="yes"): + trash.drop() + try: + trash.fetch() + raise Exception("Fetched after table dropped.") + except dj.DataJointError: + pass + finally: + dj.config["safemode"] = False + +def test_table_regexp(schema_any): + """Test whether table names are matched by regular expressions""" + + def relation_selector(attr): try: - trash.fetch() - raise Exception("Fetched after table dropped.") - except dj.DataJointError: - pass - finally: - dj.config["safemode"] = False - - def test_table_regexp(self, schema_any): - """Test whether table names are matched by regular expressions""" - - def relation_selector(attr): - try: - return issubclass(attr, Table) - except TypeError: - return False - - tiers = [dj.Imported, dj.Manual, dj.Lookup, dj.Computed] - for name, rel in getmembers(schema, relation_selector): - assert re.match( - rel.tier_regexp, rel.table_name - ) == "Regular expression does not match for {name}".format(name=name) - for tier in tiers: - assert issubclass(rel, tier) or not re.match( - tier.tier_regexp, rel.table_name - ), "Regular expression matches for {name} but should not".format(name=name) - - def test_table_size(self, experiment): - """test getting the size of the table and its indices in bytes""" - number_of_bytes = experiment.size_on_disk - assert isinstance(number_of_bytes, int) and number_of_bytes > 100 - - def test_repr_html(self, ephys): - assert ephys._repr_html_().strip().startswith(" 100 + +def test_repr_html(ephys): + assert ephys._repr_html_().strip().startswith(" Date: Wed, 13 Dec 2023 20:16:02 -0700 Subject: [PATCH 5/9] Fix typo in test --- tests/test_relation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_relation.py b/tests/test_relation.py index f03328886..6ef5de3c4 100644 --- a/tests/test_relation.py +++ b/tests/test_relation.py @@ -309,7 +309,7 @@ def relation_selector(attr): for name, rel in getmembers(schema, relation_selector): assert re.match( rel.tier_regexp, rel.table_name - ) == "Regular expression does not match for {name}".format(name=name) + ), "Regular expression does not match for {name}".format(name=name) for tier in tiers: assert issubclass(rel, tier) or not re.match( tier.tier_regexp, rel.table_name From b7ce66879ee198bbe498d4b3ce9114bb3db7ba9e Mon Sep 17 00:00:00 2001 From: Ethan Ho Date: Wed, 13 Dec 2023 20:16:22 -0700 Subject: [PATCH 6/9] Format with black --- tests/test_relation.py | 39 ++++++++++++++++++++++++++++++--------- 1 file changed, 30 insertions(+), 9 deletions(-) diff --git a/tests/test_relation.py b/tests/test_relation.py index 6ef5de3c4..05f6fe7c8 100644 --- a/tests/test_relation.py +++ b/tests/test_relation.py @@ -77,21 +77,26 @@ def test_contents(user, subject): u = subject.fetch(order_by=["subject_id"]) assert list(u["subject_id"]) == sorted([s[0] for s in subject.contents]) + def test_misnamed_attribute1(user): with pytest.raises(dj.DataJointError): user.insert([dict(username="Bob"), dict(user="Alice")]) + def test_misnamed_attribute2(user): with pytest.raises(KeyError): user.insert1(dict(user="Bob")) + def test_extra_attribute1(user): with pytest.raises(KeyError): user.insert1(dict(username="Robert", spouse="Alice")) + def test_extra_attribute2(user): user.insert1(dict(username="Robert", spouse="Alice"), ignore_extra_fields=True) + def test_missing_definition(schema_any): class MissingDefinition(dj.Manual): definitions = """ # misspelled definition @@ -101,26 +106,29 @@ class MissingDefinition(dj.Manual): """ with pytest.raises(NotImplementedError): - schema_any( - MissingDefinition, context=dict(MissingDefinition=MissingDefinition) - ) + schema_any(MissingDefinition, context=dict(MissingDefinition=MissingDefinition)) + def test_empty_insert1(user): with pytest.raises(dj.DataJointError): user.insert1(()) + def test_empty_insert(user): with pytest.raises(dj.DataJointError): user.insert([()]) + def test_wrong_arguments_insert(user): with pytest.raises(dj.DataJointError): user.insert1(("First", "Second")) + def test_wrong_insert_type(user): with pytest.raises(dj.DataJointError): user.insert1(3) + def test_insert_select(subject): schema.TTest2.delete() schema.TTest2.insert(schema.TTest) @@ -138,6 +146,7 @@ def test_insert_select(subject): subject.insert(elements, ignore_extra_fields=True) assert len(subject) == 2 * original_length + def test_insert_pandas_roundtrip(schema_any): """ensure fetched frames can be inserted""" schema.TTest2.delete() @@ -149,6 +158,7 @@ def test_insert_pandas_roundtrip(schema_any): schema.TTest2.insert(df) assert len(schema.TTest2()) == n + def test_insert_pandas_userframe(schema_any): """ ensure simple user-created frames (1 field, non-custom index) @@ -163,12 +173,14 @@ def test_insert_pandas_userframe(schema_any): schema.TTest2.insert(df) assert len(schema.TTest2()) == n + def test_insert_select_ignore_extra_fields0(test, test_extra): """need ignore extra fields for insert select""" test_extra.insert1((test.fetch("key").max() + 1, 0, 0)) with pytest.raises(dj.DataJointError): test.insert(test_extra) + def test_insert_select_ignore_extra_fields1(test, test_extra): """make sure extra fields works in insert select""" test_extra.delete() @@ -177,11 +189,13 @@ def test_insert_select_ignore_extra_fields1(test, test_extra): test.insert(test_extra, ignore_extra_fields=True) assert keyno in test.fetch("key") + def test_insert_select_ignore_extra_fields2(test_no_extra, test): """make sure insert select still works when ignoring extra fields when there are none""" test_no_extra.delete() test_no_extra.insert(test, ignore_extra_fields=True) + def test_insert_select_ignore_extra_fields3(test, test_no_extra, test_extra): """make sure insert select works for from query result""" # Recreate table state from previous tests @@ -195,12 +209,14 @@ def test_insert_select_ignore_extra_fields3(test, test_no_extra, test_extra): keystr = str(test_extra.fetch("key").max()) test_no_extra.insert((test_extra & "`key`=" + keystr), ignore_extra_fields=True) + def test_skip_duplicates(test_no_extra, test): """test that skip_duplicates works when inserting from another table""" test_no_extra.delete() test_no_extra.insert(test, ignore_extra_fields=True, skip_duplicates=True) test_no_extra.insert(test, ignore_extra_fields=True, skip_duplicates=True) + def test_replace(subject): """ Test replacing or ignoring duplicate entries @@ -214,14 +230,13 @@ def test_replace(subject): dict(key, real_id=7, date_of_birth=date, subject_notes=""), skip_duplicates=True, ) - assert date != str( - (subject & key).fetch1("date_of_birth") - ), "inappropriate replace" + assert date != str((subject & key).fetch1("date_of_birth")), "inappropriate replace" subject.insert1( dict(key, real_id=7, date_of_birth=date, subject_notes=""), replace=True ) assert date == str((subject & key).fetch1("date_of_birth")), "replace failed" + def test_delete_quick(subject): """Tests quick deletion""" tmp = np.array( @@ -232,13 +247,12 @@ def test_delete_quick(subject): dtype=subject.heading.as_dtype, ) subject.insert(tmp) - s = subject & ( - "subject_id in (%s)" % ",".join(str(r) for r in tmp["subject_id"]) - ) + s = subject & ("subject_id in (%s)" % ",".join(str(r) for r in tmp["subject_id"])) assert len(s) == 2, "insert did not work." s.delete_quick() assert len(s) == 0, "delete did not work." + def test_skip_duplicate(subject): """Tests if duplicates are properly skipped.""" tmp = np.array( @@ -258,6 +272,7 @@ def test_skip_duplicate(subject): ) subject.insert(tmp, skip_duplicates=True) + def test_not_skip_duplicate(subject): """Tests if duplicates are not skipped.""" tmp = np.array( @@ -271,11 +286,13 @@ def test_not_skip_duplicate(subject): with pytest.raises(dj.errors.DuplicateError): subject.insert(tmp, skip_duplicates=False) + def test_no_error_suppression(test): """skip_duplicates=True should not suppress other errors""" with pytest.raises(dj.errors.MissingAttributeError): test.insert([dict(key=100)], skip_duplicates=True) + def test_blob_insert(img): """Tests inserting and retrieving blobs.""" X = np.random.randn(20, 10) @@ -283,6 +300,7 @@ def test_blob_insert(img): Y = img.fetch()[0]["img"] assert np.all(X == Y), "Inserted and retrieved image are not identical" + def test_drop(trash): """Tests dropping tables""" dj.config["safemode"] = True @@ -296,6 +314,7 @@ def test_drop(trash): finally: dj.config["safemode"] = False + def test_table_regexp(schema_any): """Test whether table names are matched by regular expressions""" @@ -315,10 +334,12 @@ def relation_selector(attr): tier.tier_regexp, rel.table_name ), "Regular expression matches for {name} but should not".format(name=name) + def test_table_size(experiment): """test getting the size of the table and its indices in bytes""" number_of_bytes = experiment.size_on_disk assert isinstance(number_of_bytes, int) and number_of_bytes > 100 + def test_repr_html(ephys): assert ephys._repr_html_().strip().startswith(" Date: Wed, 13 Dec 2023 20:19:31 -0700 Subject: [PATCH 7/9] Clean up --- tests/test_relation.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/test_relation.py b/tests/test_relation.py index 05f6fe7c8..5f60b88eb 100644 --- a/tests/test_relation.py +++ b/tests/test_relation.py @@ -12,23 +12,17 @@ @pytest.fixture def test(schema_any): - assert len(schema.TTest.contents) yield schema.TTest() - assert len(schema.TTest.contents) @pytest.fixture def test_extra(schema_any): - assert len(schema.TTest.contents) yield schema.TTestExtra() - assert len(schema.TTest.contents) @pytest.fixture def test_no_extra(schema_any): - assert len(schema.TTest.contents) yield schema.TTestNoExtra() - assert len(schema.TTest.contents) @pytest.fixture From 3998cb48017b67640245a85bf2b257ad024a39fd Mon Sep 17 00:00:00 2001 From: Ethan Ho Date: Thu, 14 Dec 2023 08:37:05 -0700 Subject: [PATCH 8/9] Use fixture for TTest2 --- tests/test_relation.py | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/tests/test_relation.py b/tests/test_relation.py index 5f60b88eb..4a13df448 100644 --- a/tests/test_relation.py +++ b/tests/test_relation.py @@ -15,6 +15,11 @@ def test(schema_any): yield schema.TTest() +@pytest.fixture +def test2(schema_any): + yield schema.TTest2() + + @pytest.fixture def test_extra(schema_any): yield schema.TTestExtra() @@ -123,10 +128,10 @@ def test_wrong_insert_type(user): user.insert1(3) -def test_insert_select(subject): - schema.TTest2.delete() - schema.TTest2.insert(schema.TTest) - assert len(schema.TTest2()) == len(schema.TTest()) +def test_insert_select(subject, test2): + test2.delete() + test2.insert(schema.TTest) + assert len(test2) == len(schema.TTest()) original_length = len(subject) elements = subject.proj(..., s="subject_id") @@ -141,31 +146,31 @@ def test_insert_select(subject): assert len(subject) == 2 * original_length -def test_insert_pandas_roundtrip(schema_any): +def test_insert_pandas_roundtrip(test2): """ensure fetched frames can be inserted""" - schema.TTest2.delete() + test2.delete() n = len(schema.TTest()) assert n > 0 df = schema.TTest.fetch(format="frame") assert isinstance(df, pandas.DataFrame) assert len(df) == n - schema.TTest2.insert(df) - assert len(schema.TTest2()) == n + test2.insert(df) + assert len(test2) == n -def test_insert_pandas_userframe(schema_any): +def test_insert_pandas_userframe(test2): """ ensure simple user-created frames (1 field, non-custom index) can be inserted without extra index adjustment """ - schema.TTest2.delete() + test2.delete() n = len(schema.TTest()) assert n > 0 df = pandas.DataFrame(schema.TTest.fetch()) assert isinstance(df, pandas.DataFrame) assert len(df) == n - schema.TTest2.insert(df) - assert len(schema.TTest2()) == n + test2.insert(df) + assert len(test2) == n def test_insert_select_ignore_extra_fields0(test, test_extra): From 3830c5589d789a04de6d5924b45d486a44323959 Mon Sep 17 00:00:00 2001 From: Ethan Ho Date: Thu, 14 Dec 2023 08:48:18 -0700 Subject: [PATCH 9/9] Use fixture for TTest --- tests/test_relation.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/test_relation.py b/tests/test_relation.py index 4a13df448..2011a1901 100644 --- a/tests/test_relation.py +++ b/tests/test_relation.py @@ -128,10 +128,10 @@ def test_wrong_insert_type(user): user.insert1(3) -def test_insert_select(subject, test2): +def test_insert_select(subject, test, test2): test2.delete() - test2.insert(schema.TTest) - assert len(test2) == len(schema.TTest()) + test2.insert(test) + assert len(test2) == len(test) original_length = len(subject) elements = subject.proj(..., s="subject_id") @@ -146,27 +146,27 @@ def test_insert_select(subject, test2): assert len(subject) == 2 * original_length -def test_insert_pandas_roundtrip(test2): +def test_insert_pandas_roundtrip(test, test2): """ensure fetched frames can be inserted""" test2.delete() - n = len(schema.TTest()) + n = len(test) assert n > 0 - df = schema.TTest.fetch(format="frame") + df = test.fetch(format="frame") assert isinstance(df, pandas.DataFrame) assert len(df) == n test2.insert(df) assert len(test2) == n -def test_insert_pandas_userframe(test2): +def test_insert_pandas_userframe(test, test2): """ ensure simple user-created frames (1 field, non-custom index) can be inserted without extra index adjustment """ test2.delete() - n = len(schema.TTest()) + n = len(test) assert n > 0 - df = pandas.DataFrame(schema.TTest.fetch()) + df = pandas.DataFrame(test.fetch()) assert isinstance(df, pandas.DataFrame) assert len(df) == n test2.insert(df)