Skip to content

Commit

Permalink
Changes:
Browse files Browse the repository at this point in the history
- finish full_isolation
- fix interface of Connection
  • Loading branch information
devkral committed Aug 26, 2024
1 parent c1d1a35 commit 49ab625
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 35 deletions.
104 changes: 86 additions & 18 deletions databasez/core/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
from sqlalchemy import MetaData
from sqlalchemy.sql import ClauseElement

from databasez.types import BatchCallable, BatchCallableResult

from .database import Database


async def _startup(database: Database, is_initialized: Event) -> None:
await database.connect()
await database._global_connection._aenter()
await database._global_connection._aenter() # type: ignore
is_initialized.set()


Expand Down Expand Up @@ -73,7 +75,7 @@ def __init__(
self._force_rollback = force_rollback
self.connection_transaction: typing.Optional[Transaction] = None

@multiloop_protector(False)
@multiloop_protector(False, passthrough_timeout=True) # fail when specifying timeout
async def _aenter(self) -> None:
async with self._connection_lock:
self._connection_counter += 1
Expand Down Expand Up @@ -158,9 +160,9 @@ async def _aexit(self) -> typing.Optional[Thread]:
return thread
else:
await self._aexit_raw()
return None
return None

@multiloop_protector(False)
@multiloop_protector(False, passthrough_timeout=True) # fail when specifying timeout
async def __aexit__(
self,
exc_type: typing.Optional[typing.Type[BaseException]] = None,
Expand All @@ -179,14 +181,17 @@ def _loop(self) -> typing.Any:
return self._database._loop

@property
def _backend(self) -> typing.Any:
def _backend(self) -> interfaces.DatabaseBackend:
return self._database.backend

@multiloop_protector(False)
async def fetch_all(
self,
query: typing.Union[ClauseElement, str],
values: typing.Optional[dict] = None,
timeout: typing.Optional[
float
] = None, # stub for type checker, multiloop_protector handles timeout
) -> typing.List[interfaces.Record]:
built_query = self._build_query(query, values)
async with self._query_lock:
Expand All @@ -198,6 +203,9 @@ async def fetch_one(
query: typing.Union[ClauseElement, str],
values: typing.Optional[dict] = None,
pos: int = 0,
timeout: typing.Optional[
float
] = None, # stub for type checker, multiloop_protector handles timeout
) -> typing.Optional[interfaces.Record]:
built_query = self._build_query(query, values)
async with self._query_lock:
Expand All @@ -210,6 +218,9 @@ async def fetch_val(
values: typing.Optional[dict] = None,
column: typing.Any = 0,
pos: int = 0,
timeout: typing.Optional[
float
] = None, # stub for type checker, multiloop_protector handles timeout
) -> typing.Any:
built_query = self._build_query(query, values)
async with self._query_lock:
Expand All @@ -220,6 +231,9 @@ async def execute(
self,
query: typing.Union[ClauseElement, str],
values: typing.Any = None,
timeout: typing.Optional[
float
] = None, # stub for type checker, multiloop_protector handles timeout
) -> typing.Union[interfaces.Record, int]:
if isinstance(query, str):
built_query = self._build_query(query, values)
Expand All @@ -231,7 +245,12 @@ async def execute(

@multiloop_protector(False)
async def execute_many(
self, query: typing.Union[ClauseElement, str], values: typing.Any = None
self,
query: typing.Union[ClauseElement, str],
values: typing.Any = None,
timeout: typing.Optional[
float
] = None, # stub for type checker, multiloop_protector handles timeout
) -> typing.Union[typing.Sequence[interfaces.Record], int]:
if isinstance(query, str):
built_query = self._build_query(query, None)
Expand All @@ -241,46 +260,90 @@ async def execute_many(
async with self._query_lock:
return await self._connection.execute_many(query, values)

@multiloop_protector(False)
@multiloop_protector(False, passthrough_timeout=True)
async def iterate(
self,
query: typing.Union[ClauseElement, str],
values: typing.Optional[dict] = None,
batch_size: typing.Optional[int] = None,
) -> typing.AsyncGenerator[typing.Any, None]:
chunk_size: typing.Optional[int] = None,
timeout: typing.Optional[float] = None,
) -> typing.AsyncGenerator[interfaces.Record, None]:
built_query = self._build_query(query, values)
if timeout is None or timeout <= 0:
next_fn: typing.Callable[[typing.Any], typing.Awaitable[interfaces.Record]] = anext
else:

async def next_fn(inp: typing.Any) -> interfaces.Record:
return await asyncio.wait_for(anext(aiterator), timeout=timeout)

async with self._query_lock:
async for record in self._connection.iterate(built_query, batch_size):
yield record
aiterator = self._connection.iterate(built_query, chunk_size).__aiter__()
try:
while True:
yield await next_fn(aiterator)
except StopAsyncIteration:
pass

@multiloop_protector(False)
@multiloop_protector(False, passthrough_timeout=True)
async def batched_iterate(
self,
query: typing.Union[ClauseElement, str],
values: typing.Optional[dict] = None,
batch_size: typing.Optional[int] = None,
) -> typing.AsyncGenerator[typing.Any, None]:
batch_wrapper: BatchCallable = tuple,
timeout: typing.Optional[float] = None,
) -> typing.AsyncGenerator[BatchCallableResult, None]:
built_query = self._build_query(query, values)
if timeout is None or timeout <= 0:
next_fn: typing.Callable[
[typing.Any], typing.Awaitable[typing.Sequence[interfaces.Record]]
] = anext
else:

async def next_fn(inp: typing.Any) -> typing.Sequence[interfaces.Record]:
return await asyncio.wait_for(anext(aiterator), timeout=timeout)

async with self._query_lock:
async for records in self._connection.batched_iterate(built_query, batch_size):
yield records
aiterator = self._connection.batched_iterate(built_query, batch_size).__aiter__()
try:
while True:
yield batch_wrapper(await next_fn(aiterator))
except StopAsyncIteration:
pass

@multiloop_protector(False)
async def run_sync(
self,
fn: typing.Callable[..., typing.Any],
*args: typing.Any,
timeout: typing.Optional[
float
] = None, # stub for type checker, multiloop_protector handles timeout
**kwargs: typing.Any,
) -> typing.Any:
async with self._query_lock:
return await self._connection.run_sync(fn, *args, **kwargs)

@multiloop_protector(False)
async def create_all(self, meta: MetaData, **kwargs: typing.Any) -> None:
async def create_all(
self,
meta: MetaData,
timeout: typing.Optional[
float
] = None, # stub for type checker, multiloop_protector handles timeout
**kwargs: typing.Any,
) -> None:
await self.run_sync(meta.create_all, **kwargs)

@multiloop_protector(False)
async def drop_all(self, meta: MetaData, **kwargs: typing.Any) -> None:
async def drop_all(
self,
meta: MetaData,
timeout: typing.Optional[
float
] = None, # stub for type checker, multiloop_protector handles timeout
**kwargs: typing.Any,
) -> None:
await self.run_sync(meta.drop_all, **kwargs)

def transaction(self, *, force_rollback: bool = False, **kwargs: typing.Any) -> "Transaction":
Expand All @@ -293,7 +356,12 @@ def async_connection(self) -> typing.Any:
return self._connection.async_connection

@multiloop_protector(False)
async def get_raw_connection(self) -> typing.Any:
async def get_raw_connection(
self,
timeout: typing.Optional[
float
] = None, # stub for type checker, multiloop_protector handles timeout
) -> typing.Any:
"""The real raw connection (driver)."""
return await self.async_connection.get_raw_connection()

Expand Down
33 changes: 25 additions & 8 deletions databasez/core/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ class Database:
options: typing.Any
is_connected: bool = False
_call_hooks: bool = True
_full_isolation: bool = False
_force_rollback: ForceRollback
# descriptor
force_rollback = ForceRollbackDescriptor()
Expand Down Expand Up @@ -398,7 +399,11 @@ async def fetch_val(
) -> typing.Any:
async with self.connection() as connection:
return await connection.fetch_val(
query, values, column=column, pos=pos, timeout=timeout
query,
values,
column=column,
pos=pos,
timeout=timeout,
)

async def execute(
Expand Down Expand Up @@ -435,14 +440,21 @@ async def batched_iterate(
query: typing.Union[ClauseElement, str],
values: typing.Optional[dict] = None,
batch_size: typing.Optional[int] = None,
batch_wrapper: typing.Union[BatchCallable] = tuple,
batch_wrapper: BatchCallable = tuple,
timeout: typing.Optional[float] = None,
) -> typing.AsyncGenerator[BatchCallableResult, None]:
async with self.connection() as connection:
async for records in connection.batched_iterate(
query, values, batch_size, timeout=timeout
async for batch in typing.cast(
typing.AsyncGenerator["BatchCallableResult", None],
connection.batched_iterate(
query,
values,
batch_wrapper=batch_wrapper,
batch_size=batch_size,
timeout=timeout,
),
):
yield batch_wrapper(records)
yield batch

def transaction(self, *, force_rollback: bool = False, **kwargs: typing.Any) -> "Transaction":
return Transaction(self.connection, force_rollback=force_rollback, **kwargs)
Expand Down Expand Up @@ -470,18 +482,23 @@ async def drop_all(
await connection.drop_all(meta, **kwargs, timeout=timeout)

@multiloop_protector(False)
def _non_global_connection(self) -> Connection:
def _non_global_connection(
self,
timeout: typing.Optional[
float
] = None, # stub for type checker, multiloop_protector handles timeout
) -> Connection:
if self._connection is None:
_connection = self._connection = Connection(self)
return _connection
return self._connection

def connection(self) -> Connection:
def connection(self, timeout: typing.Optional[float] = None) -> Connection:
if not self.is_connected:
raise RuntimeError("Database is not connected")
if self.force_rollback:
return typing.cast(Connection, self._global_connection)
return self._non_global_connection()
return self._non_global_connection(timeout=timeout)

@property
@multiloop_protector(True)
Expand Down
2 changes: 1 addition & 1 deletion databasez/testclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class DatabaseTestClient(Database):
# is used for copying Database and DatabaseTestClientand providing an early url
test_db_url: str
# hooks for overwriting defaults of args with None
testclient_default_full_isolation: bool = False
testclient_default_full_isolation: bool = True
testclient_default_force_rollback: bool = False
testclient_default_lazy_setup: bool = False
# customization hooks
Expand Down
14 changes: 7 additions & 7 deletions databasez/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ async def __aenter__(self) -> typing.Any:
self.ctm = await _arun_with_timeout(
self.fn(database, *self.args, **self.kwargs), timeout=self.timeout
)
return await _arun_with_timeout(self.ctm.__aenter__(), self.timeout)
return await self.ctm.__aenter__()

async def __aexit__(
self,
Expand All @@ -194,9 +194,7 @@ async def __aexit__(
) -> None:
assert self.ctm is not None
try:
await _arun_with_timeout(
self.ctm.__aexit__(exc_type, exc_value, traceback), self.timeout
)
await _arun_with_timeout(self.ctm.__aexit__(exc_type, exc_value, traceback), None)
finally:
await self.database.__aexit__()

Expand Down Expand Up @@ -230,7 +228,7 @@ def __await__(self) -> typing.Any:
async def enter_intern(self) -> typing.Any:
await self.connection.__aenter__()
self.ctm = await self.call()
return await _arun_with_timeout(self.ctm.__aenter__(), self.timeout)
return await self.ctm.__aenter__()

async def exit_intern(self) -> typing.Any:
assert self.ctm is not None
Expand All @@ -256,7 +254,7 @@ async def __aexit__(


def multiloop_protector(
fail_with_different_loop: bool, inject_parent: bool = False
fail_with_different_loop: bool, inject_parent: bool = False, passthrough_timeout: bool = False
) -> typing.Callable[[MultiloopProtectorCallable], MultiloopProtectorCallable]:
"""For multiple threads or other reasons why the loop changes"""

Expand All @@ -267,9 +265,11 @@ def _decorator(fn: MultiloopProtectorCallable) -> MultiloopProtectorCallable:
def wrapper(
self: typing.Any,
*args: typing.Any,
timeout: typing.Optional[float] = None,
**kwargs: typing.Any,
) -> typing.Any:
timeout: typing.Optional[float] = None
if not passthrough_timeout and "timeout" in kwargs:
timeout = kwargs.pop("timeout")
if inject_parent:
assert "parent_database" not in kwargs, '"parent_database" is a reserved keyword'
try:
Expand Down
4 changes: 3 additions & 1 deletion docs/connections-and-transactions.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ from databasez import Database
<sup>Default: `None`</sup>

* **full_isolation** - Special mode for using force_rollback with nested queries. This parameter fully isolates the global connection
in an extra thread. This way it is possible to use blocking operations.
in an extra thread. This way it is possible to use blocking operations like locks with force_rollback.
This parameter has no use when used without force_rollback and causes a slightly slower setup (Lock is initialized).
It is required for edgy or other frameworks which use threads in tests and the force_rollback parameter.

<sup>Default: `None`</sup>

Expand Down
5 changes: 5 additions & 0 deletions docs/release-notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
### Added

- `full_isolation` parameter. Isolate the force_rollback Connection in a thread.
- Timeouts for operations.

### Fixed

- `batched_iterate` interface of `Connection` differed from the one of `Database`.


## 0.9.7
Expand Down
8 changes: 8 additions & 0 deletions docs/test-client.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ from databasez.testclient import DatabaseTestClient

<sup>Default: `None`, copy default or `testclient_default_force_rollback` (defaults to `False`) </sup>

* **full_isolation** - Special mode for using force_rollback with nested queries. This parameter fully isolates the global connection
in an extra thread. This way it is possible to use blocking operations like locks with force_rollback.
This parameter has no use when used without force_rollback and causes a slightly slower setup (Lock is initialized).
It is required for edgy or other frameworks which use threads in tests and the force_rollback parameter.
For the DatabaseTestClient it is enabled by default.

<sup>Default: `None`, copy default or `testclient_default_full_isolation` (defaults to `True`) </sup>

* **lazy_setup** - This sets up the db first up on connect not in init.

<sup>Default: `None`, True if copying a database or `testclient_default_lazy_setup` (defaults to `False`)</sup>
Expand Down

0 comments on commit 49ab625

Please sign in to comment.