diff --git a/.github/workflows/test-suite.yml b/.github/workflows/test-suite.yml index 2b11f8d..1f92b82 100644 --- a/.github/workflows/test-suite.yml +++ b/.github/workflows/test-suite.yml @@ -93,4 +93,4 @@ jobs: run: "hatch run test:check_types" - name: "Run tests" if: steps.filters.outputs.src == 'true' || steps.filters.outputs.workflows == 'true' || github.event.schedule != '' - run: hatch test + run: env TEST_NO_RISK_SEGFAULTS=true hatch test diff --git a/databasez/core/connection.py b/databasez/core/connection.py index 92683bb..a8705e8 100644 --- a/databasez/core/connection.py +++ b/databasez/core/connection.py @@ -231,6 +231,7 @@ async def _aexit_raw(self) -> bool: self._database._connection = None return closing + @multiloop_protector(False, passthrough_timeout=True) # fail when specifying timeout async def _aexit(self) -> typing.Optional[Thread]: if self._full_isolation: assert self._connection_thread_lock is not None @@ -248,7 +249,6 @@ async def _aexit(self) -> typing.Optional[Thread]: await self._aexit_raw() return None - @multiloop_protector(False, passthrough_timeout=True) # fail when specifying timeout async def __aexit__( self, exc_type: typing.Optional[typing.Type[BaseException]] = None, @@ -257,7 +257,7 @@ async def __aexit__( ) -> None: thread = await self._aexit() if thread is not None and thread is not current_thread(): - while thread.is_alive(): # noqa: ASYNC110 + while thread.is_alive(): await asyncio.sleep(self.poll_interval) thread.join(1) diff --git a/databasez/core/database.py b/databasez/core/database.py index 3b88b6d..3439886 100644 --- a/databasez/core/database.py +++ b/databasez/core/database.py @@ -211,7 +211,7 @@ class Database: def __init__( self, - url: typing.Optional[typing.Union[str, DatabaseURL, URL, Database]] = None, + url: typing.Union[str, DatabaseURL, URL, Database, None] = None, *, force_rollback: typing.Union[bool, None] = None, config: typing.Optional["DictAny"] = None, diff --git a/databasez/testclient.py b/databasez/testclient.py index cf40054..ebdeb0f 100644 --- a/databasez/testclient.py +++ b/databasez/testclient.py @@ -3,7 +3,7 @@ import typing from typing import Any -import sqlalchemy as sa +import sqlalchemy from sqlalchemy.exc import OperationalError, ProgrammingError from sqlalchemy_utils.functions.database import _sqlite_file_exists from sqlalchemy_utils.functions.orm import quote @@ -45,7 +45,7 @@ class DatabaseTestClient(Database): def __init__( self, - url: typing.Union[str, "DatabaseURL", "sa.URL", Database], + url: typing.Union[str, DatabaseURL, sqlalchemy.URL, Database, None] = None, *, force_rollback: typing.Union[bool, None] = None, full_isolation: typing.Union[bool, None] = None, @@ -66,13 +66,6 @@ def __init__( test_prefix = self.testclient_default_test_prefix self._setup_executed_init = False if isinstance(url, Database): - test_database_url = ( - url.url.replace(database=f"{test_prefix}{url.url.database}") - if test_prefix - else url.url - ) - # replace only if not cloning a DatabaseTestClient - self.test_db_url = str(getattr(url, "test_db_url", test_database_url)) self.use_existing = getattr(url, "use_existing", use_existing) self.drop = getattr(url, "drop", drop_database) # only if explicit set to False @@ -80,9 +73,12 @@ def __init__( self.setup_protected(self.testclient_operation_timeout_init) self._setup_executed_init = True super().__init__(url, force_rollback=force_rollback, **options) - # fix url - if str(self.url) != self.test_db_url: - self.url = test_database_url + if hasattr(url, "test_db_url"): + self.test_db_url = url.test_db_url + else: + if test_prefix: + self.url = self.url.replace(database=f"{test_prefix}{self.url.database}") + self.test_db_url = str(self.url) else: if lazy_setup is None: lazy_setup = self.testclient_default_lazy_setup @@ -90,25 +86,22 @@ def __init__( force_rollback = self.testclient_default_force_rollback if poll_interval is None: poll_interval = self.testclient_default_poll_interval - url = url if isinstance(url, DatabaseURL) else DatabaseURL(url) - test_database_url = ( - url.replace(database=f"{test_prefix}{url.database}") if test_prefix else url - ) - self.test_db_url = str(test_database_url) self.use_existing = use_existing self.drop = drop_database - # if None or False - if not lazy_setup: - self.setup_protected(self.testclient_operation_timeout_init) - self._setup_executed_init = True - super().__init__( - test_database_url, + url, force_rollback=force_rollback, full_isolation=full_isolation, poll_interval=poll_interval, **options, ) + if test_prefix: + self.url = self.url.replace(database=f"{test_prefix}{self.url.database}") + self.test_db_url = str(self.url) + # if None or False + if not lazy_setup: + self.setup_protected(self.testclient_operation_timeout_init) + self._setup_executed_init = True async def setup(self) -> None: """ @@ -150,7 +143,7 @@ async def is_database_exist(self) -> Any: return await self.database_exists(self.test_db_url) @classmethod - async def database_exists(cls, url: typing.Union[str, "sa.URL", DatabaseURL]) -> bool: + async def database_exists(cls, url: typing.Union[str, "sqlalchemy.URL", DatabaseURL]) -> bool: url = url if isinstance(url, DatabaseURL) else DatabaseURL(url) database = url.database dialect_name = url.sqla_url.get_dialect(True).name @@ -160,7 +153,9 @@ async def database_exists(cls, url: typing.Union[str, "sa.URL", DatabaseURL]) -> url = url.replace(database=db) async with Database(url, full_isolation=False, force_rollback=False) as db_client: try: - return bool(await _get_scalar_result(db_client.engine, sa.text(text))) + return bool( + await _get_scalar_result(db_client.engine, sqlalchemy.text(text)) + ) except (ProgrammingError, OperationalError): pass return False @@ -172,7 +167,7 @@ async def database_exists(cls, url: typing.Union[str, "sa.URL", DatabaseURL]) -> "WHERE SCHEMA_NAME = '%s'" % database ) async with Database(url, full_isolation=False, force_rollback=False) as db_client: - return bool(await _get_scalar_result(db_client.engine, sa.text(text))) + return bool(await _get_scalar_result(db_client.engine, sqlalchemy.text(text))) elif dialect_name == "sqlite": if database: @@ -185,14 +180,14 @@ async def database_exists(cls, url: typing.Union[str, "sa.URL", DatabaseURL]) -> text = "SELECT 1" async with Database(url, full_isolation=False, force_rollback=False) as db_client: try: - return bool(await _get_scalar_result(db_client.engine, sa.text(text))) + return bool(await _get_scalar_result(db_client.engine, sqlalchemy.text(text))) except (ProgrammingError, OperationalError): return False @classmethod async def create_database( cls, - url: typing.Union[str, "sa.URL", DatabaseURL], + url: typing.Union[str, "sqlalchemy.URL", DatabaseURL], encoding: str = "utf8", template: typing.Any = None, ) -> None: @@ -229,29 +224,29 @@ async def create_database( text = "CREATE DATABASE {} ENCODING '{}' TEMPLATE {}".format( quote(conn, database), encoding, quote(conn, template) ) - await conn.execute(sa.text(text)) + await conn.execute(sqlalchemy.text(text)) elif dialect_name == "mysql": async with db_client.engine.begin() as conn: # type: ignore text = "CREATE DATABASE {} CHARACTER SET = '{}'".format( quote(conn, database), encoding ) - await conn.execute(sa.text(text)) + await conn.execute(sqlalchemy.text(text)) elif dialect_name == "sqlite" and database != ":memory:": if database: # create a sqlite file async with db_client.engine.begin() as conn: # type: ignore - await conn.execute(sa.text("CREATE TABLE DB(id int)")) - await conn.execute(sa.text("DROP TABLE DB")) + await conn.execute(sqlalchemy.text("CREATE TABLE DB(id int)")) + await conn.execute(sqlalchemy.text("DROP TABLE DB")) else: async with db_client.engine.begin() as conn: # type: ignore text = f"CREATE DATABASE {quote(conn, database)}" - await conn.execute(sa.text(text)) + await conn.execute(sqlalchemy.text(text)) @classmethod - async def drop_database(cls, url: typing.Union[str, "sa.URL", DatabaseURL]) -> None: + async def drop_database(cls, url: typing.Union[str, "sqlalchemy.URL", DatabaseURL]) -> None: url = url if isinstance(url, DatabaseURL) else DatabaseURL(url) database = url.database dialect = url.sqla_url.get_dialect(True) @@ -310,7 +305,7 @@ async def drop_database(cls, url: typing.Union[str, "sa.URL", DatabaseURL]) -> N else: async with db_client.connection() as conn: text = f"DROP DATABASE {quote(conn.async_connection, database)}" - await conn.execute(sa.text(text)) + await conn.execute(sqlalchemy.text(text)) def drop_db_protected(self) -> None: thread = ThreadPassingExceptions( diff --git a/docs/release-notes.md b/docs/release-notes.md index 6e147bf..052251e 100644 --- a/docs/release-notes.md +++ b/docs/release-notes.md @@ -7,6 +7,7 @@ - The global connection is now entered lazily despite sub-databases. - Fix deadlock with full_isolation off. - Fix database.transaction() failing because of AsyncDatabaseHelper. +- Fix DatabaseTestClient not able to use config initialization. ## 0.10.2 diff --git a/pyproject.toml b/pyproject.toml index a5ab411..713a089 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -108,6 +108,7 @@ select = [ ] ignore = [ + "ASYNC110", # use anio.Event "B008", # do not perform function calls in argument defaults "C901", # too complex "E712", # Comparison to True should be cond is True diff --git a/tests/shared_db.py b/tests/shared_db.py index 619b40c..a066444 100644 --- a/tests/shared_db.py +++ b/tests/shared_db.py @@ -79,7 +79,14 @@ async def database_client(url: typing.Union[dict, str], meta=None) -> DatabaseTe url, test_prefix="", use_existing=not is_sqlite, drop_database=is_sqlite ) else: - database = Database(config=url) + scheme = url["connection"]["credentials"]["scheme"] + is_sqlite = scheme.startswith("sqlite") + database = DatabaseTestClient( + config=url, + test_prefix="", + use_existing=not is_sqlite, + drop_database=is_sqlite, + ) await database.connect() await database.create_all(meta) return database @@ -88,5 +95,6 @@ async def database_client(url: typing.Union[dict, str], meta=None) -> DatabaseTe async def stop_database_client(database: Database, meta=None): if meta is None: meta = metadata - await database.drop_all(meta) + if not getattr(database, "drop", False): + await database.drop_all(meta) await database.disconnect() diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py index b57035c..1089e85 100644 --- a/tests/test_concurrency.py +++ b/tests/test_concurrency.py @@ -21,7 +21,9 @@ DATABASE_URLS = [url.strip() for url in os.environ["TEST_DATABASE_URLS"].split(",")] -if not any((x.endswith(" for SQL Server") for x in pyodbc.drivers())): +if os.environ.get("TEST_NO_RISK_SEGFAULTS") or not any( + (x.endswith(" for SQL Server") for x in pyodbc.drivers()) +): DATABASE_URLS = list(filter(lambda x: "mssql" not in x, DATABASE_URLS)) diff --git a/tests/test_really_old_jdbc.py b/tests/test_really_old_jdbc.py index c97e8d0..18381fa 100644 --- a/tests/test_really_old_jdbc.py +++ b/tests/test_really_old_jdbc.py @@ -1,6 +1,6 @@ import pytest import sqlalchemy -from sqlalchemy.pool import NullPool +from sqlalchemy.pool import StaticPool from databasez import Database @@ -25,7 +25,7 @@ async def test_jdbc_connect(): """ async with Database( "jdbc+sqlite://testsuite.sqlite3?classpath=tests/sqlite-jdbc-3.6.13.jar&jdbc_driver=org.sqlite.JDBC", - poolclass=NullPool, + poolclass=StaticPool, ) as database: async with database.connection(): pass @@ -39,7 +39,7 @@ async def test_jdbc_queries(): """ async with Database( "jdbc+sqlite://testsuite.sqlite3?classpath=tests/sqlite-jdbc-3.6.13.jar&jdbc_driver=org.sqlite.JDBC", - poolclass=NullPool, + poolclass=StaticPool, ) as database: async with database.connection() as connection: await connection.create_all(metadata)