diff --git a/databasez/__init__.py b/databasez/__init__.py index d69f2a5..c084037 100644 --- a/databasez/__init__.py +++ b/databasez/__init__.py @@ -1,5 +1,5 @@ from databasez.core import Database, DatabaseURL -__version__ = "0.9.7" +__version__ = "0.10.0" __all__ = ["Database", "DatabaseURL"] diff --git a/databasez/core/connection.py b/databasez/core/connection.py index 94b3962..1ed5904 100644 --- a/databasez/core/connection.py +++ b/databasez/core/connection.py @@ -1,8 +1,11 @@ from __future__ import annotations import asyncio +import sys import typing import weakref +from contextvars import copy_context +from threading import Event, RLock, Thread, current_thread from types import TracebackType from sqlalchemy import text @@ -16,30 +19,72 @@ from sqlalchemy import MetaData from sqlalchemy.sql import ClauseElement + from databasez.types import BatchCallable, BatchCallableResult + from .database import Database +async def _startup(database: Database, is_initialized: Event) -> None: + await database.connect() + _global_connection = typing.cast(Connection, database._global_connection) + await _global_connection._aenter() + if sys.version_info < (3, 10): + # for old python versions <3.10 the locks must be created in the same event loop + _global_connection._query_lock = asyncio.Lock() + _global_connection._connection_lock = asyncio.Lock() + _global_connection._transaction_lock = asyncio.Lock() + is_initialized.set() + + +def _init_thread(database: Database, is_initialized: Event) -> None: + loop = asyncio.new_event_loop() + task = loop.create_task(_startup(database, is_initialized)) + try: + loop.run_forever() + except RuntimeError: + pass + try: + task.result() + finally: + try: + loop.run_until_complete(database.disconnect()) + loop.run_until_complete(loop.shutdown_asyncgens()) + finally: + del task + loop.close() + database._loop = None + + class Connection: def __init__( - self, database: Database, backend: interfaces.DatabaseBackend, force_rollback: bool = False + self, database: Database, force_rollback: bool = False, full_isolation: bool = False ) -> None: - self._database = database - self._backend = backend - + self._orig_database = self._database = database + self._full_isolation = full_isolation + self._connection_thread_lock: typing.Optional[RLock] = None + self._isolation_thread: typing.Optional[Thread] = None + if self._full_isolation: + self._database = database.__class__( + database, force_rollback=force_rollback, full_isolation=False + ) + self._database._call_hooks = False + self._database._global_connection = self + self._connection_thread_lock = RLock() + # the asyncio locks are overwritten in python versions < 3.10 when using full_isolation + self._query_lock = asyncio.Lock() self._connection_lock = asyncio.Lock() + self._transaction_lock = asyncio.Lock() self._connection = self._backend.connection() self._connection.owner = self self._connection_counter = 0 - self._transaction_lock = asyncio.Lock() self._transaction_stack: typing.List[Transaction] = [] - self._query_lock = asyncio.Lock() self._force_rollback = force_rollback self.connection_transaction: typing.Optional[Transaction] = None - @multiloop_protector(False) - async def __aenter__(self) -> Connection: + @multiloop_protector(False, passthrough_timeout=True) # fail when specifying timeout + async def _aenter(self) -> None: async with self._connection_lock: self._connection_counter += 1 try: @@ -63,38 +108,98 @@ async def __aenter__(self) -> Connection: except BaseException as e: self._connection_counter -= 1 raise e + + async def __aenter__(self) -> Connection: + initialized = False + if self._full_isolation: + assert self._connection_thread_lock is not None + with self._connection_thread_lock: + if self._isolation_thread is None: + initialized = True + is_initialized = Event() + ctx = copy_context() + self._isolation_thread = thread = Thread( + target=ctx.run, + args=[ + _init_thread, + self._database, + is_initialized, + ], + daemon=True, + ) + thread.start() + is_initialized.wait() + if not thread.is_alive(): + self._isolation_thread = None + thread.join() + if not initialized: + await self._aenter() return self - @multiloop_protector(False) - async def __aexit__( - self, - exc_type: typing.Optional[typing.Type[BaseException]] = None, - exc_value: typing.Optional[BaseException] = None, - traceback: typing.Optional[TracebackType] = None, - ) -> None: + async def _aexit_raw(self) -> bool: + closing = False async with self._connection_lock: assert self._connection is not None self._connection_counter -= 1 if self._connection_counter == 0: + closing = True try: if self.connection_transaction: # __aexit__ needs the connection_transaction parameter - await self.connection_transaction.__aexit__(exc_type, exc_value, traceback) + await self.connection_transaction.__aexit__() # untie, for allowing gc self.connection_transaction = None finally: await self._connection.release() self._database._connection = None + return closing + + async def _aexit(self) -> typing.Optional[Thread]: + if self._full_isolation: + assert self._connection_thread_lock is not None + with self._connection_thread_lock: + if await self._aexit_raw(): + loop = self._database._loop + thread = self._isolation_thread + if loop is not None: + loop.stop() + else: + self._isolation_thread = None + return thread + else: + await self._aexit_raw() + return None + + @multiloop_protector(False, passthrough_timeout=True) # fail when specifying timeout + async def __aexit__( + self, + exc_type: typing.Optional[typing.Type[BaseException]] = None, + exc_value: typing.Optional[BaseException] = None, + traceback: typing.Optional[TracebackType] = None, + ) -> None: + thread = None + try: + thread = await self._aexit() + finally: + if thread is not None and thread is not current_thread(): + thread.join() @property def _loop(self) -> typing.Any: return self._database._loop + @property + def _backend(self) -> interfaces.DatabaseBackend: + return self._database.backend + @multiloop_protector(False) async def fetch_all( self, query: typing.Union[ClauseElement, str], values: typing.Optional[dict] = None, + timeout: typing.Optional[ + float + ] = None, # stub for type checker, multiloop_protector handles timeout ) -> typing.List[interfaces.Record]: built_query = self._build_query(query, values) async with self._query_lock: @@ -106,6 +211,9 @@ async def fetch_one( query: typing.Union[ClauseElement, str], values: typing.Optional[dict] = None, pos: int = 0, + timeout: typing.Optional[ + float + ] = None, # stub for type checker, multiloop_protector handles timeout ) -> typing.Optional[interfaces.Record]: built_query = self._build_query(query, values) async with self._query_lock: @@ -118,6 +226,9 @@ async def fetch_val( values: typing.Optional[dict] = None, column: typing.Any = 0, pos: int = 0, + timeout: typing.Optional[ + float + ] = None, # stub for type checker, multiloop_protector handles timeout ) -> typing.Any: built_query = self._build_query(query, values) async with self._query_lock: @@ -128,6 +239,9 @@ async def execute( self, query: typing.Union[ClauseElement, str], values: typing.Any = None, + timeout: typing.Optional[ + float + ] = None, # stub for type checker, multiloop_protector handles timeout ) -> typing.Union[interfaces.Record, int]: if isinstance(query, str): built_query = self._build_query(query, values) @@ -139,7 +253,12 @@ async def execute( @multiloop_protector(False) async def execute_many( - self, query: typing.Union[ClauseElement, str], values: typing.Any = None + self, + query: typing.Union[ClauseElement, str], + values: typing.Any = None, + timeout: typing.Optional[ + float + ] = None, # stub for type checker, multiloop_protector handles timeout ) -> typing.Union[typing.Sequence[interfaces.Record], int]: if isinstance(query, str): built_query = self._build_query(query, None) @@ -149,49 +268,96 @@ async def execute_many( async with self._query_lock: return await self._connection.execute_many(query, values) - @multiloop_protector(False) + @multiloop_protector(False, passthrough_timeout=True) async def iterate( self, query: typing.Union[ClauseElement, str], values: typing.Optional[dict] = None, - batch_size: typing.Optional[int] = None, - ) -> typing.AsyncGenerator[typing.Any, None]: + chunk_size: typing.Optional[int] = None, + timeout: typing.Optional[float] = None, + ) -> typing.AsyncGenerator[interfaces.Record, None]: built_query = self._build_query(query, values) + if timeout is None or timeout <= 0: + # anext is available in python 3.10 + + async def next_fn(inp: typing.Any) -> interfaces.Record: + return await aiterator.__anext__() + else: + + async def next_fn(inp: typing.Any) -> interfaces.Record: + return await asyncio.wait_for(aiterator.__anext__(), timeout=timeout) + async with self._query_lock: - async for record in self._connection.iterate(built_query, batch_size): - yield record + aiterator = self._connection.iterate(built_query, chunk_size).__aiter__() + try: + while True: + yield await next_fn(aiterator) + except StopAsyncIteration: + pass - @multiloop_protector(False) + @multiloop_protector(False, passthrough_timeout=True) async def batched_iterate( self, query: typing.Union[ClauseElement, str], values: typing.Optional[dict] = None, batch_size: typing.Optional[int] = None, - ) -> typing.AsyncGenerator[typing.Any, None]: + batch_wrapper: BatchCallable = tuple, + timeout: typing.Optional[float] = None, + ) -> typing.AsyncGenerator[BatchCallableResult, None]: built_query = self._build_query(query, values) + if timeout is None or timeout <= 0: + # anext is available in python 3.10 + + async def next_fn(inp: typing.Any) -> typing.Sequence[interfaces.Record]: + return await aiterator.__anext__() + else: + + async def next_fn(inp: typing.Any) -> typing.Sequence[interfaces.Record]: + return await asyncio.wait_for(aiterator.__anext__(), timeout=timeout) + async with self._query_lock: - async for records in self._connection.batched_iterate(built_query, batch_size): - yield records + aiterator = self._connection.batched_iterate(built_query, batch_size).__aiter__() + try: + while True: + yield batch_wrapper(await next_fn(aiterator)) + except StopAsyncIteration: + pass @multiloop_protector(False) async def run_sync( self, fn: typing.Callable[..., typing.Any], *args: typing.Any, + timeout: typing.Optional[ + float + ] = None, # stub for type checker, multiloop_protector handles timeout **kwargs: typing.Any, ) -> typing.Any: async with self._query_lock: return await self._connection.run_sync(fn, *args, **kwargs) @multiloop_protector(False) - async def create_all(self, meta: MetaData, **kwargs: typing.Any) -> None: + async def create_all( + self, + meta: MetaData, + timeout: typing.Optional[ + float + ] = None, # stub for type checker, multiloop_protector handles timeout + **kwargs: typing.Any, + ) -> None: await self.run_sync(meta.create_all, **kwargs) @multiloop_protector(False) - async def drop_all(self, meta: MetaData, **kwargs: typing.Any) -> None: + async def drop_all( + self, + meta: MetaData, + timeout: typing.Optional[ + float + ] = None, # stub for type checker, multiloop_protector handles timeout + **kwargs: typing.Any, + ) -> None: await self.run_sync(meta.drop_all, **kwargs) - @multiloop_protector(False) def transaction(self, *, force_rollback: bool = False, **kwargs: typing.Any) -> "Transaction": return Transaction(weakref.ref(self), force_rollback, **kwargs) @@ -202,7 +368,12 @@ def async_connection(self) -> typing.Any: return self._connection.async_connection @multiloop_protector(False) - async def get_raw_connection(self) -> typing.Any: + async def get_raw_connection( + self, + timeout: typing.Optional[ + float + ] = None, # stub for type checker, multiloop_protector handles timeout + ) -> typing.Any: """The real raw connection (driver).""" return await self.async_connection.get_raw_connection() diff --git a/databasez/core/database.py b/databasez/core/database.py index 70da951..4d5c40b 100644 --- a/databasez/core/database.py +++ b/databasez/core/database.py @@ -150,6 +150,7 @@ class Database: options: typing.Any is_connected: bool = False _call_hooks: bool = True + _full_isolation: bool = False _force_rollback: ForceRollback # descriptor force_rollback = ForceRollbackDescriptor() @@ -160,6 +161,7 @@ def __init__( *, force_rollback: typing.Union[bool, None] = None, config: typing.Optional["DictAny"] = None, + full_isolation: typing.Union[bool, None] = None, **options: typing.Any, ): init() @@ -172,6 +174,8 @@ def __init__( self._call_hooks = url._call_hooks if force_rollback is None: force_rollback = bool(url.force_rollback) + if full_isolation is None: + full_isolation = bool(url._full_isolation) else: url = DatabaseURL(url) if config and "connection" in config: @@ -184,6 +188,9 @@ def __init__( ) if force_rollback is None: force_rollback = False + if full_isolation is None: + full_isolation = False + self._full_isolation = full_isolation self._force_rollback = ForceRollback(force_rollback) self.backend.owner = self self._connection_map = weakref.WeakKeyDictionary() @@ -264,10 +271,15 @@ async def connect(self) -> bool: if self._loop is not None and loop != self._loop: # copy when not in map if loop not in self._databases_map: + assert ( + self._global_connection is not None + ), "global connection should have been set" + # correctly initialize force_rollback with parent value + database = self.__class__( + self, force_rollback=bool(self.force_rollback), full_isolation=False + ) # prevent side effects of connect_hook - database = self.__copy__() database._call_hooks = False - assert self._global_connection database._global_connection = await self._global_connection.__aenter__() self._databases_map[loop] = database # forward call @@ -289,7 +301,8 @@ async def connect(self) -> bool: self.is_connected = True if self._global_connection is None: - self._global_connection = Connection(self, self.backend, force_rollback=True) + connection = Connection(self, force_rollback=True, full_isolation=self._full_isolation) + self._global_connection = connection return True async def disconnect_hook(self) -> None: @@ -315,9 +328,13 @@ async def disconnect( loop = asyncio.get_running_loop() del parent_database._databases_map[loop] if force: - for sub_database in self._databases_map.values(): - await sub_database.disconnect(True) - self._databases_map = {} + if self._databases_map: + assert not self._databases_map, "sub databases still active, force terminate them" + for sub_database in self._databases_map.values(): + asyncio.run_coroutine_threadsafe( + sub_database.disconnect(True), sub_database._loop + ) + self._databases_map = {} assert not self._databases_map, "sub databases still active" try: @@ -353,111 +370,135 @@ async def __aexit__( ) -> None: await self.disconnect() - @multiloop_protector(False) async def fetch_all( self, query: typing.Union[ClauseElement, str], values: typing.Optional[dict] = None, + timeout: typing.Optional[float] = None, ) -> typing.List[interfaces.Record]: async with self.connection() as connection: - return await connection.fetch_all(query, values) + return await connection.fetch_all(query, values, timeout=timeout) - @multiloop_protector(False) async def fetch_one( self, query: typing.Union[ClauseElement, str], values: typing.Optional[dict] = None, pos: int = 0, + timeout: typing.Optional[float] = None, ) -> typing.Optional[interfaces.Record]: async with self.connection() as connection: - return await connection.fetch_one(query, values, pos=pos) - assert connection._connection_counter == 1 + return await connection.fetch_one(query, values, pos=pos, timeout=timeout) - @multiloop_protector(False) async def fetch_val( self, query: typing.Union[ClauseElement, str], values: typing.Optional[dict] = None, column: typing.Any = 0, pos: int = 0, + timeout: typing.Optional[float] = None, ) -> typing.Any: async with self.connection() as connection: - return await connection.fetch_val(query, values, column=column, pos=pos) + return await connection.fetch_val( + query, + values, + column=column, + pos=pos, + timeout=timeout, + ) - @multiloop_protector(False) async def execute( self, query: typing.Union[ClauseElement, str], values: typing.Any = None, + timeout: typing.Optional[float] = None, ) -> typing.Union[interfaces.Record, int]: async with self.connection() as connection: - return await connection.execute(query, values) + return await connection.execute(query, values, timeout=timeout) - @multiloop_protector(False) async def execute_many( - self, query: typing.Union[ClauseElement, str], values: typing.Any = None + self, + query: typing.Union[ClauseElement, str], + values: typing.Any = None, + timeout: typing.Optional[float] = None, ) -> typing.Union[typing.Sequence[interfaces.Record], int]: async with self.connection() as connection: - return await connection.execute_many(query, values) + return await connection.execute_many(query, values, timeout=timeout) - @multiloop_protector(False) async def iterate( self, query: typing.Union[ClauseElement, str], values: typing.Optional[dict] = None, chunk_size: typing.Optional[int] = None, + timeout: typing.Optional[float] = None, ) -> typing.AsyncGenerator[interfaces.Record, None]: async with self.connection() as connection: - async for record in connection.iterate(query, values, chunk_size): + async for record in connection.iterate(query, values, chunk_size, timeout=timeout): yield record - @multiloop_protector(False) async def batched_iterate( self, query: typing.Union[ClauseElement, str], values: typing.Optional[dict] = None, batch_size: typing.Optional[int] = None, - batch_wrapper: typing.Union[BatchCallable] = tuple, + batch_wrapper: BatchCallable = tuple, + timeout: typing.Optional[float] = None, ) -> typing.AsyncGenerator[BatchCallableResult, None]: async with self.connection() as connection: - async for records in connection.batched_iterate(query, values, batch_size): - yield batch_wrapper(records) + async for batch in typing.cast( + typing.AsyncGenerator["BatchCallableResult", None], + connection.batched_iterate( + query, + values, + batch_wrapper=batch_wrapper, + batch_size=batch_size, + timeout=timeout, + ), + ): + yield batch - @multiloop_protector(True) def transaction(self, *, force_rollback: bool = False, **kwargs: typing.Any) -> "Transaction": return Transaction(self.connection, force_rollback=force_rollback, **kwargs) - @multiloop_protector(False) async def run_sync( self, fn: typing.Callable[..., typing.Any], *args: typing.Any, + timeout: typing.Optional[float] = None, **kwargs: typing.Any, ) -> typing.Any: async with self.connection() as connection: - return await connection.run_sync(fn, *args, **kwargs) + return await connection.run_sync(fn, *args, **kwargs, timeout=timeout) - @multiloop_protector(False) - async def create_all(self, meta: MetaData, **kwargs: typing.Any) -> None: + async def create_all( + self, meta: MetaData, timeout: typing.Optional[float] = None, **kwargs: typing.Any + ) -> None: async with self.connection() as connection: - await connection.create_all(meta, **kwargs) + await connection.create_all(meta, **kwargs, timeout=timeout) - @multiloop_protector(False) - async def drop_all(self, meta: MetaData, **kwargs: typing.Any) -> None: + async def drop_all( + self, meta: MetaData, timeout: typing.Optional[float] = None, **kwargs: typing.Any + ) -> None: async with self.connection() as connection: - await connection.drop_all(meta, **kwargs) + await connection.drop_all(meta, **kwargs, timeout=timeout) @multiloop_protector(False) - def connection(self) -> Connection: + def _non_global_connection( + self, + timeout: typing.Optional[ + float + ] = None, # stub for type checker, multiloop_protector handles timeout + ) -> Connection: + if self._connection is None: + _connection = self._connection = Connection(self) + return _connection + return self._connection + + def connection(self, timeout: typing.Optional[float] = None) -> Connection: if not self.is_connected: raise RuntimeError("Database is not connected") if self.force_rollback: return typing.cast(Connection, self._global_connection) - - if self._connection is None: - _connection = self._connection = Connection(self, self.backend) - return _connection - return self._connection + return self._non_global_connection(timeout=timeout) @property @multiloop_protector(True) diff --git a/databasez/testclient.py b/databasez/testclient.py index c00b262..fc7959d 100644 --- a/databasez/testclient.py +++ b/databasez/testclient.py @@ -34,6 +34,7 @@ class DatabaseTestClient(Database): # is used for copying Database and DatabaseTestClientand providing an early url test_db_url: str # hooks for overwriting defaults of args with None + testclient_default_full_isolation: bool = True testclient_default_force_rollback: bool = False testclient_default_lazy_setup: bool = False # customization hooks @@ -47,6 +48,7 @@ def __init__( *, force_rollback: typing.Union[bool, None] = None, use_existing: typing.Union[bool, None] = None, + full_isolation: typing.Union[bool, None] = None, drop_database: typing.Union[bool, None] = None, lazy_setup: typing.Union[bool, None] = None, test_prefix: typing.Union[str, None] = None, @@ -56,6 +58,8 @@ def __init__( use_existing = self.testclient_default_use_existing if drop_database is None: drop_database = self.testclient_default_drop_database + if full_isolation is None: + full_isolation = self.testclient_default_full_isolation if test_prefix is None: test_prefix = self.testclient_default_test_prefix self._setup_executed_init = False @@ -144,7 +148,7 @@ async def database_exists(cls, url: typing.Union[str, "sa.URL", DatabaseURL]) -> text = "SELECT 1 FROM pg_database WHERE datname='%s'" % database for db in (database, "postgres", "template1", "template0", None): url = url.replace(database=db) - async with Database(url) as db_client: + async with Database(url, full_isolation=False, force_rollback=False) as db_client: try: return bool(await _get_scalar_result(db_client.engine, sa.text(text))) except (ProgrammingError, OperationalError): @@ -157,7 +161,7 @@ async def database_exists(cls, url: typing.Union[str, "sa.URL", DatabaseURL]) -> "SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA " "WHERE SCHEMA_NAME = '%s'" % database ) - async with Database(url) as db_client: + async with Database(url, full_isolation=False, force_rollback=False) as db_client: return bool(await _get_scalar_result(db_client.engine, sa.text(text))) elif dialect_name == "sqlite": @@ -169,7 +173,7 @@ async def database_exists(cls, url: typing.Union[str, "sa.URL", DatabaseURL]) -> return True else: text = "SELECT 1" - async with Database(url) as db_client: + async with Database(url, full_isolation=False, force_rollback=False) as db_client: try: return bool(await _get_scalar_result(db_client.engine, sa.text(text))) except (ProgrammingError, OperationalError): @@ -201,9 +205,11 @@ async def create_database( dialect_name == "postgresql" and dialect_driver in {"asyncpg", "pg8000", "psycopg", "psycopg2", "psycopg2cffi"} ): - db_client = Database(url, isolation_level="AUTOCOMMIT", force_rollback=False) + db_client = Database( + url, isolation_level="AUTOCOMMIT", force_rollback=False, full_isolation=False + ) else: - db_client = Database(url, force_rollback=False) + db_client = Database(url, force_rollback=False, full_isolation=False) async with db_client: if dialect_name == "postgresql": if not template: @@ -255,9 +261,11 @@ async def drop_database(cls, url: typing.Union[str, "sa.URL", DatabaseURL]) -> N dialect_name == "postgresql" and dialect_driver in {"asyncpg", "pg8000", "psycopg", "psycopg2", "psycopg2cffi"} ): - db_client = Database(url, isolation_level="AUTOCOMMIT", force_rollback=False) + db_client = Database( + url, isolation_level="AUTOCOMMIT", force_rollback=False, full_isolation=False + ) else: - db_client = Database(url, force_rollback=False) + db_client = Database(url, force_rollback=False, full_isolation=False) async with db_client: if dialect_name == "sqlite" and database and database != ":memory:": try: diff --git a/databasez/utils.py b/databasez/utils.py index 07262f5..b2c7cc4 100644 --- a/databasez/utils.py +++ b/databasez/utils.py @@ -140,28 +140,50 @@ def join(self, timeout: typing.Union[float, int, None] = None) -> None: MultiloopProtectorCallable = typing.TypeVar("MultiloopProtectorCallable", bound=typing.Callable) +def _run_with_timeout(inp: typing.Any, timeout: typing.Optional[float]) -> typing.Any: + if timeout is not None and timeout > 0 and inspect.isawaitable(inp): + inp = asyncio.wait_for(inp, timeout=timeout) + return inp + + +async def _arun_with_timeout(inp: typing.Any, timeout: typing.Optional[float]) -> typing.Any: + if timeout is not None and timeout > 0 and inspect.isawaitable(inp): + inp = await asyncio.wait_for(inp, timeout=timeout) + elif inspect.isawaitable(inp): + return await inp + return inp + + class AsyncHelperDatabase: def __init__( self, database: typing.Any, fn: typing.Callable, - *args: typing.Any, - **kwargs: typing.Any, + args: typing.Any, + kwargs: typing.Any, + timeout: typing.Optional[float], ) -> None: - self.database = database.__copy__() - self.fn = partial(fn, self.database, *args, **kwargs) + self.database = database + self.fn = fn + self.args = args + self.kwargs = kwargs + self.timeout = timeout self.ctm = None async def call(self) -> typing.Any: - async with self.database: - return await self.fn() + async with self.database as database: + return await _arun_with_timeout( + self.fn(database, *self.args, **self.kwargs), self.timeout + ) def __await__(self) -> typing.Any: return self.call().__await__() async def __aenter__(self) -> typing.Any: - await self.database.__aenter__() - self.ctm = self.fn() + database = await self.database.__aenter__() + self.ctm = await _arun_with_timeout( + self.fn(database, *self.args, **self.kwargs), timeout=self.timeout + ) return await self.ctm.__aenter__() async def __aexit__( @@ -172,7 +194,7 @@ async def __aexit__( ) -> None: assert self.ctm is not None try: - await self.ctm.__aexit__(exc_type, exc_value, traceback) + await _arun_with_timeout(self.ctm.__aexit__(exc_type, exc_value, traceback), None) finally: await self.database.__aexit__() @@ -182,18 +204,19 @@ def __init__( self, connection: typing.Any, fn: typing.Callable, - *args: typing.Any, - **kwargs: typing.Any, + args: typing.Any, + kwargs: typing.Any, + timeout: typing.Optional[float], ) -> None: self.connection = connection self.fn = partial(fn, self.connection, *args, **kwargs) + self.timeout = timeout self.ctm = None async def call(self) -> typing.Any: async with self.connection: - result = self.fn() - if inspect.isawaitable(result): - result = await result + # is automatically awaited + result = await _arun_with_timeout(self.fn(), self.timeout) return result async def acall(self) -> typing.Any: @@ -231,7 +254,7 @@ async def __aexit__( def multiloop_protector( - fail_with_different_loop: bool, inject_parent: bool = False + fail_with_different_loop: bool, inject_parent: bool = False, passthrough_timeout: bool = False ) -> typing.Callable[[MultiloopProtectorCallable], MultiloopProtectorCallable]: """For multiple threads or other reasons why the loop changes""" @@ -239,7 +262,14 @@ def multiloop_protector( # needs _loop attribute to check against def _decorator(fn: MultiloopProtectorCallable) -> MultiloopProtectorCallable: @wraps(fn) - def wrapper(self: typing.Any, *args: typing.Any, **kwargs: typing.Any) -> typing.Any: + def wrapper( + self: typing.Any, + *args: typing.Any, + **kwargs: typing.Any, + ) -> typing.Any: + timeout: typing.Optional[float] = None + if not passthrough_timeout and "timeout" in kwargs: + timeout = kwargs.pop("timeout") if inject_parent: assert "parent_database" not in kwargs, '"parent_database" is a reserved keyword' try: @@ -262,8 +292,8 @@ def wrapper(self: typing.Any, *args: typing.Any, **kwargs: typing.Any) -> typing if hasattr(self, "_databases_map") else AsyncHelperConnection ) - return helper(self, fn, *args, **kwargs) - return fn(self, *args, **kwargs) + return helper(self, fn, args, kwargs, timeout=timeout) + return _run_with_timeout(fn(self, *args, **kwargs), timeout=timeout) return typing.cast(MultiloopProtectorCallable, wrapper) diff --git a/docs/connections-and-transactions.md b/docs/connections-and-transactions.md index ee8300d..19f2b2f 100644 --- a/docs/connections-and-transactions.md +++ b/docs/connections-and-transactions.md @@ -24,6 +24,13 @@ from databasez import Database Default: `None` +* **full_isolation** - Special mode for using force_rollback with nested queries. This parameter fully isolates the global connection + in an extra thread. This way it is possible to use blocking operations like locks with force_rollback. + This parameter has no use when used without force_rollback and causes a slightly slower setup (Lock is initialized). + It is required for edgy or other frameworks which use threads in tests and the force_rollback parameter. + + Default: `None` + * **config** - A python like dictionary as alternative to the `url` that contains the information to connect to the database. @@ -35,6 +42,8 @@ to connect to the database. Be careful when setting up the `url` or `config`. You can use one or the other but not both at the same time. +!!! Warning + `full_isolation` is not mature and shouldn't be used in production code. **Attributes*** diff --git a/docs/queries.md b/docs/queries.md index c6af888..1594e65 100644 --- a/docs/queries.md +++ b/docs/queries.md @@ -49,7 +49,7 @@ execute it with databasez. ## Queries Since you can use [SQLAlchemy core](https://docs.sqlalchemy.org/en/20/core/), that also means you -can also use the queries. Check out the [official tutorial](https://docs.sqlalchemy.org/en/14/core/tutorial.html). +can also use the queries. Check out the [official tutorial](https://docs.sqlalchemy.org/en/20/tutorial/). ```python {!> ../docs_src/queries/queries.py !} diff --git a/docs/release-notes.md b/docs/release-notes.md index c099b24..10315a2 100644 --- a/docs/release-notes.md +++ b/docs/release-notes.md @@ -1,6 +1,22 @@ # Release Notes +## 0.10.0 + + +### Added + +- `full_isolation` parameter. Isolate the force_rollback Connection in a thread. +- Timeouts for operations. + +### Fixed + +- `batched_iterate` interface of `Connection` differed from the one of `Database`. +- `iterate` interface of `Connection` differed from the one of `Database`. +- Hooks were called on automatically created Database objects. +- More multithreading safety. + + ## 0.9.7 ### Added diff --git a/docs/test-client.md b/docs/test-client.md index df2b4db..3771b11 100644 --- a/docs/test-client.md +++ b/docs/test-client.md @@ -48,6 +48,14 @@ from databasez.testclient import DatabaseTestClient Default: `None`, copy default or `testclient_default_force_rollback` (defaults to `False`) +* **full_isolation** - Special mode for using force_rollback with nested queries. This parameter fully isolates the global connection + in an extra thread. This way it is possible to use blocking operations like locks with force_rollback. + This parameter has no use when used without force_rollback and causes a slightly slower setup (Lock is initialized). + It is required for edgy or other frameworks which use threads in tests and the force_rollback parameter. + For the DatabaseTestClient it is enabled by default. + + Default: `None`, copy default or `testclient_default_full_isolation` (defaults to `True`) + * **lazy_setup** - This sets up the db first up on connect not in init. Default: `None`, True if copying a database or `testclient_default_lazy_setup` (defaults to `False`) diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py new file mode 100644 index 0000000..c89d2c3 --- /dev/null +++ b/tests/test_concurrency.py @@ -0,0 +1,246 @@ +import asyncio +import contextvars +import functools +import os +from concurrent.futures import Future +from threading import Thread + +import pyodbc +import pytest + +from databasez import Database, DatabaseURL +from tests.shared_db import ( + database_client, + notes, + stop_database_client, +) + +assert "TEST_DATABASE_URLS" in os.environ, "TEST_DATABASE_URLS is not set." + +DATABASE_URLS = [url.strip() for url in os.environ["TEST_DATABASE_URLS"].split(",")] + +if not any((x.endswith(" for SQL Server") for x in pyodbc.drivers())): + DATABASE_URLS = list(filter(lambda x: "mssql" not in x, DATABASE_URLS)) + + +try: + to_thread = asyncio.to_thread +except AttributeError: + # for py <= 3.8 + async def to_thread(func, /, *args, **kwargs): + loop = asyncio.get_running_loop() + ctx = contextvars.copy_context() + func_call = functools.partial(ctx.run, func, *args, **kwargs) + return await loop.run_in_executor(None, func_call) + + +@pytest.fixture(params=DATABASE_URLS) +def database_url(request): + """Yield test database despite its name""" + # yield test Databases + loop = asyncio.new_event_loop() + database = loop.run_until_complete(database_client(request.param)) + yield database + loop.run_until_complete(stop_database_client(database)) + + +def _startswith(tested, params): + for param in params: + if tested.startswith(param): + return True + return False + + +@pytest.mark.asyncio +async def test_concurrent_access_on_single_connection(database_url): + database_url = DatabaseURL(str(database_url.url)) + if not _startswith(database_url.dialect, ["mysql", "mariadb", "postgres", "mssql"]): + pytest.skip("Test requires sleep function") + async with Database(database_url, force_rollback=True, full_isolation=False) as database: + + async def db_lookup(): + if database_url.dialect.startswith("postgres"): + await database.fetch_one("SELECT pg_sleep(0.3)") + elif database_url.dialect.startswith("mysql") or database_url.dialect.startswith( + "mariadb" + ): + await database.fetch_one("SELECT SLEEP(0.3)") + elif database_url.dialect.startswith("mssql"): + await database.execute("WAITFOR DELAY '00:00:00.300'") + + await asyncio.gather(db_lookup(), db_lookup(), db_lookup()) + + +def _future_helper(awaitable, future): + try: + future.set_result(asyncio.run(awaitable)) + except BaseException as exc: + future.set_exception(exc) + + +@pytest.mark.parametrize( + "join_type,full_isolation", + [ + ("to_thread", False), + ("to_thread", True), + ("thread_join_with_context", True), + ("thread_join_without_context", True), + ], +) +@pytest.mark.parametrize("force_rollback", [True, False]) +@pytest.mark.asyncio +async def test_multi_thread_db(database_url, force_rollback, join_type, full_isolation): + database_url = DatabaseURL(str(database_url.url)) + async with Database( + database_url, force_rollback=force_rollback, full_isolation=full_isolation + ) as database: + + async def db_lookup(in_thread): + async with database.connection() as conn: + assert bool(conn._database.force_rollback) == force_rollback + if not _startswith(database_url.dialect, ["mysql", "mariadb", "postgres", "mssql"]): + return + if database_url.dialect.startswith("postgres"): + await database.fetch_one("SELECT pg_sleep(0.3)") + elif database_url.dialect.startswith("mysql") or database_url.dialect.startswith( + "mariadb" + ): + await database.fetch_one("SELECT SLEEP(0.3)") + elif database_url.dialect.startswith("mssql"): + await database.execute("WAITFOR DELAY '00:00:00.300'") + + async def wrap_in_thread(): + if join_type.startswith("thread_join"): + future = Future() + args = [_future_helper, asyncio.wait_for(db_lookup(True), 3), future] + if join_type == "thread_join_with_context": + ctx = contextvars.copy_context() + args.insert(0, ctx.run) + thread = Thread(target=args[0], args=args[1:]) + thread.start() + future.result(4) + else: + await to_thread(asyncio.run, asyncio.wait_for(db_lookup(True), 3)) + + await asyncio.gather(db_lookup(False), wrap_in_thread(), wrap_in_thread()) + + +@pytest.mark.parametrize( + "join_type,full_isolation", + [ + ("to_thread", False), + ("to_thread", True), + ("thread_join_with_context", True), + ("thread_join_without_context", True), + ], +) +@pytest.mark.parametrize("force_rollback", [True, False]) +@pytest.mark.asyncio +async def test_multi_thread_db_contextmanager( + database_url, force_rollback, join_type, full_isolation +): + async with Database( + database_url, force_rollback=force_rollback, full_isolation=full_isolation + ) as database: + query = notes.insert().values(text="examplecontext", completed=True) + await database.execute(query, timeout=10) + database._non_copied_attribute = True + + async def db_connect(depth=3): + # many parallel and nested threads + async with database as new_database: + assert not hasattr(new_database, "_non_copied_attribute") + query = notes.select() + result = await database.fetch_one(query) + assert result.text == "examplecontext" + assert result.completed is True + # test delegate to sub database + assert database.engine is new_database.engine + # also this shouldn't fail because redirected + old_refcount = new_database.ref_counter + await database.connect() + assert new_database.ref_counter == old_refcount + 1 + await database.disconnect() + ops = [] + while depth >= 0: + depth -= 1 + ops.append(to_thread(asyncio.run, db_connect(depth=depth))) + await asyncio.gather(*ops) + assert new_database.ref_counter == 0 + + if join_type.startswith("thread_join"): + future = Future() + args = [_future_helper, asyncio.wait_for(db_connect(), 3), future] + if join_type == "thread_join_with_context": + ctx = contextvars.copy_context() + args.insert(0, ctx.run) + thread = Thread(target=args[0], args=args[1:]) + thread.start() + future.result(4) + else: + await to_thread(asyncio.run, asyncio.wait_for(db_connect(), 3)) + assert database.ref_counter == 0 + if force_rollback: + async with database: + query = notes.select() + result = await database.fetch_one(query) + assert result is None + + +@pytest.mark.asyncio +async def test_multi_thread_db_connect(database_url): + async with Database(database_url, force_rollback=True) as database: + + async def db_connect(): + await database.connect() + await database.fetch_one("SELECT 1") + await database.disconnect() + + await to_thread(asyncio.run, db_connect()) + + +@pytest.mark.asyncio +async def test_multi_thread_db_fails(database_url): + async with Database(database_url, force_rollback=True) as database: + + async def db_connect(): + # not in same loop + database.disconnect() + + with pytest.raises(RuntimeError): + await to_thread(asyncio.run, db_connect()) + + +@pytest.mark.asyncio +async def test_global_connection_is_initialized_lazily(database_url): + """ + Ensure that global connection is initialized at latest possible time + so it's _query_lock will belong to same event loop that async_adapter has + initialized. + + See https://github.com/dymmond/databasez/issues/157 for more context. + """ + + database_url = DatabaseURL(database_url.url) + if not _startswith(database_url.dialect, ["mysql", "mariadb", "postgres", "mssql"]): + pytest.skip("Test requires sleep function") + + database = Database(database_url, force_rollback=True) + + async def run_database_queries(): + async with database: + + async def db_lookup(): + if database_url.dialect.startswith("postgres"): + await database.fetch_one("SELECT pg_sleep(0.3)") + elif database_url.dialect.startswith("mysql") or database_url.dialect.startswith( + "mariadb" + ): + await database.fetch_one("SELECT SLEEP(0.3)") + elif database_url.dialect.startswith("mssql"): + await database.execute("WAITFOR DELAY '00:00:00.300'") + + await asyncio.gather(db_lookup(), db_lookup(), db_lookup()) + + await run_database_queries() + await database.disconnect() diff --git a/tests/test_databases.py b/tests/test_databases.py index c851e94..cee4044 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -1,8 +1,6 @@ import asyncio -import contextvars import datetime import decimal -import functools import os from collections.abc import Sequence from unittest.mock import patch @@ -57,17 +55,6 @@ MIXED_DATABASE_CONFIG_URLS_IDS = [*DATABASE_URLS, *(f"{x}[config]" for x in DATABASE_URLS)] -try: - to_thread = asyncio.to_thread -except AttributeError: - # for py <= 3.8 - async def to_thread(func, /, *args, **kwargs): - loop = asyncio.get_running_loop() - ctx = contextvars.copy_context() - func_call = functools.partial(ctx.run, func, *args, **kwargs) - return await loop.run_in_executor(None, func_call) - - @pytest.fixture(params=DATABASE_URLS) def database_url(request): """Yield test database despite its name""" @@ -831,127 +818,8 @@ async def test_database_url_interface(database_mixed_url): assert database.url == database_mixed_url -def _startswith(tested, params): - for param in params: - if tested.startswith(param): - return True - return False - - -@pytest.mark.asyncio -async def test_concurrent_access_on_single_connection(database_url): - database_url = DatabaseURL(str(database_url.url)) - if not _startswith(database_url.dialect, ["mysql", "mariadb", "postgres", "mssql"]): - pytest.skip("Test requires sleep function") - async with Database(database_url, force_rollback=True) as database: - - async def db_lookup(): - if database_url.dialect.startswith("postgres"): - await database.fetch_one("SELECT pg_sleep(0.3)") - elif database_url.dialect.startswith("mysql") or database_url.dialect.startswith( - "mariadb" - ): - await database.fetch_one("SELECT SLEEP(0.3)") - elif database_url.dialect.startswith("mssql"): - await database.execute("WAITFOR DELAY '00:00:00.300'") - - await asyncio.gather(db_lookup(), db_lookup(), db_lookup()) - - -@pytest.mark.parametrize("force_rollback", [True, False]) -@pytest.mark.asyncio -async def test_multi_thread(database_url, force_rollback): - database_url = DatabaseURL(str(database_url.url)) - async with Database(database_url, force_rollback=force_rollback) as database: - database._non_copied_attribute = True - - async def db_lookup(in_thread): - async with database.connection() as conn: - if in_thread: - assert not hasattr(conn._database, "_non_copied_attribute") - else: - assert hasattr(conn._database, "_non_copied_attribute") - assert bool(conn._database.force_rollback) == force_rollback - if not _startswith(database_url.dialect, ["mysql", "mariadb", "postgres", "mssql"]): - return - if database_url.dialect.startswith("postgres"): - await database.fetch_one("SELECT pg_sleep(0.3)") - elif database_url.dialect.startswith("mysql") or database_url.dialect.startswith( - "mariadb" - ): - await database.fetch_one("SELECT SLEEP(0.3)") - elif database_url.dialect.startswith("mssql"): - await database.execute("WAITFOR DELAY '00:00:00.300'") - - async def wrap_in_thread(): - await to_thread(asyncio.run, db_lookup(True)) - - await asyncio.gather(db_lookup(False), wrap_in_thread(), wrap_in_thread()) - - -@pytest.mark.parametrize("force_rollback", [True, False]) @pytest.mark.asyncio -async def test_multi_thread_db_contextmanager(database_url, force_rollback): - async with Database(database_url, force_rollback=force_rollback) as database: - query = notes.insert().values(text="examplecontext", completed=True) - await database.execute(query) - - async def db_connect(depth=3): - # many parallel and nested threads - async with database as new_database: - query = notes.select() - result = await database.fetch_one(query) - assert result.text == "examplecontext" - assert result.completed is True - # test delegate to sub database - assert database.engine is new_database.engine - # also this shouldn't fail because redirected - old_refcount = new_database.ref_counter - await database.connect() - assert new_database.ref_counter == old_refcount + 1 - await database.disconnect() - ops = [] - while depth >= 0: - depth -= 1 - ops.append(to_thread(asyncio.run, db_connect(depth=depth))) - await asyncio.gather(*ops) - assert new_database.ref_counter == 0 - - await to_thread(asyncio.run, db_connect()) - assert database.ref_counter == 0 - if force_rollback: - async with database: - query = notes.select() - result = await database.fetch_one(query) - assert result is None - - -@pytest.mark.asyncio -async def test_multi_thread_db_connect(database_url): - async with Database(database_url, force_rollback=True) as database: - - async def db_connect(): - await database.connect() - await database.fetch_one("SELECT 1") - await database.disconnect() - - await to_thread(asyncio.run, db_connect()) - - -@pytest.mark.asyncio -async def test_multi_thread_db_fails(database_url): - async with Database(database_url, force_rollback=True) as database: - - async def db_connect(): - # not in same loop - database.disconnect() - - with pytest.raises(RuntimeError): - await to_thread(asyncio.run, db_connect()) - - -@pytest.mark.asyncio -async def test_error_on_passed_parent_database(database_url): +async def test_error_on_passed_parent_database_argument(database_url): database = Database(database_url) # don't allow specifying parent_database with pytest.raises(AssertionError): @@ -962,41 +830,6 @@ async def test_error_on_passed_parent_database(database_url): await database.disconnect(False, None) -@pytest.mark.asyncio -async def test_global_connection_is_initialized_lazily(database_url): - """ - Ensure that global connection is initialized at latest possible time - so it's _query_lock will belong to same event loop that async_adapter has - initialized. - - See https://github.com/dymmond/databasez/issues/157 for more context. - """ - - database_url = DatabaseURL(database_url.url) - if not _startswith(database_url.dialect, ["mysql", "mariadb", "postgres", "mssql"]): - pytest.skip("Test requires sleep function") - - database = Database(database_url, force_rollback=True) - - async def run_database_queries(): - async with database: - - async def db_lookup(): - if database_url.dialect.startswith("postgres"): - await database.fetch_one("SELECT pg_sleep(0.3)") - elif database_url.dialect.startswith("mysql") or database_url.dialect.startswith( - "mariadb" - ): - await database.fetch_one("SELECT SLEEP(0.3)") - elif database_url.dialect.startswith("mssql"): - await database.execute("WAITFOR DELAY '00:00:00.300'") - - await asyncio.gather(db_lookup(), db_lookup(), db_lookup()) - - await run_database_queries() - await database.disconnect() - - @pytest.mark.parametrize("select_query", [notes.select(), "SELECT * FROM notes"]) @pytest.mark.asyncio async def test_column_names(database_url, select_query):