From 49ab625e7280bcd9dea35131b709e61d89b15f11 Mon Sep 17 00:00:00 2001 From: alex Date: Mon, 26 Aug 2024 12:43:17 +0200 Subject: [PATCH] Changes: - finish full_isolation - fix interface of Connection --- databasez/core/connection.py | 104 ++++++++++++++++++++++----- databasez/core/database.py | 33 ++++++--- databasez/testclient.py | 2 +- databasez/utils.py | 14 ++-- docs/connections-and-transactions.md | 4 +- docs/release-notes.md | 5 ++ docs/test-client.md | 8 +++ 7 files changed, 135 insertions(+), 35 deletions(-) diff --git a/databasez/core/connection.py b/databasez/core/connection.py index 4eb2658..e1d97a3 100644 --- a/databasez/core/connection.py +++ b/databasez/core/connection.py @@ -18,12 +18,14 @@ from sqlalchemy import MetaData from sqlalchemy.sql import ClauseElement + from databasez.types import BatchCallable, BatchCallableResult + from .database import Database async def _startup(database: Database, is_initialized: Event) -> None: await database.connect() - await database._global_connection._aenter() + await database._global_connection._aenter() # type: ignore is_initialized.set() @@ -73,7 +75,7 @@ def __init__( self._force_rollback = force_rollback self.connection_transaction: typing.Optional[Transaction] = None - @multiloop_protector(False) + @multiloop_protector(False, passthrough_timeout=True) # fail when specifying timeout async def _aenter(self) -> None: async with self._connection_lock: self._connection_counter += 1 @@ -158,9 +160,9 @@ async def _aexit(self) -> typing.Optional[Thread]: return thread else: await self._aexit_raw() - return None + return None - @multiloop_protector(False) + @multiloop_protector(False, passthrough_timeout=True) # fail when specifying timeout async def __aexit__( self, exc_type: typing.Optional[typing.Type[BaseException]] = None, @@ -179,7 +181,7 @@ def _loop(self) -> typing.Any: return self._database._loop @property - def _backend(self) -> typing.Any: + def _backend(self) -> interfaces.DatabaseBackend: return self._database.backend @multiloop_protector(False) @@ -187,6 +189,9 @@ async def fetch_all( self, query: typing.Union[ClauseElement, str], values: typing.Optional[dict] = None, + timeout: typing.Optional[ + float + ] = None, # stub for type checker, multiloop_protector handles timeout ) -> typing.List[interfaces.Record]: built_query = self._build_query(query, values) async with self._query_lock: @@ -198,6 +203,9 @@ async def fetch_one( query: typing.Union[ClauseElement, str], values: typing.Optional[dict] = None, pos: int = 0, + timeout: typing.Optional[ + float + ] = None, # stub for type checker, multiloop_protector handles timeout ) -> typing.Optional[interfaces.Record]: built_query = self._build_query(query, values) async with self._query_lock: @@ -210,6 +218,9 @@ async def fetch_val( values: typing.Optional[dict] = None, column: typing.Any = 0, pos: int = 0, + timeout: typing.Optional[ + float + ] = None, # stub for type checker, multiloop_protector handles timeout ) -> typing.Any: built_query = self._build_query(query, values) async with self._query_lock: @@ -220,6 +231,9 @@ async def execute( self, query: typing.Union[ClauseElement, str], values: typing.Any = None, + timeout: typing.Optional[ + float + ] = None, # stub for type checker, multiloop_protector handles timeout ) -> typing.Union[interfaces.Record, int]: if isinstance(query, str): built_query = self._build_query(query, values) @@ -231,7 +245,12 @@ async def execute( @multiloop_protector(False) async def execute_many( - self, query: typing.Union[ClauseElement, str], values: typing.Any = None + self, + query: typing.Union[ClauseElement, str], + values: typing.Any = None, + timeout: typing.Optional[ + float + ] = None, # stub for type checker, multiloop_protector handles timeout ) -> typing.Union[typing.Sequence[interfaces.Record], int]: if isinstance(query, str): built_query = self._build_query(query, None) @@ -241,46 +260,90 @@ async def execute_many( async with self._query_lock: return await self._connection.execute_many(query, values) - @multiloop_protector(False) + @multiloop_protector(False, passthrough_timeout=True) async def iterate( self, query: typing.Union[ClauseElement, str], values: typing.Optional[dict] = None, - batch_size: typing.Optional[int] = None, - ) -> typing.AsyncGenerator[typing.Any, None]: + chunk_size: typing.Optional[int] = None, + timeout: typing.Optional[float] = None, + ) -> typing.AsyncGenerator[interfaces.Record, None]: built_query = self._build_query(query, values) + if timeout is None or timeout <= 0: + next_fn: typing.Callable[[typing.Any], typing.Awaitable[interfaces.Record]] = anext + else: + + async def next_fn(inp: typing.Any) -> interfaces.Record: + return await asyncio.wait_for(anext(aiterator), timeout=timeout) + async with self._query_lock: - async for record in self._connection.iterate(built_query, batch_size): - yield record + aiterator = self._connection.iterate(built_query, chunk_size).__aiter__() + try: + while True: + yield await next_fn(aiterator) + except StopAsyncIteration: + pass - @multiloop_protector(False) + @multiloop_protector(False, passthrough_timeout=True) async def batched_iterate( self, query: typing.Union[ClauseElement, str], values: typing.Optional[dict] = None, batch_size: typing.Optional[int] = None, - ) -> typing.AsyncGenerator[typing.Any, None]: + batch_wrapper: BatchCallable = tuple, + timeout: typing.Optional[float] = None, + ) -> typing.AsyncGenerator[BatchCallableResult, None]: built_query = self._build_query(query, values) + if timeout is None or timeout <= 0: + next_fn: typing.Callable[ + [typing.Any], typing.Awaitable[typing.Sequence[interfaces.Record]] + ] = anext + else: + + async def next_fn(inp: typing.Any) -> typing.Sequence[interfaces.Record]: + return await asyncio.wait_for(anext(aiterator), timeout=timeout) + async with self._query_lock: - async for records in self._connection.batched_iterate(built_query, batch_size): - yield records + aiterator = self._connection.batched_iterate(built_query, batch_size).__aiter__() + try: + while True: + yield batch_wrapper(await next_fn(aiterator)) + except StopAsyncIteration: + pass @multiloop_protector(False) async def run_sync( self, fn: typing.Callable[..., typing.Any], *args: typing.Any, + timeout: typing.Optional[ + float + ] = None, # stub for type checker, multiloop_protector handles timeout **kwargs: typing.Any, ) -> typing.Any: async with self._query_lock: return await self._connection.run_sync(fn, *args, **kwargs) @multiloop_protector(False) - async def create_all(self, meta: MetaData, **kwargs: typing.Any) -> None: + async def create_all( + self, + meta: MetaData, + timeout: typing.Optional[ + float + ] = None, # stub for type checker, multiloop_protector handles timeout + **kwargs: typing.Any, + ) -> None: await self.run_sync(meta.create_all, **kwargs) @multiloop_protector(False) - async def drop_all(self, meta: MetaData, **kwargs: typing.Any) -> None: + async def drop_all( + self, + meta: MetaData, + timeout: typing.Optional[ + float + ] = None, # stub for type checker, multiloop_protector handles timeout + **kwargs: typing.Any, + ) -> None: await self.run_sync(meta.drop_all, **kwargs) def transaction(self, *, force_rollback: bool = False, **kwargs: typing.Any) -> "Transaction": @@ -293,7 +356,12 @@ def async_connection(self) -> typing.Any: return self._connection.async_connection @multiloop_protector(False) - async def get_raw_connection(self) -> typing.Any: + async def get_raw_connection( + self, + timeout: typing.Optional[ + float + ] = None, # stub for type checker, multiloop_protector handles timeout + ) -> typing.Any: """The real raw connection (driver).""" return await self.async_connection.get_raw_connection() diff --git a/databasez/core/database.py b/databasez/core/database.py index 1031245..4d5c40b 100644 --- a/databasez/core/database.py +++ b/databasez/core/database.py @@ -150,6 +150,7 @@ class Database: options: typing.Any is_connected: bool = False _call_hooks: bool = True + _full_isolation: bool = False _force_rollback: ForceRollback # descriptor force_rollback = ForceRollbackDescriptor() @@ -398,7 +399,11 @@ async def fetch_val( ) -> typing.Any: async with self.connection() as connection: return await connection.fetch_val( - query, values, column=column, pos=pos, timeout=timeout + query, + values, + column=column, + pos=pos, + timeout=timeout, ) async def execute( @@ -435,14 +440,21 @@ async def batched_iterate( query: typing.Union[ClauseElement, str], values: typing.Optional[dict] = None, batch_size: typing.Optional[int] = None, - batch_wrapper: typing.Union[BatchCallable] = tuple, + batch_wrapper: BatchCallable = tuple, timeout: typing.Optional[float] = None, ) -> typing.AsyncGenerator[BatchCallableResult, None]: async with self.connection() as connection: - async for records in connection.batched_iterate( - query, values, batch_size, timeout=timeout + async for batch in typing.cast( + typing.AsyncGenerator["BatchCallableResult", None], + connection.batched_iterate( + query, + values, + batch_wrapper=batch_wrapper, + batch_size=batch_size, + timeout=timeout, + ), ): - yield batch_wrapper(records) + yield batch def transaction(self, *, force_rollback: bool = False, **kwargs: typing.Any) -> "Transaction": return Transaction(self.connection, force_rollback=force_rollback, **kwargs) @@ -470,18 +482,23 @@ async def drop_all( await connection.drop_all(meta, **kwargs, timeout=timeout) @multiloop_protector(False) - def _non_global_connection(self) -> Connection: + def _non_global_connection( + self, + timeout: typing.Optional[ + float + ] = None, # stub for type checker, multiloop_protector handles timeout + ) -> Connection: if self._connection is None: _connection = self._connection = Connection(self) return _connection return self._connection - def connection(self) -> Connection: + def connection(self, timeout: typing.Optional[float] = None) -> Connection: if not self.is_connected: raise RuntimeError("Database is not connected") if self.force_rollback: return typing.cast(Connection, self._global_connection) - return self._non_global_connection() + return self._non_global_connection(timeout=timeout) @property @multiloop_protector(True) diff --git a/databasez/testclient.py b/databasez/testclient.py index ae6be71..fc7959d 100644 --- a/databasez/testclient.py +++ b/databasez/testclient.py @@ -34,7 +34,7 @@ class DatabaseTestClient(Database): # is used for copying Database and DatabaseTestClientand providing an early url test_db_url: str # hooks for overwriting defaults of args with None - testclient_default_full_isolation: bool = False + testclient_default_full_isolation: bool = True testclient_default_force_rollback: bool = False testclient_default_lazy_setup: bool = False # customization hooks diff --git a/databasez/utils.py b/databasez/utils.py index 4b1b289..b2c7cc4 100644 --- a/databasez/utils.py +++ b/databasez/utils.py @@ -184,7 +184,7 @@ async def __aenter__(self) -> typing.Any: self.ctm = await _arun_with_timeout( self.fn(database, *self.args, **self.kwargs), timeout=self.timeout ) - return await _arun_with_timeout(self.ctm.__aenter__(), self.timeout) + return await self.ctm.__aenter__() async def __aexit__( self, @@ -194,9 +194,7 @@ async def __aexit__( ) -> None: assert self.ctm is not None try: - await _arun_with_timeout( - self.ctm.__aexit__(exc_type, exc_value, traceback), self.timeout - ) + await _arun_with_timeout(self.ctm.__aexit__(exc_type, exc_value, traceback), None) finally: await self.database.__aexit__() @@ -230,7 +228,7 @@ def __await__(self) -> typing.Any: async def enter_intern(self) -> typing.Any: await self.connection.__aenter__() self.ctm = await self.call() - return await _arun_with_timeout(self.ctm.__aenter__(), self.timeout) + return await self.ctm.__aenter__() async def exit_intern(self) -> typing.Any: assert self.ctm is not None @@ -256,7 +254,7 @@ async def __aexit__( def multiloop_protector( - fail_with_different_loop: bool, inject_parent: bool = False + fail_with_different_loop: bool, inject_parent: bool = False, passthrough_timeout: bool = False ) -> typing.Callable[[MultiloopProtectorCallable], MultiloopProtectorCallable]: """For multiple threads or other reasons why the loop changes""" @@ -267,9 +265,11 @@ def _decorator(fn: MultiloopProtectorCallable) -> MultiloopProtectorCallable: def wrapper( self: typing.Any, *args: typing.Any, - timeout: typing.Optional[float] = None, **kwargs: typing.Any, ) -> typing.Any: + timeout: typing.Optional[float] = None + if not passthrough_timeout and "timeout" in kwargs: + timeout = kwargs.pop("timeout") if inject_parent: assert "parent_database" not in kwargs, '"parent_database" is a reserved keyword' try: diff --git a/docs/connections-and-transactions.md b/docs/connections-and-transactions.md index e6a1047..19f2b2f 100644 --- a/docs/connections-and-transactions.md +++ b/docs/connections-and-transactions.md @@ -25,7 +25,9 @@ from databasez import Database Default: `None` * **full_isolation** - Special mode for using force_rollback with nested queries. This parameter fully isolates the global connection - in an extra thread. This way it is possible to use blocking operations. + in an extra thread. This way it is possible to use blocking operations like locks with force_rollback. + This parameter has no use when used without force_rollback and causes a slightly slower setup (Lock is initialized). + It is required for edgy or other frameworks which use threads in tests and the force_rollback parameter. Default: `None` diff --git a/docs/release-notes.md b/docs/release-notes.md index 9ec9f7f..cc79476 100644 --- a/docs/release-notes.md +++ b/docs/release-notes.md @@ -7,6 +7,11 @@ ### Added - `full_isolation` parameter. Isolate the force_rollback Connection in a thread. +- Timeouts for operations. + +### Fixed + +- `batched_iterate` interface of `Connection` differed from the one of `Database`. ## 0.9.7 diff --git a/docs/test-client.md b/docs/test-client.md index df2b4db..3771b11 100644 --- a/docs/test-client.md +++ b/docs/test-client.md @@ -48,6 +48,14 @@ from databasez.testclient import DatabaseTestClient Default: `None`, copy default or `testclient_default_force_rollback` (defaults to `False`) +* **full_isolation** - Special mode for using force_rollback with nested queries. This parameter fully isolates the global connection + in an extra thread. This way it is possible to use blocking operations like locks with force_rollback. + This parameter has no use when used without force_rollback and causes a slightly slower setup (Lock is initialized). + It is required for edgy or other frameworks which use threads in tests and the force_rollback parameter. + For the DatabaseTestClient it is enabled by default. + + Default: `None`, copy default or `testclient_default_full_isolation` (defaults to `True`) + * **lazy_setup** - This sets up the db first up on connect not in init. Default: `None`, True if copying a database or `testclient_default_lazy_setup` (defaults to `False`)