From 08f10793fa24485283d03a7223d27c19cf9b212a Mon Sep 17 00:00:00 2001 From: alex Date: Thu, 5 Sep 2024 14:15:43 +0200 Subject: [PATCH 1/4] Changes: - fix transactions with multithreading - stack the transaction backend instances on the transaction stack and remove ACTIVE_TRANSACTIONS (unreliable with multithreading) --- databasez/__init__.py | 2 +- databasez/core/__init__.py | 4 +- databasez/core/connection.py | 58 +++++++++- databasez/core/database.py | 56 ++++++++- databasez/core/transaction.py | 164 ++++++++++++++++----------- databasez/utils.py | 125 +------------------- docs/connections-and-transactions.md | 4 + docs/release-notes.md | 15 +++ tests/test_concurrency.py | 44 +++++-- tests/test_transactions.py | 69 +---------- 10 files changed, 268 insertions(+), 273 deletions(-) diff --git a/databasez/__init__.py b/databasez/__init__.py index 5ff66ec..95cafc6 100644 --- a/databasez/__init__.py +++ b/databasez/__init__.py @@ -1,5 +1,5 @@ from databasez.core import Database, DatabaseURL -__version__ = "0.10.1" +__version__ = "0.10.2" __all__ = ["Database", "DatabaseURL"] diff --git a/databasez/core/__init__.py b/databasez/core/__init__.py index dc429ea..b7a685c 100644 --- a/databasez/core/__init__.py +++ b/databasez/core/__init__.py @@ -1,6 +1,6 @@ from .connection import Connection from .database import Database, init from .databaseurl import DatabaseURL -from .transaction import ACTIVE_TRANSACTIONS, Transaction +from .transaction import Transaction -__all__ = ["Connection", "Database", "init", "DatabaseURL", "Transaction", "ACTIVE_TRANSACTIONS"] +__all__ = ["Connection", "Database", "init", "DatabaseURL", "Transaction"] diff --git a/databasez/core/connection.py b/databasez/core/connection.py index 9a214ce..b0531bc 100644 --- a/databasez/core/connection.py +++ b/databasez/core/connection.py @@ -3,13 +3,14 @@ import asyncio import typing import weakref +from functools import partial from threading import Event, Lock, Thread, current_thread from types import TracebackType from sqlalchemy import text from databasez import interfaces -from databasez.utils import multiloop_protector +from databasez.utils import _arun_with_timeout, arun_coroutine_threadsafe, multiloop_protector from .transaction import Transaction @@ -59,7 +60,50 @@ def _init_thread( database._global_connection._isolation_thread = None # type: ignore +class AsyncHelperConnection: + def __init__( + self, + connection: Connection, + fn: typing.Callable, + args: typing.Any, + kwargs: typing.Any, + timeout: typing.Optional[float], + ) -> None: + self.connection = connection + self.fn = partial(fn, self.connection, *args, **kwargs) + self.timeout = timeout + self.ctm = None + + async def call(self) -> typing.Any: + # is automatically awaited + result = await _arun_with_timeout(self.fn(), self.timeout) + return result + + async def acall(self) -> typing.Any: + return await arun_coroutine_threadsafe( + self.call(), self.connection._loop, self.connection.poll_interval + ) + + def __await__(self) -> typing.Any: + return self.acall().__await__() + + async def __aiter__(self) -> typing.Any: + result = await self.acall() + try: + while True: + yield await arun_coroutine_threadsafe( + _arun_with_timeout(result.__anext__(), self.timeout), + self.connection._loop, + self.connection.poll_interval, + ) + except StopAsyncIteration: + pass + + class Connection: + # async helper + async_helper: typing.Type[AsyncHelperConnection] = AsyncHelperConnection + def __init__( self, database: Database, force_rollback: bool = False, full_isolation: bool = False ) -> None: @@ -86,11 +130,18 @@ def __init__( self._connection.owner = self self._connection_counter = 0 - self._transaction_stack: typing.List[Transaction] = [] + # for keeping weak references to transactions active + self._transaction_stack: typing.List[ + typing.Tuple[Transaction, interfaces.TransactionBackend] + ] = [] self._force_rollback = force_rollback self.connection_transaction: typing.Optional[Transaction] = None + @multiloop_protector(True) + def _get_connection_backend(self) -> interfaces.ConnectionBackend: + return self._connection + @multiloop_protector(False, passthrough_timeout=True) # fail when specifying timeout async def _aenter(self) -> None: async with self._connection_lock: @@ -111,8 +162,7 @@ async def _aenter(self) -> None: self.connection_transaction = self.transaction( force_rollback=self._force_rollback ) - # make re-entrant, we have already the connection lock - await self.connection_transaction.start(True) + await self.connection_transaction.start() except BaseException as e: self._connection_counter -= 1 raise e diff --git a/databasez/core/database.py b/databasez/core/database.py index d5cb2d1..d2e377a 100644 --- a/databasez/core/database.py +++ b/databasez/core/database.py @@ -11,7 +11,12 @@ from types import TracebackType from databasez import interfaces -from databasez.utils import DATABASEZ_POLL_INTERVAL, arun_coroutine_threadsafe, multiloop_protector +from databasez.utils import ( + DATABASEZ_POLL_INTERVAL, + _arun_with_timeout, + arun_coroutine_threadsafe, + multiloop_protector, +) from .connection import Connection from .databaseurl import DatabaseURL @@ -118,6 +123,51 @@ def __delete__(self, obj: Database) -> None: obj._force_rollback.set(None) +class AsyncHelperDatabase: + def __init__( + self, + database: Database, + fn: typing.Callable, + args: typing.Any, + kwargs: typing.Any, + timeout: typing.Optional[float], + ) -> None: + self.database = database + self.fn = fn + self.args = args + self.kwargs = kwargs + self.timeout = timeout + self.ctm = None + + async def call(self) -> typing.Any: + async with self.database as database: + return await _arun_with_timeout( + self.fn(database, *self.args, **self.kwargs), self.timeout + ) + + def __await__(self) -> typing.Any: + return self.call().__await__() + + async def __aenter__(self) -> typing.Any: + database = await self.database.__aenter__() + self.ctm = await _arun_with_timeout( + self.fn(database, *self.args, **self.kwargs), timeout=self.timeout + ) + return await self.ctm.__aenter__() + + async def __aexit__( + self, + exc_type: typing.Optional[typing.Type[BaseException]] = None, + exc_value: typing.Optional[BaseException] = None, + traceback: typing.Optional[TracebackType] = None, + ) -> None: + assert self.ctm is not None + try: + await _arun_with_timeout(self.ctm.__aexit__(exc_type, exc_value, traceback), None) + finally: + await self.database.__aexit__() + + class Database: """ An abstraction on the top of the EncodeORM databases.Database object. @@ -156,6 +206,8 @@ class Database: _force_rollback: ForceRollback # descriptor force_rollback = ForceRollbackDescriptor() + # async helper + async_helper: typing.Type[AsyncHelperDatabase] = AsyncHelperDatabase def __init__( self, @@ -195,7 +247,7 @@ def __init__( if force_rollback is None: force_rollback = False if full_isolation is None: - full_isolation = False + full_isolation = True if poll_interval is None: poll_interval = DATABASEZ_POLL_INTERVAL self.poll_interval = poll_interval diff --git a/databasez/core/transaction.py b/databasez/core/transaction.py index 2c07a70..77e7500 100644 --- a/databasez/core/transaction.py +++ b/databasez/core/transaction.py @@ -1,13 +1,11 @@ from __future__ import annotations -import functools +import asyncio import typing -import weakref -from contextlib import AsyncExitStack -from contextvars import ContextVar +from functools import partial, wraps from types import TracebackType -from databasez import interfaces +from databasez.utils import _arun_with_timeout, arun_coroutine_threadsafe, multiloop_protector if typing.TYPE_CHECKING: from .connection import Connection @@ -16,12 +14,37 @@ _CallableType = typing.TypeVar("_CallableType", bound=typing.Callable) -ACTIVE_TRANSACTIONS: ContextVar[ - typing.Optional[weakref.WeakKeyDictionary[Transaction, interfaces.TransactionBackend]] -] = ContextVar("ACTIVE_TRANSACTIONS", default=None) +class AsyncHelperTransaction: + def __init__( + self, + transaction: typing.Any, + fn: typing.Callable, + args: typing.Any, + kwargs: typing.Any, + timeout: typing.Optional[float], + ) -> None: + self.transaction = transaction + self.fn = partial(fn, self.transaction, *args, **kwargs) + self.timeout = timeout + self.ctm = None + + async def call(self) -> typing.Any: + # is automatically awaited + return await _arun_with_timeout(self.fn(), self.timeout) + + async def acall(self) -> typing.Any: + return await arun_coroutine_threadsafe( + self.call(), self.transaction._loop, self.transaction.poll_interval + ) + + def __await__(self) -> typing.Any: + return self.acall().__await__() class Transaction: + # async helper + async_helper: typing.Type[AsyncHelperTransaction] = AsyncHelperTransaction + def __init__( self, connection_callable: typing.Callable[[], typing.Optional[Connection]], @@ -42,42 +65,20 @@ def connection(self) -> Connection: return conn @property - def _transaction(self) -> typing.Optional[interfaces.TransactionBackend]: - transactions = ACTIVE_TRANSACTIONS.get() - if transactions is None: - return None - - return transactions.get(self, None) - - @_transaction.setter - def _transaction( - self, transaction: typing.Optional[interfaces.TransactionBackend] - ) -> typing.Optional[interfaces.TransactionBackend]: - transactions = ACTIVE_TRANSACTIONS.get() - if transactions is None: - # shortcut, we don't need to initialize anything for None (remove transaction) - if transaction is None: - return None - transactions = weakref.WeakKeyDictionary() - else: - transactions = transactions.copy() - - if transaction is None: - transactions.pop(self, None) - else: - transactions[self] = transaction - # It is always a copy required to - # prevent sideeffects between contexts - ACTIVE_TRANSACTIONS.set(transactions) + def _loop(self) -> typing.Optional[asyncio.AbstractEventLoop]: + return self.connection._loop - return transactions.get(self, None) + @property + def poll_interval(self) -> float: + return self.connection.poll_interval async def __aenter__(self) -> Transaction: """ Called when entering `async with database.transaction()` """ + # when used with existing transaction, please call start if/when required if self._existing_transaction is None: - await self.start() + await self.start(cleanup_on_error=False) return self async def __aexit__( @@ -105,60 +106,91 @@ def __call__(self, func: _CallableType) -> _CallableType: Called if using `@database.transaction()` as a decorator. """ - @functools.wraps(func) + @wraps(func) async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: async with self: - return await func(*args, **kwargs) + await func(*args, **kwargs) return wrapper # type: ignore - async def start(self, without_transaction_lock: bool = False) -> Transaction: - connection = self.connection + # called directly from connection + @multiloop_protector(False) + async def _start( + self, + connection: Connection, + timeout: typing.Optional[ + float + ] = None, # stub for type checker, multiloop_protector handles timeout + ) -> None: + assert connection._loop - async with AsyncExitStack() as cm: - if not without_transaction_lock: - await cm.enter_async_context(connection._transaction_lock) + async with connection._transaction_lock: is_root = not connection._transaction_stack - _transaction = connection._connection.transaction(self._existing_transaction) + # we retrieve the base connection here, loop protection is required + _transaction = connection._get_connection_backend().transaction( + self._existing_transaction + ) _transaction.owner = self - # will be terminated with connection, don't bump - # fixes also a locking issue - if connection.connection_transaction is not self: - await connection.__aenter__() if self._existing_transaction is None: await _transaction.start(is_root=is_root, **self._extra_options) + # because we have an await before, we need the _transaction_lock self._transaction = _transaction - connection._transaction_stack.append(self) + connection._transaction_stack.append((self, _transaction)) + _transaction = self._transaction + + # called directly from connection + async def start( + self, + timeout: typing.Optional[float] = None, + cleanup_on_error: bool = True, + ) -> Transaction: + connection = self.connection + # count up connection and init multithreading-safe the isolation thread + # benefit 2: setup works with transaction_lock + if connection.connection_transaction is not self: + await connection.__aenter__() + # we have a loop now in case of full_isolation + try: + await self._start(connection, timeout=timeout) + except BaseException as exc: + # normal start call + if cleanup_on_error and connection.connection_transaction is not self: + await connection.__aexit__() + raise exc return self - async def commit(self) -> None: + @multiloop_protector(False) + async def commit( + self, + timeout: typing.Optional[ + float + ] = None, # stub for type checker, multiloop_protector handles timeout + ) -> None: connection = self.connection async with connection._transaction_lock: - _transaction = self._transaction # some transactions are tied to connections and are not on the transaction stack - if _transaction is not None: - # delete transaction from ACTIVE_TRANSACTIONS - self._transaction = None - assert connection._transaction_stack[-1] is self - connection._transaction_stack.pop() + if connection._transaction_stack and connection._transaction_stack[-1][0] is self: + _, _transaction = connection._transaction_stack.pop() await _transaction.commit() - # if a connection_transaction, the connetion cleans it up in __aexit__ + # if a connection_transaction, the connection cleans it up in __aexit__ # prevent loop if connection.connection_transaction is not self: await connection.__aexit__() - async def rollback(self) -> None: + @multiloop_protector(False) + async def rollback( + self, + timeout: typing.Optional[ + float + ] = None, # stub for type checker, multiloop_protector handles timeout + ) -> None: connection = self.connection async with connection._transaction_lock: - _transaction = self._transaction # some transactions are tied to connections and are not on the transaction stack - if _transaction is not None: - # delete transaction from ACTIVE_TRANSACTIONS - self._transaction = None - assert connection._transaction_stack[-1] is self - connection._transaction_stack.pop() + if connection._transaction_stack and connection._transaction_stack[-1][0] is self: + _, _transaction = connection._transaction_stack.pop() await _transaction.rollback() - # if a connection_transaction, the connetion cleans it up in __aexit__ + # if a connection_transaction, the connection cleans it up in __aexit__ # prevent loop if connection.connection_transaction is not self: await connection.__aexit__() diff --git a/databasez/utils.py b/databasez/utils.py index d5285d6..4c432ec 100644 --- a/databasez/utils.py +++ b/databasez/utils.py @@ -4,7 +4,6 @@ from concurrent.futures import Future from functools import partial, wraps from threading import Thread -from types import TracebackType DATABASEZ_RESULT_TIMEOUT: typing.Optional[float] = None # Poll with 0.1ms, this way CPU isn't at 100% @@ -184,127 +183,12 @@ def _run_with_timeout(inp: typing.Any, timeout: typing.Optional[float]) -> typin async def _arun_with_timeout(inp: typing.Any, timeout: typing.Optional[float]) -> typing.Any: if timeout is not None and timeout > 0 and inspect.isawaitable(inp): - inp = await asyncio.wait_for(inp, timeout=timeout) + return await asyncio.wait_for(inp, timeout=timeout) elif inspect.isawaitable(inp): return await inp return inp -class AsyncHelperDatabase: - def __init__( - self, - database: typing.Any, - fn: typing.Callable, - args: typing.Any, - kwargs: typing.Any, - timeout: typing.Optional[float], - ) -> None: - self.database = database - self.fn = fn - self.args = args - self.kwargs = kwargs - self.timeout = timeout - self.ctm = None - - async def call(self) -> typing.Any: - async with self.database as database: - return await _arun_with_timeout( - self.fn(database, *self.args, **self.kwargs), self.timeout - ) - - def __await__(self) -> typing.Any: - return self.call().__await__() - - async def __aenter__(self) -> typing.Any: - database = await self.database.__aenter__() - self.ctm = await _arun_with_timeout( - self.fn(database, *self.args, **self.kwargs), timeout=self.timeout - ) - return await self.ctm.__aenter__() - - async def __aexit__( - self, - exc_type: typing.Optional[typing.Type[BaseException]] = None, - exc_value: typing.Optional[BaseException] = None, - traceback: typing.Optional[TracebackType] = None, - ) -> None: - assert self.ctm is not None - try: - await _arun_with_timeout(self.ctm.__aexit__(exc_type, exc_value, traceback), None) - finally: - await self.database.__aexit__() - - -class AsyncHelperConnection: - def __init__( - self, - connection: typing.Any, - fn: typing.Callable, - args: typing.Any, - kwargs: typing.Any, - timeout: typing.Optional[float], - ) -> None: - self.connection = connection - self.fn = partial(fn, self.connection, *args, **kwargs) - self.timeout = timeout - self.ctm = None - - async def call(self) -> typing.Any: - async with self.connection: - # is automatically awaited - result = await _arun_with_timeout(self.fn(), self.timeout) - return result - - async def acall(self) -> typing.Any: - return await arun_coroutine_threadsafe( - self.call(), self.connection._loop, self.connection.poll_interval - ) - - def __await__(self) -> typing.Any: - return self.acall().__await__() - - async def __aiter__(self) -> typing.Any: - result = await self.acall() - try: - while True: - yield await arun_coroutine_threadsafe( - _arun_with_timeout(result.__anext__(), self.timeout), - self.connection._loop, - self.connection.poll_interval, - ) - except StopAsyncIteration: - pass - - async def enter_intern(self) -> typing.Any: - await self.connection.__aenter__() - self.ctm = await self.call() - return await self.ctm.__aenter__() - - async def exit_intern(self) -> typing.Any: - assert self.ctm is not None - try: - await self.ctm.__aexit__() - finally: - self.ctm = None - await self.connection.__aexit__() - - async def __aenter__(self) -> typing.Any: - return await arun_coroutine_threadsafe( - self.enter_intern(), self.connection._loop, self.connection.poll_interval - ) - - async def __aexit__( - self, - exc_type: typing.Optional[typing.Type[BaseException]] = None, - exc_value: typing.Optional[BaseException] = None, - traceback: typing.Optional[TracebackType] = None, - ) -> None: - assert self.ctm is not None - await arun_coroutine_threadsafe( - self.exit_intern(), self.connection._loop, self.connection.poll_interval - ) - - def multiloop_protector( fail_with_different_loop: bool, inject_parent: bool = False, passthrough_timeout: bool = False ) -> typing.Callable[[MultiloopProtectorCallable], MultiloopProtectorCallable]: @@ -339,12 +223,7 @@ def wrapper( else: if fail_with_different_loop: raise RuntimeError("Different loop used") - helper = ( - AsyncHelperDatabase - if hasattr(self, "_databases_map") - else AsyncHelperConnection - ) - return helper(self, fn, args, kwargs, timeout=timeout) + return self.async_helper(self, fn, args, kwargs, timeout=timeout) return _run_with_timeout(fn(self, *args, **kwargs), timeout=timeout) return typing.cast(MultiloopProtectorCallable, wrapper) diff --git a/docs/connections-and-transactions.md b/docs/connections-and-transactions.md index 3e47652..e8b613b 100644 --- a/docs/connections-and-transactions.md +++ b/docs/connections-and-transactions.md @@ -400,6 +400,10 @@ async with database.transaction(isolation_level="serializable"): ... ``` +!!! Warning + When using force_rollback, transactions require the parameter `full_isolation` when using multithreading. + Otherwise the stack is in disarray. + ## Reusing sqlalchemy engine of databasez For integration in other libraries databasez has also the AsyncEngine exposed via the `engine` property. diff --git a/docs/release-notes.md b/docs/release-notes.md index 2eb7b5c..a638489 100644 --- a/docs/release-notes.md +++ b/docs/release-notes.md @@ -1,5 +1,20 @@ # Release Notes +## 0.10.2 + +### Fixed + +- Fix transactions in multi-threading contexts. + +### Changed + +- The transaction stack contains the backend too. + +### Removed + +- Remove `ACTIVE_TRANSACTIONS` ContextVar plus tests for it. It became unreliable with multithreading. + + ## 0.10.1 ### Added diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py index 03d7625..b57035c 100644 --- a/tests/test_concurrency.py +++ b/tests/test_concurrency.py @@ -88,8 +88,16 @@ def _future_helper(awaitable, future): ("thread_join_with_context", True), ("thread_join_without_context", True), ], + ids=[ + "to_thread-no_full_isolation", + "to_thread-full_isolation", + "thread_join_with_context-full_isolation", + "thread_join_without_context-full_isolation", + ], +) +@pytest.mark.parametrize( + "force_rollback", [True, False], ids=["force_rollback", "no_force_rollback"] ) -@pytest.mark.parametrize("force_rollback", [True, False]) @pytest.mark.asyncio async def test_multi_thread_db(database_url, force_rollback, join_type, full_isolation): database_url = DatabaseURL( @@ -146,8 +154,16 @@ async def wrap_in_thread(): ("thread_join_with_context", True), ("thread_join_without_context", True), ], + ids=[ + "to_thread-no_full_isolation", + "to_thread-full_isolation", + "thread_join_with_context-full_isolation", + "thread_join_without_context-full_isolation", + ], +) +@pytest.mark.parametrize( + "force_rollback", [True, False], ids=["force_rollback", "no_force_rollback"] ) -@pytest.mark.parametrize("force_rollback", [True, False]) def test_multi_thread_db_anyio( run_params, plain_database_url, force_rollback, join_type, full_isolation ): @@ -169,8 +185,16 @@ def test_multi_thread_db_anyio( ("thread_join_with_context", True), ("thread_join_without_context", True), ], + ids=[ + "to_thread-no_full_isolation", + "to_thread-full_isolation", + "thread_join_with_context-full_isolation", + "thread_join_without_context-full_isolation", + ], +) +@pytest.mark.parametrize( + "force_rollback", [True, False], ids=["force_rollback", "no_force_rollback"] ) -@pytest.mark.parametrize("force_rollback", [True, False]) @pytest.mark.asyncio async def test_multi_thread_db_contextmanager( database_url, force_rollback, join_type, full_isolation @@ -178,16 +202,22 @@ async def test_multi_thread_db_contextmanager( async with Database( database_url, force_rollback=force_rollback, full_isolation=full_isolation ) as database: - query = notes.insert().values(text="examplecontext", completed=True) - await database.execute(query, timeout=10) + if not str(database_url.url).startswith("sqlite"): + async with database.transaction(): + query = notes.insert().values(text="examplecontext", completed=True) + await database.execute(query, timeout=10) + else: + query = notes.insert().values(text="examplecontext", completed=True) + await database.execute(query, timeout=10) database._non_copied_attribute = True async def db_connect(depth=3): # many parallel and nested threads async with database as new_database: assert not hasattr(new_database, "_non_copied_attribute") - query = notes.select() - result = await database.fetch_one(query) + async with database.transaction(): + query = notes.select() + result = await database.fetch_one(query) assert result.text == "examplecontext" assert result.completed is True # test delegate to sub database diff --git a/tests/test_transactions.py b/tests/test_transactions.py index 5455716..d2b6664 100644 --- a/tests/test_transactions.py +++ b/tests/test_transactions.py @@ -163,71 +163,6 @@ async def check_child_connection(database: Database): await database.disconnect() -@pytest.mark.asyncio -async def test_transaction_context_cleanup_contextmanager(database_url): - """ - Ensure that contextvar transactions are not persisted unecessarily. - """ - from databasez.core import ACTIVE_TRANSACTIONS - - assert ACTIVE_TRANSACTIONS.get() is None - - async with Database(database_url) as database: - async with database.transaction() as transaction: - open_transactions = ACTIVE_TRANSACTIONS.get() - assert isinstance(open_transactions, MutableMapping) - assert open_transactions.get(transaction) is transaction._transaction - - # Context manager closes, open_transactions is cleaned up - open_transactions = ACTIVE_TRANSACTIONS.get() - assert isinstance(open_transactions, MutableMapping) - assert open_transactions.get(transaction, None) is None - - -@pytest.mark.asyncio -async def test_transaction_context_cleanup_garbagecollector(database_url): - """ - Ensure that contextvar transactions are not persisted unecessarily, even - if exit handlers are not called. - This test should be an XFAIL, but cannot be due to the way that is hangs - during teardown. - """ - from databasez.core import ACTIVE_TRANSACTIONS - - assert ACTIVE_TRANSACTIONS.get() is None - - async with Database(database_url) as database: - # Should be tracking the transaction - open_transactions = ACTIVE_TRANSACTIONS.get() - assert open_transactions is None - transaction = database.transaction() - await transaction.start() - # is replaced after start() call - open_transactions = ACTIVE_TRANSACTIONS.get() - assert len(open_transactions) == 1 - - assert open_transactions.get(transaction) is transaction._transaction - - # neither .commit, .rollback, nor .__aexit__ are called - del transaction - gc.collect() - - # A strong reference to the transaction is kept alive by the connection's - # ._transaction_stack, so it is still be tracked at this point. - assert len(open_transactions) == 1 - - # If that were magically cleared, the transaction would be cleaned up, - # but as it stands this always causes a hang during teardown at - # `Database(...).disconnect()` if the transaction is not closed. - transaction = database.connection()._transaction_stack[-1] - await transaction.rollback() - assert transaction.connection._connection_counter == 0 - del transaction - - # Now with the transaction rolled-back, it should be cleaned up. - assert len(open_transactions) == 0 - - @pytest.mark.asyncio async def test_iterate_outside_transaction_with_temp_table(database_url): """ @@ -466,7 +401,7 @@ async def test_transaction_decorator(database_url): """ Ensure that @database.transaction() is supported. """ - database = Database(database_url, force_rollback=True) + database = Database(database_url, force_rollback=True, full_isolation=True) @database.transaction() async def insert_data(raise_exception): @@ -570,11 +505,9 @@ async def check_transaction(transaction): # Parent task is now in a transaction, we should not # see its transaction backend since this task was # _started_ in a context where no transaction was active. - assert transaction._transaction is None end.set() transaction = database.transaction() - assert transaction._transaction is None task = asyncio.create_task(check_transaction(transaction)) async with transaction: From fba27f2a07848facfc0963be8a8e5b09a2103d8f Mon Sep 17 00:00:00 2001 From: alex Date: Thu, 5 Sep 2024 14:20:17 +0200 Subject: [PATCH 2/4] remove obsolete warning --- docs/connections-and-transactions.md | 4 ---- 1 file changed, 4 deletions(-) diff --git a/docs/connections-and-transactions.md b/docs/connections-and-transactions.md index e8b613b..3e47652 100644 --- a/docs/connections-and-transactions.md +++ b/docs/connections-and-transactions.md @@ -400,10 +400,6 @@ async with database.transaction(isolation_level="serializable"): ... ``` -!!! Warning - When using force_rollback, transactions require the parameter `full_isolation` when using multithreading. - Otherwise the stack is in disarray. - ## Reusing sqlalchemy engine of databasez For integration in other libraries databasez has also the AsyncEngine exposed via the `engine` property. From b9f786a557702a418876c5bd5316a600bfde4fc2 Mon Sep 17 00:00:00 2001 From: alex Date: Thu, 5 Sep 2024 14:20:58 +0200 Subject: [PATCH 3/4] fix flipped default --- databasez/core/database.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/databasez/core/database.py b/databasez/core/database.py index d2e377a..c9c6c99 100644 --- a/databasez/core/database.py +++ b/databasez/core/database.py @@ -247,7 +247,7 @@ def __init__( if force_rollback is None: force_rollback = False if full_isolation is None: - full_isolation = True + full_isolation = False if poll_interval is None: poll_interval = DATABASEZ_POLL_INTERVAL self.poll_interval = poll_interval From 1823072d50ce29be9a3d16a95f72fa47ed45ea93 Mon Sep 17 00:00:00 2001 From: alex Date: Thu, 5 Sep 2024 14:25:10 +0200 Subject: [PATCH 4/4] test decorator harder --- tests/test_transactions.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/test_transactions.py b/tests/test_transactions.py index d2b6664..d7667d6 100644 --- a/tests/test_transactions.py +++ b/tests/test_transactions.py @@ -396,12 +396,15 @@ async def test_transaction_rollback_low_level(database_url): assert len(results) == 0 +@pytest.mark.parametrize( + "full_isolation", [True, False], ids=["full_isolation", "no_full_isolation"] +) @pytest.mark.asyncio -async def test_transaction_decorator(database_url): +async def test_transaction_decorator(database_url, full_isolation): """ Ensure that @database.transaction() is supported. """ - database = Database(database_url, force_rollback=True, full_isolation=True) + database = Database(database_url, force_rollback=True, full_isolation=full_isolation) @database.transaction() async def insert_data(raise_exception): @@ -424,6 +427,11 @@ async def insert_data(raise_exception): results = await database.fetch_all(query=query) assert len(results) == 1 + async with database: + query = notes.select() + results = await database.fetch_all(query=query) + assert len(results) == 0 + # highly default isolation level specific