Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix transactions in multithreading contexts #53

Merged
merged 4 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion databasez/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from databasez.core import Database, DatabaseURL

__version__ = "0.10.1"
__version__ = "0.10.2"

__all__ = ["Database", "DatabaseURL"]
4 changes: 2 additions & 2 deletions databasez/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .connection import Connection
from .database import Database, init
from .databaseurl import DatabaseURL
from .transaction import ACTIVE_TRANSACTIONS, Transaction
from .transaction import Transaction

__all__ = ["Connection", "Database", "init", "DatabaseURL", "Transaction", "ACTIVE_TRANSACTIONS"]
__all__ = ["Connection", "Database", "init", "DatabaseURL", "Transaction"]
58 changes: 54 additions & 4 deletions databasez/core/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
import asyncio
import typing
import weakref
from functools import partial
from threading import Event, Lock, Thread, current_thread
from types import TracebackType

from sqlalchemy import text

from databasez import interfaces
from databasez.utils import multiloop_protector
from databasez.utils import _arun_with_timeout, arun_coroutine_threadsafe, multiloop_protector

from .transaction import Transaction

Expand Down Expand Up @@ -59,7 +60,50 @@ def _init_thread(
database._global_connection._isolation_thread = None # type: ignore


class AsyncHelperConnection:
def __init__(
self,
connection: Connection,
fn: typing.Callable,
args: typing.Any,
kwargs: typing.Any,
timeout: typing.Optional[float],
) -> None:
self.connection = connection
self.fn = partial(fn, self.connection, *args, **kwargs)
self.timeout = timeout
self.ctm = None

async def call(self) -> typing.Any:
# is automatically awaited
result = await _arun_with_timeout(self.fn(), self.timeout)
return result

async def acall(self) -> typing.Any:
return await arun_coroutine_threadsafe(
self.call(), self.connection._loop, self.connection.poll_interval
)

def __await__(self) -> typing.Any:
return self.acall().__await__()

async def __aiter__(self) -> typing.Any:
result = await self.acall()
try:
while True:
yield await arun_coroutine_threadsafe(
_arun_with_timeout(result.__anext__(), self.timeout),
self.connection._loop,
self.connection.poll_interval,
)
except StopAsyncIteration:
pass


class Connection:
# async helper
async_helper: typing.Type[AsyncHelperConnection] = AsyncHelperConnection

def __init__(
self, database: Database, force_rollback: bool = False, full_isolation: bool = False
) -> None:
Expand All @@ -86,11 +130,18 @@ def __init__(
self._connection.owner = self
self._connection_counter = 0

self._transaction_stack: typing.List[Transaction] = []
# for keeping weak references to transactions active
self._transaction_stack: typing.List[
typing.Tuple[Transaction, interfaces.TransactionBackend]
] = []

self._force_rollback = force_rollback
self.connection_transaction: typing.Optional[Transaction] = None

@multiloop_protector(True)
def _get_connection_backend(self) -> interfaces.ConnectionBackend:
return self._connection

@multiloop_protector(False, passthrough_timeout=True) # fail when specifying timeout
async def _aenter(self) -> None:
async with self._connection_lock:
Expand All @@ -111,8 +162,7 @@ async def _aenter(self) -> None:
self.connection_transaction = self.transaction(
force_rollback=self._force_rollback
)
# make re-entrant, we have already the connection lock
await self.connection_transaction.start(True)
await self.connection_transaction.start()
except BaseException as e:
self._connection_counter -= 1
raise e
Expand Down
54 changes: 53 additions & 1 deletion databasez/core/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
from types import TracebackType

from databasez import interfaces
from databasez.utils import DATABASEZ_POLL_INTERVAL, arun_coroutine_threadsafe, multiloop_protector
from databasez.utils import (
DATABASEZ_POLL_INTERVAL,
_arun_with_timeout,
arun_coroutine_threadsafe,
multiloop_protector,
)

from .connection import Connection
from .databaseurl import DatabaseURL
Expand Down Expand Up @@ -118,6 +123,51 @@ def __delete__(self, obj: Database) -> None:
obj._force_rollback.set(None)


class AsyncHelperDatabase:
def __init__(
self,
database: Database,
fn: typing.Callable,
args: typing.Any,
kwargs: typing.Any,
timeout: typing.Optional[float],
) -> None:
self.database = database
self.fn = fn
self.args = args
self.kwargs = kwargs
self.timeout = timeout
self.ctm = None

async def call(self) -> typing.Any:
async with self.database as database:
return await _arun_with_timeout(
self.fn(database, *self.args, **self.kwargs), self.timeout
)

def __await__(self) -> typing.Any:
return self.call().__await__()

async def __aenter__(self) -> typing.Any:
database = await self.database.__aenter__()
self.ctm = await _arun_with_timeout(
self.fn(database, *self.args, **self.kwargs), timeout=self.timeout
)
return await self.ctm.__aenter__()

async def __aexit__(
self,
exc_type: typing.Optional[typing.Type[BaseException]] = None,
exc_value: typing.Optional[BaseException] = None,
traceback: typing.Optional[TracebackType] = None,
) -> None:
assert self.ctm is not None
try:
await _arun_with_timeout(self.ctm.__aexit__(exc_type, exc_value, traceback), None)
finally:
await self.database.__aexit__()


class Database:
"""
An abstraction on the top of the EncodeORM databases.Database object.
Expand Down Expand Up @@ -156,6 +206,8 @@ class Database:
_force_rollback: ForceRollback
# descriptor
force_rollback = ForceRollbackDescriptor()
# async helper
async_helper: typing.Type[AsyncHelperDatabase] = AsyncHelperDatabase

def __init__(
self,
Expand Down
Loading