diff --git a/databasez/__init__.py b/databasez/__init__.py
index d69f2a5..c084037 100644
--- a/databasez/__init__.py
+++ b/databasez/__init__.py
@@ -1,5 +1,5 @@
from databasez.core import Database, DatabaseURL
-__version__ = "0.9.7"
+__version__ = "0.10.0"
__all__ = ["Database", "DatabaseURL"]
diff --git a/databasez/core/connection.py b/databasez/core/connection.py
index 94b3962..1ed5904 100644
--- a/databasez/core/connection.py
+++ b/databasez/core/connection.py
@@ -1,8 +1,11 @@
from __future__ import annotations
import asyncio
+import sys
import typing
import weakref
+from contextvars import copy_context
+from threading import Event, RLock, Thread, current_thread
from types import TracebackType
from sqlalchemy import text
@@ -16,30 +19,72 @@
from sqlalchemy import MetaData
from sqlalchemy.sql import ClauseElement
+ from databasez.types import BatchCallable, BatchCallableResult
+
from .database import Database
+async def _startup(database: Database, is_initialized: Event) -> None:
+ await database.connect()
+ _global_connection = typing.cast(Connection, database._global_connection)
+ await _global_connection._aenter()
+ if sys.version_info < (3, 10):
+ # for old python versions <3.10 the locks must be created in the same event loop
+ _global_connection._query_lock = asyncio.Lock()
+ _global_connection._connection_lock = asyncio.Lock()
+ _global_connection._transaction_lock = asyncio.Lock()
+ is_initialized.set()
+
+
+def _init_thread(database: Database, is_initialized: Event) -> None:
+ loop = asyncio.new_event_loop()
+ task = loop.create_task(_startup(database, is_initialized))
+ try:
+ loop.run_forever()
+ except RuntimeError:
+ pass
+ try:
+ task.result()
+ finally:
+ try:
+ loop.run_until_complete(database.disconnect())
+ loop.run_until_complete(loop.shutdown_asyncgens())
+ finally:
+ del task
+ loop.close()
+ database._loop = None
+
+
class Connection:
def __init__(
- self, database: Database, backend: interfaces.DatabaseBackend, force_rollback: bool = False
+ self, database: Database, force_rollback: bool = False, full_isolation: bool = False
) -> None:
- self._database = database
- self._backend = backend
-
+ self._orig_database = self._database = database
+ self._full_isolation = full_isolation
+ self._connection_thread_lock: typing.Optional[RLock] = None
+ self._isolation_thread: typing.Optional[Thread] = None
+ if self._full_isolation:
+ self._database = database.__class__(
+ database, force_rollback=force_rollback, full_isolation=False
+ )
+ self._database._call_hooks = False
+ self._database._global_connection = self
+ self._connection_thread_lock = RLock()
+ # the asyncio locks are overwritten in python versions < 3.10 when using full_isolation
+ self._query_lock = asyncio.Lock()
self._connection_lock = asyncio.Lock()
+ self._transaction_lock = asyncio.Lock()
self._connection = self._backend.connection()
self._connection.owner = self
self._connection_counter = 0
- self._transaction_lock = asyncio.Lock()
self._transaction_stack: typing.List[Transaction] = []
- self._query_lock = asyncio.Lock()
self._force_rollback = force_rollback
self.connection_transaction: typing.Optional[Transaction] = None
- @multiloop_protector(False)
- async def __aenter__(self) -> Connection:
+ @multiloop_protector(False, passthrough_timeout=True) # fail when specifying timeout
+ async def _aenter(self) -> None:
async with self._connection_lock:
self._connection_counter += 1
try:
@@ -63,38 +108,98 @@ async def __aenter__(self) -> Connection:
except BaseException as e:
self._connection_counter -= 1
raise e
+
+ async def __aenter__(self) -> Connection:
+ initialized = False
+ if self._full_isolation:
+ assert self._connection_thread_lock is not None
+ with self._connection_thread_lock:
+ if self._isolation_thread is None:
+ initialized = True
+ is_initialized = Event()
+ ctx = copy_context()
+ self._isolation_thread = thread = Thread(
+ target=ctx.run,
+ args=[
+ _init_thread,
+ self._database,
+ is_initialized,
+ ],
+ daemon=True,
+ )
+ thread.start()
+ is_initialized.wait()
+ if not thread.is_alive():
+ self._isolation_thread = None
+ thread.join()
+ if not initialized:
+ await self._aenter()
return self
- @multiloop_protector(False)
- async def __aexit__(
- self,
- exc_type: typing.Optional[typing.Type[BaseException]] = None,
- exc_value: typing.Optional[BaseException] = None,
- traceback: typing.Optional[TracebackType] = None,
- ) -> None:
+ async def _aexit_raw(self) -> bool:
+ closing = False
async with self._connection_lock:
assert self._connection is not None
self._connection_counter -= 1
if self._connection_counter == 0:
+ closing = True
try:
if self.connection_transaction:
# __aexit__ needs the connection_transaction parameter
- await self.connection_transaction.__aexit__(exc_type, exc_value, traceback)
+ await self.connection_transaction.__aexit__()
# untie, for allowing gc
self.connection_transaction = None
finally:
await self._connection.release()
self._database._connection = None
+ return closing
+
+ async def _aexit(self) -> typing.Optional[Thread]:
+ if self._full_isolation:
+ assert self._connection_thread_lock is not None
+ with self._connection_thread_lock:
+ if await self._aexit_raw():
+ loop = self._database._loop
+ thread = self._isolation_thread
+ if loop is not None:
+ loop.stop()
+ else:
+ self._isolation_thread = None
+ return thread
+ else:
+ await self._aexit_raw()
+ return None
+
+ @multiloop_protector(False, passthrough_timeout=True) # fail when specifying timeout
+ async def __aexit__(
+ self,
+ exc_type: typing.Optional[typing.Type[BaseException]] = None,
+ exc_value: typing.Optional[BaseException] = None,
+ traceback: typing.Optional[TracebackType] = None,
+ ) -> None:
+ thread = None
+ try:
+ thread = await self._aexit()
+ finally:
+ if thread is not None and thread is not current_thread():
+ thread.join()
@property
def _loop(self) -> typing.Any:
return self._database._loop
+ @property
+ def _backend(self) -> interfaces.DatabaseBackend:
+ return self._database.backend
+
@multiloop_protector(False)
async def fetch_all(
self,
query: typing.Union[ClauseElement, str],
values: typing.Optional[dict] = None,
+ timeout: typing.Optional[
+ float
+ ] = None, # stub for type checker, multiloop_protector handles timeout
) -> typing.List[interfaces.Record]:
built_query = self._build_query(query, values)
async with self._query_lock:
@@ -106,6 +211,9 @@ async def fetch_one(
query: typing.Union[ClauseElement, str],
values: typing.Optional[dict] = None,
pos: int = 0,
+ timeout: typing.Optional[
+ float
+ ] = None, # stub for type checker, multiloop_protector handles timeout
) -> typing.Optional[interfaces.Record]:
built_query = self._build_query(query, values)
async with self._query_lock:
@@ -118,6 +226,9 @@ async def fetch_val(
values: typing.Optional[dict] = None,
column: typing.Any = 0,
pos: int = 0,
+ timeout: typing.Optional[
+ float
+ ] = None, # stub for type checker, multiloop_protector handles timeout
) -> typing.Any:
built_query = self._build_query(query, values)
async with self._query_lock:
@@ -128,6 +239,9 @@ async def execute(
self,
query: typing.Union[ClauseElement, str],
values: typing.Any = None,
+ timeout: typing.Optional[
+ float
+ ] = None, # stub for type checker, multiloop_protector handles timeout
) -> typing.Union[interfaces.Record, int]:
if isinstance(query, str):
built_query = self._build_query(query, values)
@@ -139,7 +253,12 @@ async def execute(
@multiloop_protector(False)
async def execute_many(
- self, query: typing.Union[ClauseElement, str], values: typing.Any = None
+ self,
+ query: typing.Union[ClauseElement, str],
+ values: typing.Any = None,
+ timeout: typing.Optional[
+ float
+ ] = None, # stub for type checker, multiloop_protector handles timeout
) -> typing.Union[typing.Sequence[interfaces.Record], int]:
if isinstance(query, str):
built_query = self._build_query(query, None)
@@ -149,49 +268,96 @@ async def execute_many(
async with self._query_lock:
return await self._connection.execute_many(query, values)
- @multiloop_protector(False)
+ @multiloop_protector(False, passthrough_timeout=True)
async def iterate(
self,
query: typing.Union[ClauseElement, str],
values: typing.Optional[dict] = None,
- batch_size: typing.Optional[int] = None,
- ) -> typing.AsyncGenerator[typing.Any, None]:
+ chunk_size: typing.Optional[int] = None,
+ timeout: typing.Optional[float] = None,
+ ) -> typing.AsyncGenerator[interfaces.Record, None]:
built_query = self._build_query(query, values)
+ if timeout is None or timeout <= 0:
+ # anext is available in python 3.10
+
+ async def next_fn(inp: typing.Any) -> interfaces.Record:
+ return await aiterator.__anext__()
+ else:
+
+ async def next_fn(inp: typing.Any) -> interfaces.Record:
+ return await asyncio.wait_for(aiterator.__anext__(), timeout=timeout)
+
async with self._query_lock:
- async for record in self._connection.iterate(built_query, batch_size):
- yield record
+ aiterator = self._connection.iterate(built_query, chunk_size).__aiter__()
+ try:
+ while True:
+ yield await next_fn(aiterator)
+ except StopAsyncIteration:
+ pass
- @multiloop_protector(False)
+ @multiloop_protector(False, passthrough_timeout=True)
async def batched_iterate(
self,
query: typing.Union[ClauseElement, str],
values: typing.Optional[dict] = None,
batch_size: typing.Optional[int] = None,
- ) -> typing.AsyncGenerator[typing.Any, None]:
+ batch_wrapper: BatchCallable = tuple,
+ timeout: typing.Optional[float] = None,
+ ) -> typing.AsyncGenerator[BatchCallableResult, None]:
built_query = self._build_query(query, values)
+ if timeout is None or timeout <= 0:
+ # anext is available in python 3.10
+
+ async def next_fn(inp: typing.Any) -> typing.Sequence[interfaces.Record]:
+ return await aiterator.__anext__()
+ else:
+
+ async def next_fn(inp: typing.Any) -> typing.Sequence[interfaces.Record]:
+ return await asyncio.wait_for(aiterator.__anext__(), timeout=timeout)
+
async with self._query_lock:
- async for records in self._connection.batched_iterate(built_query, batch_size):
- yield records
+ aiterator = self._connection.batched_iterate(built_query, batch_size).__aiter__()
+ try:
+ while True:
+ yield batch_wrapper(await next_fn(aiterator))
+ except StopAsyncIteration:
+ pass
@multiloop_protector(False)
async def run_sync(
self,
fn: typing.Callable[..., typing.Any],
*args: typing.Any,
+ timeout: typing.Optional[
+ float
+ ] = None, # stub for type checker, multiloop_protector handles timeout
**kwargs: typing.Any,
) -> typing.Any:
async with self._query_lock:
return await self._connection.run_sync(fn, *args, **kwargs)
@multiloop_protector(False)
- async def create_all(self, meta: MetaData, **kwargs: typing.Any) -> None:
+ async def create_all(
+ self,
+ meta: MetaData,
+ timeout: typing.Optional[
+ float
+ ] = None, # stub for type checker, multiloop_protector handles timeout
+ **kwargs: typing.Any,
+ ) -> None:
await self.run_sync(meta.create_all, **kwargs)
@multiloop_protector(False)
- async def drop_all(self, meta: MetaData, **kwargs: typing.Any) -> None:
+ async def drop_all(
+ self,
+ meta: MetaData,
+ timeout: typing.Optional[
+ float
+ ] = None, # stub for type checker, multiloop_protector handles timeout
+ **kwargs: typing.Any,
+ ) -> None:
await self.run_sync(meta.drop_all, **kwargs)
- @multiloop_protector(False)
def transaction(self, *, force_rollback: bool = False, **kwargs: typing.Any) -> "Transaction":
return Transaction(weakref.ref(self), force_rollback, **kwargs)
@@ -202,7 +368,12 @@ def async_connection(self) -> typing.Any:
return self._connection.async_connection
@multiloop_protector(False)
- async def get_raw_connection(self) -> typing.Any:
+ async def get_raw_connection(
+ self,
+ timeout: typing.Optional[
+ float
+ ] = None, # stub for type checker, multiloop_protector handles timeout
+ ) -> typing.Any:
"""The real raw connection (driver)."""
return await self.async_connection.get_raw_connection()
diff --git a/databasez/core/database.py b/databasez/core/database.py
index 70da951..4d5c40b 100644
--- a/databasez/core/database.py
+++ b/databasez/core/database.py
@@ -150,6 +150,7 @@ class Database:
options: typing.Any
is_connected: bool = False
_call_hooks: bool = True
+ _full_isolation: bool = False
_force_rollback: ForceRollback
# descriptor
force_rollback = ForceRollbackDescriptor()
@@ -160,6 +161,7 @@ def __init__(
*,
force_rollback: typing.Union[bool, None] = None,
config: typing.Optional["DictAny"] = None,
+ full_isolation: typing.Union[bool, None] = None,
**options: typing.Any,
):
init()
@@ -172,6 +174,8 @@ def __init__(
self._call_hooks = url._call_hooks
if force_rollback is None:
force_rollback = bool(url.force_rollback)
+ if full_isolation is None:
+ full_isolation = bool(url._full_isolation)
else:
url = DatabaseURL(url)
if config and "connection" in config:
@@ -184,6 +188,9 @@ def __init__(
)
if force_rollback is None:
force_rollback = False
+ if full_isolation is None:
+ full_isolation = False
+ self._full_isolation = full_isolation
self._force_rollback = ForceRollback(force_rollback)
self.backend.owner = self
self._connection_map = weakref.WeakKeyDictionary()
@@ -264,10 +271,15 @@ async def connect(self) -> bool:
if self._loop is not None and loop != self._loop:
# copy when not in map
if loop not in self._databases_map:
+ assert (
+ self._global_connection is not None
+ ), "global connection should have been set"
+ # correctly initialize force_rollback with parent value
+ database = self.__class__(
+ self, force_rollback=bool(self.force_rollback), full_isolation=False
+ )
# prevent side effects of connect_hook
- database = self.__copy__()
database._call_hooks = False
- assert self._global_connection
database._global_connection = await self._global_connection.__aenter__()
self._databases_map[loop] = database
# forward call
@@ -289,7 +301,8 @@ async def connect(self) -> bool:
self.is_connected = True
if self._global_connection is None:
- self._global_connection = Connection(self, self.backend, force_rollback=True)
+ connection = Connection(self, force_rollback=True, full_isolation=self._full_isolation)
+ self._global_connection = connection
return True
async def disconnect_hook(self) -> None:
@@ -315,9 +328,13 @@ async def disconnect(
loop = asyncio.get_running_loop()
del parent_database._databases_map[loop]
if force:
- for sub_database in self._databases_map.values():
- await sub_database.disconnect(True)
- self._databases_map = {}
+ if self._databases_map:
+ assert not self._databases_map, "sub databases still active, force terminate them"
+ for sub_database in self._databases_map.values():
+ asyncio.run_coroutine_threadsafe(
+ sub_database.disconnect(True), sub_database._loop
+ )
+ self._databases_map = {}
assert not self._databases_map, "sub databases still active"
try:
@@ -353,111 +370,135 @@ async def __aexit__(
) -> None:
await self.disconnect()
- @multiloop_protector(False)
async def fetch_all(
self,
query: typing.Union[ClauseElement, str],
values: typing.Optional[dict] = None,
+ timeout: typing.Optional[float] = None,
) -> typing.List[interfaces.Record]:
async with self.connection() as connection:
- return await connection.fetch_all(query, values)
+ return await connection.fetch_all(query, values, timeout=timeout)
- @multiloop_protector(False)
async def fetch_one(
self,
query: typing.Union[ClauseElement, str],
values: typing.Optional[dict] = None,
pos: int = 0,
+ timeout: typing.Optional[float] = None,
) -> typing.Optional[interfaces.Record]:
async with self.connection() as connection:
- return await connection.fetch_one(query, values, pos=pos)
- assert connection._connection_counter == 1
+ return await connection.fetch_one(query, values, pos=pos, timeout=timeout)
- @multiloop_protector(False)
async def fetch_val(
self,
query: typing.Union[ClauseElement, str],
values: typing.Optional[dict] = None,
column: typing.Any = 0,
pos: int = 0,
+ timeout: typing.Optional[float] = None,
) -> typing.Any:
async with self.connection() as connection:
- return await connection.fetch_val(query, values, column=column, pos=pos)
+ return await connection.fetch_val(
+ query,
+ values,
+ column=column,
+ pos=pos,
+ timeout=timeout,
+ )
- @multiloop_protector(False)
async def execute(
self,
query: typing.Union[ClauseElement, str],
values: typing.Any = None,
+ timeout: typing.Optional[float] = None,
) -> typing.Union[interfaces.Record, int]:
async with self.connection() as connection:
- return await connection.execute(query, values)
+ return await connection.execute(query, values, timeout=timeout)
- @multiloop_protector(False)
async def execute_many(
- self, query: typing.Union[ClauseElement, str], values: typing.Any = None
+ self,
+ query: typing.Union[ClauseElement, str],
+ values: typing.Any = None,
+ timeout: typing.Optional[float] = None,
) -> typing.Union[typing.Sequence[interfaces.Record], int]:
async with self.connection() as connection:
- return await connection.execute_many(query, values)
+ return await connection.execute_many(query, values, timeout=timeout)
- @multiloop_protector(False)
async def iterate(
self,
query: typing.Union[ClauseElement, str],
values: typing.Optional[dict] = None,
chunk_size: typing.Optional[int] = None,
+ timeout: typing.Optional[float] = None,
) -> typing.AsyncGenerator[interfaces.Record, None]:
async with self.connection() as connection:
- async for record in connection.iterate(query, values, chunk_size):
+ async for record in connection.iterate(query, values, chunk_size, timeout=timeout):
yield record
- @multiloop_protector(False)
async def batched_iterate(
self,
query: typing.Union[ClauseElement, str],
values: typing.Optional[dict] = None,
batch_size: typing.Optional[int] = None,
- batch_wrapper: typing.Union[BatchCallable] = tuple,
+ batch_wrapper: BatchCallable = tuple,
+ timeout: typing.Optional[float] = None,
) -> typing.AsyncGenerator[BatchCallableResult, None]:
async with self.connection() as connection:
- async for records in connection.batched_iterate(query, values, batch_size):
- yield batch_wrapper(records)
+ async for batch in typing.cast(
+ typing.AsyncGenerator["BatchCallableResult", None],
+ connection.batched_iterate(
+ query,
+ values,
+ batch_wrapper=batch_wrapper,
+ batch_size=batch_size,
+ timeout=timeout,
+ ),
+ ):
+ yield batch
- @multiloop_protector(True)
def transaction(self, *, force_rollback: bool = False, **kwargs: typing.Any) -> "Transaction":
return Transaction(self.connection, force_rollback=force_rollback, **kwargs)
- @multiloop_protector(False)
async def run_sync(
self,
fn: typing.Callable[..., typing.Any],
*args: typing.Any,
+ timeout: typing.Optional[float] = None,
**kwargs: typing.Any,
) -> typing.Any:
async with self.connection() as connection:
- return await connection.run_sync(fn, *args, **kwargs)
+ return await connection.run_sync(fn, *args, **kwargs, timeout=timeout)
- @multiloop_protector(False)
- async def create_all(self, meta: MetaData, **kwargs: typing.Any) -> None:
+ async def create_all(
+ self, meta: MetaData, timeout: typing.Optional[float] = None, **kwargs: typing.Any
+ ) -> None:
async with self.connection() as connection:
- await connection.create_all(meta, **kwargs)
+ await connection.create_all(meta, **kwargs, timeout=timeout)
- @multiloop_protector(False)
- async def drop_all(self, meta: MetaData, **kwargs: typing.Any) -> None:
+ async def drop_all(
+ self, meta: MetaData, timeout: typing.Optional[float] = None, **kwargs: typing.Any
+ ) -> None:
async with self.connection() as connection:
- await connection.drop_all(meta, **kwargs)
+ await connection.drop_all(meta, **kwargs, timeout=timeout)
@multiloop_protector(False)
- def connection(self) -> Connection:
+ def _non_global_connection(
+ self,
+ timeout: typing.Optional[
+ float
+ ] = None, # stub for type checker, multiloop_protector handles timeout
+ ) -> Connection:
+ if self._connection is None:
+ _connection = self._connection = Connection(self)
+ return _connection
+ return self._connection
+
+ def connection(self, timeout: typing.Optional[float] = None) -> Connection:
if not self.is_connected:
raise RuntimeError("Database is not connected")
if self.force_rollback:
return typing.cast(Connection, self._global_connection)
-
- if self._connection is None:
- _connection = self._connection = Connection(self, self.backend)
- return _connection
- return self._connection
+ return self._non_global_connection(timeout=timeout)
@property
@multiloop_protector(True)
diff --git a/databasez/testclient.py b/databasez/testclient.py
index c00b262..fc7959d 100644
--- a/databasez/testclient.py
+++ b/databasez/testclient.py
@@ -34,6 +34,7 @@ class DatabaseTestClient(Database):
# is used for copying Database and DatabaseTestClientand providing an early url
test_db_url: str
# hooks for overwriting defaults of args with None
+ testclient_default_full_isolation: bool = True
testclient_default_force_rollback: bool = False
testclient_default_lazy_setup: bool = False
# customization hooks
@@ -47,6 +48,7 @@ def __init__(
*,
force_rollback: typing.Union[bool, None] = None,
use_existing: typing.Union[bool, None] = None,
+ full_isolation: typing.Union[bool, None] = None,
drop_database: typing.Union[bool, None] = None,
lazy_setup: typing.Union[bool, None] = None,
test_prefix: typing.Union[str, None] = None,
@@ -56,6 +58,8 @@ def __init__(
use_existing = self.testclient_default_use_existing
if drop_database is None:
drop_database = self.testclient_default_drop_database
+ if full_isolation is None:
+ full_isolation = self.testclient_default_full_isolation
if test_prefix is None:
test_prefix = self.testclient_default_test_prefix
self._setup_executed_init = False
@@ -144,7 +148,7 @@ async def database_exists(cls, url: typing.Union[str, "sa.URL", DatabaseURL]) ->
text = "SELECT 1 FROM pg_database WHERE datname='%s'" % database
for db in (database, "postgres", "template1", "template0", None):
url = url.replace(database=db)
- async with Database(url) as db_client:
+ async with Database(url, full_isolation=False, force_rollback=False) as db_client:
try:
return bool(await _get_scalar_result(db_client.engine, sa.text(text)))
except (ProgrammingError, OperationalError):
@@ -157,7 +161,7 @@ async def database_exists(cls, url: typing.Union[str, "sa.URL", DatabaseURL]) ->
"SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA "
"WHERE SCHEMA_NAME = '%s'" % database
)
- async with Database(url) as db_client:
+ async with Database(url, full_isolation=False, force_rollback=False) as db_client:
return bool(await _get_scalar_result(db_client.engine, sa.text(text)))
elif dialect_name == "sqlite":
@@ -169,7 +173,7 @@ async def database_exists(cls, url: typing.Union[str, "sa.URL", DatabaseURL]) ->
return True
else:
text = "SELECT 1"
- async with Database(url) as db_client:
+ async with Database(url, full_isolation=False, force_rollback=False) as db_client:
try:
return bool(await _get_scalar_result(db_client.engine, sa.text(text)))
except (ProgrammingError, OperationalError):
@@ -201,9 +205,11 @@ async def create_database(
dialect_name == "postgresql"
and dialect_driver in {"asyncpg", "pg8000", "psycopg", "psycopg2", "psycopg2cffi"}
):
- db_client = Database(url, isolation_level="AUTOCOMMIT", force_rollback=False)
+ db_client = Database(
+ url, isolation_level="AUTOCOMMIT", force_rollback=False, full_isolation=False
+ )
else:
- db_client = Database(url, force_rollback=False)
+ db_client = Database(url, force_rollback=False, full_isolation=False)
async with db_client:
if dialect_name == "postgresql":
if not template:
@@ -255,9 +261,11 @@ async def drop_database(cls, url: typing.Union[str, "sa.URL", DatabaseURL]) -> N
dialect_name == "postgresql"
and dialect_driver in {"asyncpg", "pg8000", "psycopg", "psycopg2", "psycopg2cffi"}
):
- db_client = Database(url, isolation_level="AUTOCOMMIT", force_rollback=False)
+ db_client = Database(
+ url, isolation_level="AUTOCOMMIT", force_rollback=False, full_isolation=False
+ )
else:
- db_client = Database(url, force_rollback=False)
+ db_client = Database(url, force_rollback=False, full_isolation=False)
async with db_client:
if dialect_name == "sqlite" and database and database != ":memory:":
try:
diff --git a/databasez/utils.py b/databasez/utils.py
index 07262f5..b2c7cc4 100644
--- a/databasez/utils.py
+++ b/databasez/utils.py
@@ -140,28 +140,50 @@ def join(self, timeout: typing.Union[float, int, None] = None) -> None:
MultiloopProtectorCallable = typing.TypeVar("MultiloopProtectorCallable", bound=typing.Callable)
+def _run_with_timeout(inp: typing.Any, timeout: typing.Optional[float]) -> typing.Any:
+ if timeout is not None and timeout > 0 and inspect.isawaitable(inp):
+ inp = asyncio.wait_for(inp, timeout=timeout)
+ return inp
+
+
+async def _arun_with_timeout(inp: typing.Any, timeout: typing.Optional[float]) -> typing.Any:
+ if timeout is not None and timeout > 0 and inspect.isawaitable(inp):
+ inp = await asyncio.wait_for(inp, timeout=timeout)
+ elif inspect.isawaitable(inp):
+ return await inp
+ return inp
+
+
class AsyncHelperDatabase:
def __init__(
self,
database: typing.Any,
fn: typing.Callable,
- *args: typing.Any,
- **kwargs: typing.Any,
+ args: typing.Any,
+ kwargs: typing.Any,
+ timeout: typing.Optional[float],
) -> None:
- self.database = database.__copy__()
- self.fn = partial(fn, self.database, *args, **kwargs)
+ self.database = database
+ self.fn = fn
+ self.args = args
+ self.kwargs = kwargs
+ self.timeout = timeout
self.ctm = None
async def call(self) -> typing.Any:
- async with self.database:
- return await self.fn()
+ async with self.database as database:
+ return await _arun_with_timeout(
+ self.fn(database, *self.args, **self.kwargs), self.timeout
+ )
def __await__(self) -> typing.Any:
return self.call().__await__()
async def __aenter__(self) -> typing.Any:
- await self.database.__aenter__()
- self.ctm = self.fn()
+ database = await self.database.__aenter__()
+ self.ctm = await _arun_with_timeout(
+ self.fn(database, *self.args, **self.kwargs), timeout=self.timeout
+ )
return await self.ctm.__aenter__()
async def __aexit__(
@@ -172,7 +194,7 @@ async def __aexit__(
) -> None:
assert self.ctm is not None
try:
- await self.ctm.__aexit__(exc_type, exc_value, traceback)
+ await _arun_with_timeout(self.ctm.__aexit__(exc_type, exc_value, traceback), None)
finally:
await self.database.__aexit__()
@@ -182,18 +204,19 @@ def __init__(
self,
connection: typing.Any,
fn: typing.Callable,
- *args: typing.Any,
- **kwargs: typing.Any,
+ args: typing.Any,
+ kwargs: typing.Any,
+ timeout: typing.Optional[float],
) -> None:
self.connection = connection
self.fn = partial(fn, self.connection, *args, **kwargs)
+ self.timeout = timeout
self.ctm = None
async def call(self) -> typing.Any:
async with self.connection:
- result = self.fn()
- if inspect.isawaitable(result):
- result = await result
+ # is automatically awaited
+ result = await _arun_with_timeout(self.fn(), self.timeout)
return result
async def acall(self) -> typing.Any:
@@ -231,7 +254,7 @@ async def __aexit__(
def multiloop_protector(
- fail_with_different_loop: bool, inject_parent: bool = False
+ fail_with_different_loop: bool, inject_parent: bool = False, passthrough_timeout: bool = False
) -> typing.Callable[[MultiloopProtectorCallable], MultiloopProtectorCallable]:
"""For multiple threads or other reasons why the loop changes"""
@@ -239,7 +262,14 @@ def multiloop_protector(
# needs _loop attribute to check against
def _decorator(fn: MultiloopProtectorCallable) -> MultiloopProtectorCallable:
@wraps(fn)
- def wrapper(self: typing.Any, *args: typing.Any, **kwargs: typing.Any) -> typing.Any:
+ def wrapper(
+ self: typing.Any,
+ *args: typing.Any,
+ **kwargs: typing.Any,
+ ) -> typing.Any:
+ timeout: typing.Optional[float] = None
+ if not passthrough_timeout and "timeout" in kwargs:
+ timeout = kwargs.pop("timeout")
if inject_parent:
assert "parent_database" not in kwargs, '"parent_database" is a reserved keyword'
try:
@@ -262,8 +292,8 @@ def wrapper(self: typing.Any, *args: typing.Any, **kwargs: typing.Any) -> typing
if hasattr(self, "_databases_map")
else AsyncHelperConnection
)
- return helper(self, fn, *args, **kwargs)
- return fn(self, *args, **kwargs)
+ return helper(self, fn, args, kwargs, timeout=timeout)
+ return _run_with_timeout(fn(self, *args, **kwargs), timeout=timeout)
return typing.cast(MultiloopProtectorCallable, wrapper)
diff --git a/docs/connections-and-transactions.md b/docs/connections-and-transactions.md
index ee8300d..19f2b2f 100644
--- a/docs/connections-and-transactions.md
+++ b/docs/connections-and-transactions.md
@@ -24,6 +24,13 @@ from databasez import Database
Default: `None`
+* **full_isolation** - Special mode for using force_rollback with nested queries. This parameter fully isolates the global connection
+ in an extra thread. This way it is possible to use blocking operations like locks with force_rollback.
+ This parameter has no use when used without force_rollback and causes a slightly slower setup (Lock is initialized).
+ It is required for edgy or other frameworks which use threads in tests and the force_rollback parameter.
+
+ Default: `None`
+
* **config** - A python like dictionary as alternative to the `url` that contains the information
to connect to the database.
@@ -35,6 +42,8 @@ to connect to the database.
Be careful when setting up the `url` or `config`. You can use one or the other but not both
at the same time.
+!!! Warning
+ `full_isolation` is not mature and shouldn't be used in production code.
**Attributes***
diff --git a/docs/queries.md b/docs/queries.md
index c6af888..1594e65 100644
--- a/docs/queries.md
+++ b/docs/queries.md
@@ -49,7 +49,7 @@ execute it with databasez.
## Queries
Since you can use [SQLAlchemy core](https://docs.sqlalchemy.org/en/20/core/), that also means you
-can also use the queries. Check out the [official tutorial](https://docs.sqlalchemy.org/en/14/core/tutorial.html).
+can also use the queries. Check out the [official tutorial](https://docs.sqlalchemy.org/en/20/tutorial/).
```python
{!> ../docs_src/queries/queries.py !}
diff --git a/docs/release-notes.md b/docs/release-notes.md
index c099b24..10315a2 100644
--- a/docs/release-notes.md
+++ b/docs/release-notes.md
@@ -1,6 +1,22 @@
# Release Notes
+## 0.10.0
+
+
+### Added
+
+- `full_isolation` parameter. Isolate the force_rollback Connection in a thread.
+- Timeouts for operations.
+
+### Fixed
+
+- `batched_iterate` interface of `Connection` differed from the one of `Database`.
+- `iterate` interface of `Connection` differed from the one of `Database`.
+- Hooks were called on automatically created Database objects.
+- More multithreading safety.
+
+
## 0.9.7
### Added
diff --git a/docs/test-client.md b/docs/test-client.md
index df2b4db..3771b11 100644
--- a/docs/test-client.md
+++ b/docs/test-client.md
@@ -48,6 +48,14 @@ from databasez.testclient import DatabaseTestClient
Default: `None`, copy default or `testclient_default_force_rollback` (defaults to `False`)
+* **full_isolation** - Special mode for using force_rollback with nested queries. This parameter fully isolates the global connection
+ in an extra thread. This way it is possible to use blocking operations like locks with force_rollback.
+ This parameter has no use when used without force_rollback and causes a slightly slower setup (Lock is initialized).
+ It is required for edgy or other frameworks which use threads in tests and the force_rollback parameter.
+ For the DatabaseTestClient it is enabled by default.
+
+ Default: `None`, copy default or `testclient_default_full_isolation` (defaults to `True`)
+
* **lazy_setup** - This sets up the db first up on connect not in init.
Default: `None`, True if copying a database or `testclient_default_lazy_setup` (defaults to `False`)
diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py
new file mode 100644
index 0000000..c89d2c3
--- /dev/null
+++ b/tests/test_concurrency.py
@@ -0,0 +1,246 @@
+import asyncio
+import contextvars
+import functools
+import os
+from concurrent.futures import Future
+from threading import Thread
+
+import pyodbc
+import pytest
+
+from databasez import Database, DatabaseURL
+from tests.shared_db import (
+ database_client,
+ notes,
+ stop_database_client,
+)
+
+assert "TEST_DATABASE_URLS" in os.environ, "TEST_DATABASE_URLS is not set."
+
+DATABASE_URLS = [url.strip() for url in os.environ["TEST_DATABASE_URLS"].split(",")]
+
+if not any((x.endswith(" for SQL Server") for x in pyodbc.drivers())):
+ DATABASE_URLS = list(filter(lambda x: "mssql" not in x, DATABASE_URLS))
+
+
+try:
+ to_thread = asyncio.to_thread
+except AttributeError:
+ # for py <= 3.8
+ async def to_thread(func, /, *args, **kwargs):
+ loop = asyncio.get_running_loop()
+ ctx = contextvars.copy_context()
+ func_call = functools.partial(ctx.run, func, *args, **kwargs)
+ return await loop.run_in_executor(None, func_call)
+
+
+@pytest.fixture(params=DATABASE_URLS)
+def database_url(request):
+ """Yield test database despite its name"""
+ # yield test Databases
+ loop = asyncio.new_event_loop()
+ database = loop.run_until_complete(database_client(request.param))
+ yield database
+ loop.run_until_complete(stop_database_client(database))
+
+
+def _startswith(tested, params):
+ for param in params:
+ if tested.startswith(param):
+ return True
+ return False
+
+
+@pytest.mark.asyncio
+async def test_concurrent_access_on_single_connection(database_url):
+ database_url = DatabaseURL(str(database_url.url))
+ if not _startswith(database_url.dialect, ["mysql", "mariadb", "postgres", "mssql"]):
+ pytest.skip("Test requires sleep function")
+ async with Database(database_url, force_rollback=True, full_isolation=False) as database:
+
+ async def db_lookup():
+ if database_url.dialect.startswith("postgres"):
+ await database.fetch_one("SELECT pg_sleep(0.3)")
+ elif database_url.dialect.startswith("mysql") or database_url.dialect.startswith(
+ "mariadb"
+ ):
+ await database.fetch_one("SELECT SLEEP(0.3)")
+ elif database_url.dialect.startswith("mssql"):
+ await database.execute("WAITFOR DELAY '00:00:00.300'")
+
+ await asyncio.gather(db_lookup(), db_lookup(), db_lookup())
+
+
+def _future_helper(awaitable, future):
+ try:
+ future.set_result(asyncio.run(awaitable))
+ except BaseException as exc:
+ future.set_exception(exc)
+
+
+@pytest.mark.parametrize(
+ "join_type,full_isolation",
+ [
+ ("to_thread", False),
+ ("to_thread", True),
+ ("thread_join_with_context", True),
+ ("thread_join_without_context", True),
+ ],
+)
+@pytest.mark.parametrize("force_rollback", [True, False])
+@pytest.mark.asyncio
+async def test_multi_thread_db(database_url, force_rollback, join_type, full_isolation):
+ database_url = DatabaseURL(str(database_url.url))
+ async with Database(
+ database_url, force_rollback=force_rollback, full_isolation=full_isolation
+ ) as database:
+
+ async def db_lookup(in_thread):
+ async with database.connection() as conn:
+ assert bool(conn._database.force_rollback) == force_rollback
+ if not _startswith(database_url.dialect, ["mysql", "mariadb", "postgres", "mssql"]):
+ return
+ if database_url.dialect.startswith("postgres"):
+ await database.fetch_one("SELECT pg_sleep(0.3)")
+ elif database_url.dialect.startswith("mysql") or database_url.dialect.startswith(
+ "mariadb"
+ ):
+ await database.fetch_one("SELECT SLEEP(0.3)")
+ elif database_url.dialect.startswith("mssql"):
+ await database.execute("WAITFOR DELAY '00:00:00.300'")
+
+ async def wrap_in_thread():
+ if join_type.startswith("thread_join"):
+ future = Future()
+ args = [_future_helper, asyncio.wait_for(db_lookup(True), 3), future]
+ if join_type == "thread_join_with_context":
+ ctx = contextvars.copy_context()
+ args.insert(0, ctx.run)
+ thread = Thread(target=args[0], args=args[1:])
+ thread.start()
+ future.result(4)
+ else:
+ await to_thread(asyncio.run, asyncio.wait_for(db_lookup(True), 3))
+
+ await asyncio.gather(db_lookup(False), wrap_in_thread(), wrap_in_thread())
+
+
+@pytest.mark.parametrize(
+ "join_type,full_isolation",
+ [
+ ("to_thread", False),
+ ("to_thread", True),
+ ("thread_join_with_context", True),
+ ("thread_join_without_context", True),
+ ],
+)
+@pytest.mark.parametrize("force_rollback", [True, False])
+@pytest.mark.asyncio
+async def test_multi_thread_db_contextmanager(
+ database_url, force_rollback, join_type, full_isolation
+):
+ async with Database(
+ database_url, force_rollback=force_rollback, full_isolation=full_isolation
+ ) as database:
+ query = notes.insert().values(text="examplecontext", completed=True)
+ await database.execute(query, timeout=10)
+ database._non_copied_attribute = True
+
+ async def db_connect(depth=3):
+ # many parallel and nested threads
+ async with database as new_database:
+ assert not hasattr(new_database, "_non_copied_attribute")
+ query = notes.select()
+ result = await database.fetch_one(query)
+ assert result.text == "examplecontext"
+ assert result.completed is True
+ # test delegate to sub database
+ assert database.engine is new_database.engine
+ # also this shouldn't fail because redirected
+ old_refcount = new_database.ref_counter
+ await database.connect()
+ assert new_database.ref_counter == old_refcount + 1
+ await database.disconnect()
+ ops = []
+ while depth >= 0:
+ depth -= 1
+ ops.append(to_thread(asyncio.run, db_connect(depth=depth)))
+ await asyncio.gather(*ops)
+ assert new_database.ref_counter == 0
+
+ if join_type.startswith("thread_join"):
+ future = Future()
+ args = [_future_helper, asyncio.wait_for(db_connect(), 3), future]
+ if join_type == "thread_join_with_context":
+ ctx = contextvars.copy_context()
+ args.insert(0, ctx.run)
+ thread = Thread(target=args[0], args=args[1:])
+ thread.start()
+ future.result(4)
+ else:
+ await to_thread(asyncio.run, asyncio.wait_for(db_connect(), 3))
+ assert database.ref_counter == 0
+ if force_rollback:
+ async with database:
+ query = notes.select()
+ result = await database.fetch_one(query)
+ assert result is None
+
+
+@pytest.mark.asyncio
+async def test_multi_thread_db_connect(database_url):
+ async with Database(database_url, force_rollback=True) as database:
+
+ async def db_connect():
+ await database.connect()
+ await database.fetch_one("SELECT 1")
+ await database.disconnect()
+
+ await to_thread(asyncio.run, db_connect())
+
+
+@pytest.mark.asyncio
+async def test_multi_thread_db_fails(database_url):
+ async with Database(database_url, force_rollback=True) as database:
+
+ async def db_connect():
+ # not in same loop
+ database.disconnect()
+
+ with pytest.raises(RuntimeError):
+ await to_thread(asyncio.run, db_connect())
+
+
+@pytest.mark.asyncio
+async def test_global_connection_is_initialized_lazily(database_url):
+ """
+ Ensure that global connection is initialized at latest possible time
+ so it's _query_lock will belong to same event loop that async_adapter has
+ initialized.
+
+ See https://github.com/dymmond/databasez/issues/157 for more context.
+ """
+
+ database_url = DatabaseURL(database_url.url)
+ if not _startswith(database_url.dialect, ["mysql", "mariadb", "postgres", "mssql"]):
+ pytest.skip("Test requires sleep function")
+
+ database = Database(database_url, force_rollback=True)
+
+ async def run_database_queries():
+ async with database:
+
+ async def db_lookup():
+ if database_url.dialect.startswith("postgres"):
+ await database.fetch_one("SELECT pg_sleep(0.3)")
+ elif database_url.dialect.startswith("mysql") or database_url.dialect.startswith(
+ "mariadb"
+ ):
+ await database.fetch_one("SELECT SLEEP(0.3)")
+ elif database_url.dialect.startswith("mssql"):
+ await database.execute("WAITFOR DELAY '00:00:00.300'")
+
+ await asyncio.gather(db_lookup(), db_lookup(), db_lookup())
+
+ await run_database_queries()
+ await database.disconnect()
diff --git a/tests/test_databases.py b/tests/test_databases.py
index c851e94..cee4044 100644
--- a/tests/test_databases.py
+++ b/tests/test_databases.py
@@ -1,8 +1,6 @@
import asyncio
-import contextvars
import datetime
import decimal
-import functools
import os
from collections.abc import Sequence
from unittest.mock import patch
@@ -57,17 +55,6 @@
MIXED_DATABASE_CONFIG_URLS_IDS = [*DATABASE_URLS, *(f"{x}[config]" for x in DATABASE_URLS)]
-try:
- to_thread = asyncio.to_thread
-except AttributeError:
- # for py <= 3.8
- async def to_thread(func, /, *args, **kwargs):
- loop = asyncio.get_running_loop()
- ctx = contextvars.copy_context()
- func_call = functools.partial(ctx.run, func, *args, **kwargs)
- return await loop.run_in_executor(None, func_call)
-
-
@pytest.fixture(params=DATABASE_URLS)
def database_url(request):
"""Yield test database despite its name"""
@@ -831,127 +818,8 @@ async def test_database_url_interface(database_mixed_url):
assert database.url == database_mixed_url
-def _startswith(tested, params):
- for param in params:
- if tested.startswith(param):
- return True
- return False
-
-
-@pytest.mark.asyncio
-async def test_concurrent_access_on_single_connection(database_url):
- database_url = DatabaseURL(str(database_url.url))
- if not _startswith(database_url.dialect, ["mysql", "mariadb", "postgres", "mssql"]):
- pytest.skip("Test requires sleep function")
- async with Database(database_url, force_rollback=True) as database:
-
- async def db_lookup():
- if database_url.dialect.startswith("postgres"):
- await database.fetch_one("SELECT pg_sleep(0.3)")
- elif database_url.dialect.startswith("mysql") or database_url.dialect.startswith(
- "mariadb"
- ):
- await database.fetch_one("SELECT SLEEP(0.3)")
- elif database_url.dialect.startswith("mssql"):
- await database.execute("WAITFOR DELAY '00:00:00.300'")
-
- await asyncio.gather(db_lookup(), db_lookup(), db_lookup())
-
-
-@pytest.mark.parametrize("force_rollback", [True, False])
-@pytest.mark.asyncio
-async def test_multi_thread(database_url, force_rollback):
- database_url = DatabaseURL(str(database_url.url))
- async with Database(database_url, force_rollback=force_rollback) as database:
- database._non_copied_attribute = True
-
- async def db_lookup(in_thread):
- async with database.connection() as conn:
- if in_thread:
- assert not hasattr(conn._database, "_non_copied_attribute")
- else:
- assert hasattr(conn._database, "_non_copied_attribute")
- assert bool(conn._database.force_rollback) == force_rollback
- if not _startswith(database_url.dialect, ["mysql", "mariadb", "postgres", "mssql"]):
- return
- if database_url.dialect.startswith("postgres"):
- await database.fetch_one("SELECT pg_sleep(0.3)")
- elif database_url.dialect.startswith("mysql") or database_url.dialect.startswith(
- "mariadb"
- ):
- await database.fetch_one("SELECT SLEEP(0.3)")
- elif database_url.dialect.startswith("mssql"):
- await database.execute("WAITFOR DELAY '00:00:00.300'")
-
- async def wrap_in_thread():
- await to_thread(asyncio.run, db_lookup(True))
-
- await asyncio.gather(db_lookup(False), wrap_in_thread(), wrap_in_thread())
-
-
-@pytest.mark.parametrize("force_rollback", [True, False])
@pytest.mark.asyncio
-async def test_multi_thread_db_contextmanager(database_url, force_rollback):
- async with Database(database_url, force_rollback=force_rollback) as database:
- query = notes.insert().values(text="examplecontext", completed=True)
- await database.execute(query)
-
- async def db_connect(depth=3):
- # many parallel and nested threads
- async with database as new_database:
- query = notes.select()
- result = await database.fetch_one(query)
- assert result.text == "examplecontext"
- assert result.completed is True
- # test delegate to sub database
- assert database.engine is new_database.engine
- # also this shouldn't fail because redirected
- old_refcount = new_database.ref_counter
- await database.connect()
- assert new_database.ref_counter == old_refcount + 1
- await database.disconnect()
- ops = []
- while depth >= 0:
- depth -= 1
- ops.append(to_thread(asyncio.run, db_connect(depth=depth)))
- await asyncio.gather(*ops)
- assert new_database.ref_counter == 0
-
- await to_thread(asyncio.run, db_connect())
- assert database.ref_counter == 0
- if force_rollback:
- async with database:
- query = notes.select()
- result = await database.fetch_one(query)
- assert result is None
-
-
-@pytest.mark.asyncio
-async def test_multi_thread_db_connect(database_url):
- async with Database(database_url, force_rollback=True) as database:
-
- async def db_connect():
- await database.connect()
- await database.fetch_one("SELECT 1")
- await database.disconnect()
-
- await to_thread(asyncio.run, db_connect())
-
-
-@pytest.mark.asyncio
-async def test_multi_thread_db_fails(database_url):
- async with Database(database_url, force_rollback=True) as database:
-
- async def db_connect():
- # not in same loop
- database.disconnect()
-
- with pytest.raises(RuntimeError):
- await to_thread(asyncio.run, db_connect())
-
-
-@pytest.mark.asyncio
-async def test_error_on_passed_parent_database(database_url):
+async def test_error_on_passed_parent_database_argument(database_url):
database = Database(database_url)
# don't allow specifying parent_database
with pytest.raises(AssertionError):
@@ -962,41 +830,6 @@ async def test_error_on_passed_parent_database(database_url):
await database.disconnect(False, None)
-@pytest.mark.asyncio
-async def test_global_connection_is_initialized_lazily(database_url):
- """
- Ensure that global connection is initialized at latest possible time
- so it's _query_lock will belong to same event loop that async_adapter has
- initialized.
-
- See https://github.com/dymmond/databasez/issues/157 for more context.
- """
-
- database_url = DatabaseURL(database_url.url)
- if not _startswith(database_url.dialect, ["mysql", "mariadb", "postgres", "mssql"]):
- pytest.skip("Test requires sleep function")
-
- database = Database(database_url, force_rollback=True)
-
- async def run_database_queries():
- async with database:
-
- async def db_lookup():
- if database_url.dialect.startswith("postgres"):
- await database.fetch_one("SELECT pg_sleep(0.3)")
- elif database_url.dialect.startswith("mysql") or database_url.dialect.startswith(
- "mariadb"
- ):
- await database.fetch_one("SELECT SLEEP(0.3)")
- elif database_url.dialect.startswith("mssql"):
- await database.execute("WAITFOR DELAY '00:00:00.300'")
-
- await asyncio.gather(db_lookup(), db_lookup(), db_lookup())
-
- await run_database_queries()
- await database.disconnect()
-
-
@pytest.mark.parametrize("select_query", [notes.select(), "SELECT * FROM notes"])
@pytest.mark.asyncio
async def test_column_names(database_url, select_query):