diff --git a/.github/workflows/test-suite.yml b/.github/workflows/test-suite.yml index 1f92b82..4bf0706 100644 --- a/.github/workflows/test-suite.yml +++ b/.github/workflows/test-suite.yml @@ -16,7 +16,7 @@ jobs: runs-on: "ubuntu-latest" strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] + python-version: ["3.9", "3.10", "3.11", "3.12"] services: mariadb: diff --git a/README.md b/README.md index 302c6c0..8f9e5a5 100644 --- a/README.md +++ b/README.md @@ -56,7 +56,7 @@ Databasez is suitable for integrating against any async Web framework, such as [ [Starlette][starlette], [Sanic][sanic], [Responder][responder], [Quart][quart], [aiohttp][aiohttp], [Tornado][tornado], or [FastAPI][fastapi]. -Databasez was built for Python 3.8+ and on the top of the newest **SQLAlchemy 2** and gives you +Databasez was built for Python 3.9+ and on the top of the newest **SQLAlchemy 2** and gives you simple asyncio support for a range of databases. ### Special notes diff --git a/databasez/core/asgi.py b/databasez/core/asgi.py index a82e5bc..e2a1c4e 100644 --- a/databasez/core/asgi.py +++ b/databasez/core/asgi.py @@ -1,17 +1,18 @@ from __future__ import annotations +from collections.abc import Awaitable, Callable from contextlib import suppress from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from databasez.core.Database import Database ASGIApp = Callable[ [ - Dict[str, Any], - Callable[[], Awaitable[Dict[str, Any]]], - Callable[[Dict[str, Any]], Awaitable[None]], + dict[str, Any], + Callable[[], Awaitable[dict[str, Any]]], + Callable[[dict[str, Any]], Awaitable[None]], ], Awaitable[None], ] @@ -29,14 +30,14 @@ class ASGIHelper: async def __call__( self, - scope: Dict[str, Any], - receive: Callable[[], Awaitable[Dict[str, Any]]], - send: Callable[[Dict[str, Any]], Awaitable[None]], + scope: dict[str, Any], + receive: Callable[[], Awaitable[dict[str, Any]]], + send: Callable[[dict[str, Any]], Awaitable[None]], ) -> None: if scope["type"] == "lifespan": original_receive = receive - async def receive() -> Dict[str, Any]: + async def receive() -> dict[str, Any]: message = await original_receive() if message["type"] == "lifespan.startup": try: diff --git a/databasez/core/connection.py b/databasez/core/connection.py index de585cf..bd8eb08 100644 --- a/databasez/core/connection.py +++ b/databasez/core/connection.py @@ -1,11 +1,12 @@ from __future__ import annotations import asyncio -import typing import weakref +from collections.abc import AsyncGenerator, Callable, Sequence from functools import partial from threading import Event, Lock, Thread, current_thread from types import TracebackType +from typing import TYPE_CHECKING, Any, cast from sqlalchemy import text @@ -14,7 +15,7 @@ from .transaction import Transaction -if typing.TYPE_CHECKING: +if TYPE_CHECKING: from sqlalchemy import MetaData from sqlalchemy.sql import ClauseElement @@ -25,7 +26,7 @@ async def _startup(database: Database, is_initialized: Event) -> None: await database.connect() - _global_connection = typing.cast(Connection, database._global_connection) + _global_connection = cast(Connection, database._global_connection) await _global_connection._aenter() # we ensure fresh locks _global_connection._query_lock = asyncio.Lock() @@ -64,30 +65,30 @@ class AsyncHelperConnection: def __init__( self, connection: Connection, - fn: typing.Callable, - args: typing.Any, - kwargs: typing.Any, - timeout: typing.Optional[float], + fn: Callable, + args: Any, + kwargs: Any, + timeout: float | None, ) -> None: self.connection = connection self.fn = partial(fn, self.connection, *args, **kwargs) self.timeout = timeout self.ctm = None - async def call(self) -> typing.Any: + async def call(self) -> Any: # is automatically awaited result = await _arun_with_timeout(self.fn(), self.timeout) return result - async def acall(self) -> typing.Any: + async def acall(self) -> Any: return await arun_coroutine_threadsafe( self.call(), self.connection._loop, self.connection.poll_interval ) - def __await__(self) -> typing.Any: + def __await__(self) -> Any: return self.acall().__await__() - async def __aiter__(self) -> typing.Any: + async def __aiter__(self) -> Any: result = await self.acall() try: while True: @@ -102,17 +103,17 @@ async def __aiter__(self) -> typing.Any: class Connection: # async helper - async_helper: typing.Type[AsyncHelperConnection] = AsyncHelperConnection + async_helper: type[AsyncHelperConnection] = AsyncHelperConnection def __init__( self, database: Database, force_rollback: bool = False, full_isolation: bool = False ) -> None: self._orig_database = self._database = database self._full_isolation = full_isolation - self._connection_thread_lock: typing.Optional[Lock] = None - self._connection_thread_is_initialized: typing.Optional[Event] = None - self._connection_thread_running_lock: typing.Optional[Lock] = None - self._isolation_thread: typing.Optional[Thread] = None + self._connection_thread_lock: Lock | None = None + self._connection_thread_is_initialized: Event | None = None + self._connection_thread_running_lock: Lock | None = None + self._isolation_thread: Thread | None = None if self._full_isolation: self._connection_thread_lock = Lock() self._connection_thread_is_initialized = Event() @@ -131,12 +132,10 @@ def __init__( self._connection_counter = 0 # for keeping weak references to transactions active - self._transaction_stack: typing.List[ - typing.Tuple[Transaction, interfaces.TransactionBackend] - ] = [] + self._transaction_stack: list[tuple[Transaction, interfaces.TransactionBackend]] = [] self._force_rollback = force_rollback - self.connection_transaction: typing.Optional[Transaction] = None + self.connection_transaction: Transaction | None = None @multiloop_protector(True) def _get_connection_backend(self) -> interfaces.ConnectionBackend: @@ -170,7 +169,7 @@ async def _aenter(self) -> None: async def __aenter__(self) -> Connection: initialized: bool = False if self._full_isolation: - thread: typing.Optional[Thread] = None + thread: Thread | None = None assert self._connection_thread_lock is not None assert self._connection_thread_is_initialized is not None assert self._connection_thread_running_lock is not None @@ -232,7 +231,7 @@ async def _aexit_raw(self) -> bool: return closing @multiloop_protector(False, passthrough_timeout=True) # fail when specifying timeout - async def _aexit(self) -> typing.Optional[Thread]: + async def _aexit(self) -> Thread | None: if self._full_isolation: assert self._connection_thread_lock is not None # the lock must be held on exit @@ -251,9 +250,9 @@ async def _aexit(self) -> typing.Optional[Thread]: async def __aexit__( self, - exc_type: typing.Optional[typing.Type[BaseException]] = None, - exc_value: typing.Optional[BaseException] = None, - traceback: typing.Optional[TracebackType] = None, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: TracebackType | None = None, ) -> None: thread = await self._aexit() if thread is not None and thread is not current_thread(): @@ -262,7 +261,7 @@ async def __aexit__( thread.join(1) @property - def _loop(self) -> typing.Optional[asyncio.AbstractEventLoop]: + def _loop(self) -> asyncio.AbstractEventLoop | None: return self._database._loop @property @@ -278,12 +277,10 @@ def _backend(self) -> interfaces.DatabaseBackend: @multiloop_protector(False) async def fetch_all( self, - query: typing.Union[ClauseElement, str], - values: typing.Optional[dict] = None, - timeout: typing.Optional[ - float - ] = None, # stub for type checker, multiloop_protector handles timeout - ) -> typing.List[interfaces.Record]: + query: ClauseElement | str, + values: dict | None = None, + timeout: float | None = None, # stub for type checker, multiloop_protector handles timeout + ) -> list[interfaces.Record]: built_query = self._build_query(query, values) async with self._query_lock: return await self._connection.fetch_all(built_query) @@ -291,13 +288,11 @@ async def fetch_all( @multiloop_protector(False) async def fetch_one( self, - query: typing.Union[ClauseElement, str], - values: typing.Optional[dict] = None, + query: ClauseElement | str, + values: dict | None = None, pos: int = 0, - timeout: typing.Optional[ - float - ] = None, # stub for type checker, multiloop_protector handles timeout - ) -> typing.Optional[interfaces.Record]: + timeout: float | None = None, # stub for type checker, multiloop_protector handles timeout + ) -> interfaces.Record | None: built_query = self._build_query(query, values) async with self._query_lock: return await self._connection.fetch_one(built_query, pos=pos) @@ -305,14 +300,12 @@ async def fetch_one( @multiloop_protector(False) async def fetch_val( self, - query: typing.Union[ClauseElement, str], - values: typing.Optional[dict] = None, - column: typing.Any = 0, + query: ClauseElement | str, + values: dict | None = None, + column: Any = 0, pos: int = 0, - timeout: typing.Optional[ - float - ] = None, # stub for type checker, multiloop_protector handles timeout - ) -> typing.Any: + timeout: float | None = None, # stub for type checker, multiloop_protector handles timeout + ) -> Any: built_query = self._build_query(query, values) async with self._query_lock: return await self._connection.fetch_val(built_query, column, pos=pos) @@ -320,12 +313,10 @@ async def fetch_val( @multiloop_protector(False) async def execute( self, - query: typing.Union[ClauseElement, str], - values: typing.Any = None, - timeout: typing.Optional[ - float - ] = None, # stub for type checker, multiloop_protector handles timeout - ) -> typing.Union[interfaces.Record, int]: + query: ClauseElement | str, + values: Any = None, + timeout: float | None = None, # stub for type checker, multiloop_protector handles timeout + ) -> interfaces.Record | int: if isinstance(query, str): built_query = self._build_query(query, values) async with self._query_lock: @@ -337,12 +328,10 @@ async def execute( @multiloop_protector(False) async def execute_many( self, - query: typing.Union[ClauseElement, str], - values: typing.Any = None, - timeout: typing.Optional[ - float - ] = None, # stub for type checker, multiloop_protector handles timeout - ) -> typing.Union[typing.Sequence[interfaces.Record], int]: + query: ClauseElement | str, + values: Any = None, + timeout: float | None = None, # stub for type checker, multiloop_protector handles timeout + ) -> Sequence[interfaces.Record] | int: if isinstance(query, str): built_query = self._build_query(query, None) async with self._query_lock: @@ -354,20 +343,20 @@ async def execute_many( @multiloop_protector(False, passthrough_timeout=True) async def iterate( self, - query: typing.Union[ClauseElement, str], - values: typing.Optional[dict] = None, - chunk_size: typing.Optional[int] = None, - timeout: typing.Optional[float] = None, - ) -> typing.AsyncGenerator[interfaces.Record, None]: + query: ClauseElement | str, + values: dict | None = None, + chunk_size: int | None = None, + timeout: float | None = None, + ) -> AsyncGenerator[interfaces.Record, None]: built_query = self._build_query(query, values) if timeout is None or timeout <= 0: # anext is available in python 3.10 - async def next_fn(inp: typing.Any) -> interfaces.Record: + async def next_fn(inp: Any) -> interfaces.Record: return await aiterator.__anext__() else: - async def next_fn(inp: typing.Any) -> interfaces.Record: + async def next_fn(inp: Any) -> interfaces.Record: return await asyncio.wait_for(aiterator.__anext__(), timeout=timeout) async with self._query_lock: @@ -381,21 +370,21 @@ async def next_fn(inp: typing.Any) -> interfaces.Record: @multiloop_protector(False, passthrough_timeout=True) async def batched_iterate( self, - query: typing.Union[ClauseElement, str], - values: typing.Optional[dict] = None, - batch_size: typing.Optional[int] = None, + query: ClauseElement | str, + values: dict | None = None, + batch_size: int | None = None, batch_wrapper: BatchCallable = tuple, - timeout: typing.Optional[float] = None, - ) -> typing.AsyncGenerator[BatchCallableResult, None]: + timeout: float | None = None, + ) -> AsyncGenerator[BatchCallableResult, None]: built_query = self._build_query(query, values) if timeout is None or timeout <= 0: # anext is available in python 3.10 - async def next_fn(inp: typing.Any) -> typing.Sequence[interfaces.Record]: + async def next_fn(inp: Any) -> Sequence[interfaces.Record]: return await aiterator.__anext__() else: - async def next_fn(inp: typing.Any) -> typing.Sequence[interfaces.Record]: + async def next_fn(inp: Any) -> Sequence[interfaces.Record]: return await asyncio.wait_for(aiterator.__anext__(), timeout=timeout) async with self._query_lock: @@ -409,13 +398,11 @@ async def next_fn(inp: typing.Any) -> typing.Sequence[interfaces.Record]: @multiloop_protector(False) async def run_sync( self, - fn: typing.Callable[..., typing.Any], - *args: typing.Any, - timeout: typing.Optional[ - float - ] = None, # stub for type checker, multiloop_protector handles timeout - **kwargs: typing.Any, - ) -> typing.Any: + fn: Callable[..., Any], + *args: Any, + timeout: float | None = None, # stub for type checker, multiloop_protector handles timeout + **kwargs: Any, + ) -> Any: async with self._query_lock: return await self._connection.run_sync(fn, *args, **kwargs) @@ -423,10 +410,8 @@ async def run_sync( async def create_all( self, meta: MetaData, - timeout: typing.Optional[ - float - ] = None, # stub for type checker, multiloop_protector handles timeout - **kwargs: typing.Any, + timeout: float | None = None, # stub for type checker, multiloop_protector handles timeout + **kwargs: Any, ) -> None: await self.run_sync(meta.create_all, **kwargs) @@ -434,36 +419,30 @@ async def create_all( async def drop_all( self, meta: MetaData, - timeout: typing.Optional[ - float - ] = None, # stub for type checker, multiloop_protector handles timeout - **kwargs: typing.Any, + timeout: float | None = None, # stub for type checker, multiloop_protector handles timeout + **kwargs: Any, ) -> None: await self.run_sync(meta.drop_all, **kwargs) - def transaction(self, *, force_rollback: bool = False, **kwargs: typing.Any) -> "Transaction": + def transaction(self, *, force_rollback: bool = False, **kwargs: Any) -> Transaction: return Transaction(weakref.ref(self), force_rollback, **kwargs) @property @multiloop_protector(True) - def async_connection(self) -> typing.Any: + def async_connection(self) -> Any: """The first layer (sqlalchemy).""" return self._connection.async_connection @multiloop_protector(False) async def get_raw_connection( self, - timeout: typing.Optional[ - float - ] = None, # stub for type checker, multiloop_protector handles timeout - ) -> typing.Any: + timeout: float | None = None, # stub for type checker, multiloop_protector handles timeout + ) -> Any: """The real raw connection (driver).""" return await self.async_connection.get_raw_connection() @staticmethod - def _build_query( - query: typing.Union[ClauseElement, str], values: typing.Optional[typing.Any] = None - ) -> ClauseElement: + def _build_query(query: ClauseElement | str, values: Any | None = None) -> ClauseElement: if isinstance(query, str): query = text(query) diff --git a/databasez/core/database.py b/databasez/core/database.py index ab08361..e272847 100644 --- a/databasez/core/database.py +++ b/databasez/core/database.py @@ -4,11 +4,12 @@ import contextlib import importlib import logging -import typing import weakref +from collections.abc import AsyncGenerator, Callable, Iterator, Sequence from contextvars import ContextVar from functools import lru_cache, partial from types import TracebackType +from typing import TYPE_CHECKING, Any, cast, overload from databasez import interfaces from databasez.utils import ( @@ -23,7 +24,7 @@ from .databaseurl import DatabaseURL from .transaction import Transaction -if typing.TYPE_CHECKING: +if TYPE_CHECKING: from sqlalchemy import URL, MetaData from sqlalchemy.ext.asyncio import AsyncEngine from sqlalchemy.sql import ClauseElement @@ -47,9 +48,9 @@ logger = logging.getLogger("databasez") -default_database: typing.Type[interfaces.DatabaseBackend] -default_connection: typing.Type[interfaces.ConnectionBackend] -default_transaction: typing.Type[interfaces.TransactionBackend] +default_database: type[interfaces.DatabaseBackend] +default_connection: type[interfaces.ConnectionBackend] +default_transaction: type[interfaces.TransactionBackend] @lru_cache(1) @@ -69,9 +70,9 @@ def init() -> None: ) -ACTIVE_FORCE_ROLLBACKS: ContextVar[ - typing.Optional[weakref.WeakKeyDictionary[ForceRollback, bool]] -] = ContextVar("ACTIVE_FORCE_ROLLBACKS", default=None) +ACTIVE_FORCE_ROLLBACKS: ContextVar[weakref.WeakKeyDictionary[ForceRollback, bool] | None] = ( + ContextVar("ACTIVE_FORCE_ROLLBACKS", default=None) +) class ForceRollback: @@ -80,7 +81,7 @@ class ForceRollback: def __init__(self, default: bool): self.default = default - def set(self, value: typing.Union[bool, None] = None) -> None: + def set(self, value: bool | None = None) -> None: force_rollbacks = ACTIVE_FORCE_ROLLBACKS.get() if force_rollbacks is None: # shortcut, we don't need to initialize anything for None (reset) @@ -103,7 +104,7 @@ def __bool__(self) -> bool: return force_rollbacks.get(self, self.default) @contextlib.contextmanager - def __call__(self, force_rollback: bool = True) -> typing.Iterator[None]: + def __call__(self, force_rollback: bool = True) -> Iterator[None]: initial = bool(self) self.set(force_rollback) try: @@ -113,10 +114,10 @@ def __call__(self, force_rollback: bool = True) -> typing.Iterator[None]: class ForceRollbackDescriptor: - def __get__(self, obj: Database, objtype: typing.Type[Database]) -> ForceRollback: + def __get__(self, obj: Database, objtype: type[Database]) -> ForceRollback: return obj._force_rollback - def __set__(self, obj: Database, value: typing.Union[bool, None]) -> None: + def __set__(self, obj: Database, value: bool | None) -> None: assert value is None or isinstance(value, bool), f"Invalid type: {value!r}." obj._force_rollback.set(value) @@ -128,10 +129,10 @@ class AsyncHelperDatabase: def __init__( self, database: Database, - fn: typing.Callable, - args: typing.Any, - kwargs: typing.Any, - timeout: typing.Optional[float], + fn: Callable, + args: Any, + kwargs: Any, + timeout: float | None, ) -> None: self.database = database self.fn = fn @@ -140,16 +141,16 @@ def __init__( self.timeout = timeout self.ctm = None - async def call(self) -> typing.Any: + async def call(self) -> Any: async with self.database as database: return await _arun_with_timeout( self.fn(database, *self.args, **self.kwargs), self.timeout ) - def __await__(self) -> typing.Any: + def __await__(self) -> Any: return self.call().__await__() - async def __aenter__(self) -> typing.Any: + async def __aenter__(self) -> Any: database = await self.database.__aenter__() self.ctm = await _arun_with_timeout( self.fn(database, *self.args, **self.kwargs), timeout=self.timeout @@ -158,9 +159,9 @@ async def __aenter__(self) -> typing.Any: async def __aexit__( self, - exc_type: typing.Optional[typing.Type[BaseException]] = None, - exc_value: typing.Optional[BaseException] = None, - traceback: typing.Optional[TracebackType] = None, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: TracebackType | None = None, ) -> None: assert self.ctm is not None try: @@ -194,11 +195,11 @@ class Database: """ _connection_map: weakref.WeakKeyDictionary[asyncio.Task, Connection] - _databases_map: typing.Dict[typing.Any, Database] - _loop: typing.Optional[asyncio.AbstractEventLoop] = None + _databases_map: dict[Any, Database] + _loop: asyncio.AbstractEventLoop | None = None backend: interfaces.DatabaseBackend url: DatabaseURL - options: typing.Any + options: Any is_connected: bool = False _call_hooks: bool = True _remove_global_connection: bool = True @@ -208,18 +209,18 @@ class Database: # descriptor force_rollback = ForceRollbackDescriptor() # async helper - async_helper: typing.Type[AsyncHelperDatabase] = AsyncHelperDatabase + async_helper: type[AsyncHelperDatabase] = AsyncHelperDatabase def __init__( self, - url: typing.Union[str, DatabaseURL, URL, Database, None] = None, + url: str | DatabaseURL | URL | Database | None = None, *, - force_rollback: typing.Union[bool, None] = None, - config: typing.Optional["DictAny"] = None, - full_isolation: typing.Union[bool, None] = None, + force_rollback: bool | None = None, + config: DictAny | None = None, + full_isolation: bool | None = None, # for - poll_interval: typing.Union[float, None] = None, - **options: typing.Any, + poll_interval: float | None = None, + **options: Any, ): init() assert config is None or url is None, "Use either 'url' or 'config', not both." @@ -260,7 +261,7 @@ def __init__( # When `force_rollback=True` is used, we use a single global # connection, within a transaction that always rolls back. - self._global_connection: typing.Optional[Connection] = None + self._global_connection: Connection | None = None self.ref_counter: int = 0 self.ref_lock: asyncio.Lock = asyncio.Lock() @@ -276,11 +277,11 @@ def _current_task(self) -> asyncio.Task: return task @property - def _connection(self) -> typing.Optional[Connection]: + def _connection(self) -> Connection | None: return self._connection_map.get(self._current_task) @_connection.setter - def _connection(self, connection: typing.Optional[Connection]) -> typing.Optional[Connection]: + def _connection(self, connection: Connection | None) -> Connection | None: task = self._current_task if connection is None: @@ -376,7 +377,7 @@ async def disconnect_hook(self) -> None: @multiloop_protector(True, inject_parent=True) async def disconnect( - self, force: bool = False, *, parent_database: typing.Optional[Database] = None + self, force: bool = False, *, parent_database: Database | None = None ) -> bool: """ Close all connections in the connection pool. @@ -393,14 +394,13 @@ async def disconnect( if parent_database is not None: loop = asyncio.get_running_loop() del parent_database._databases_map[loop] - if force: - if self._databases_map: - assert not self._databases_map, "sub databases still active, force terminate them" - for sub_database in self._databases_map.values(): - await arun_coroutine_threadsafe( - sub_database.disconnect(True), sub_database._loop, self.poll_interval - ) - self._databases_map = {} + if force and self._databases_map: + assert not self._databases_map, "sub databases still active, force terminate them" + for sub_database in self._databases_map.values(): + await arun_coroutine_threadsafe( + sub_database.disconnect(True), sub_database._loop, self.poll_interval + ) + self._databases_map = {} assert not self._databases_map, "sub databases still active" try: @@ -422,7 +422,7 @@ async def disconnect( await self.disconnect_hook() return True - async def __aenter__(self) -> "Database": + async def __aenter__(self) -> Database: await self.connect() # get right database loop = asyncio.get_running_loop() @@ -431,39 +431,39 @@ async def __aenter__(self) -> "Database": async def __aexit__( self, - exc_type: typing.Optional[typing.Type[BaseException]] = None, - exc_value: typing.Optional[BaseException] = None, - traceback: typing.Optional[TracebackType] = None, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: TracebackType | None = None, ) -> None: await self.disconnect() async def fetch_all( self, - query: typing.Union[ClauseElement, str], - values: typing.Optional[dict] = None, - timeout: typing.Optional[float] = None, - ) -> typing.List[interfaces.Record]: + query: ClauseElement | str, + values: dict | None = None, + timeout: float | None = None, + ) -> list[interfaces.Record]: async with self.connection() as connection: return await connection.fetch_all(query, values, timeout=timeout) async def fetch_one( self, - query: typing.Union[ClauseElement, str], - values: typing.Optional[dict] = None, + query: ClauseElement | str, + values: dict | None = None, pos: int = 0, - timeout: typing.Optional[float] = None, - ) -> typing.Optional[interfaces.Record]: + timeout: float | None = None, + ) -> interfaces.Record | None: async with self.connection() as connection: return await connection.fetch_one(query, values, pos=pos, timeout=timeout) async def fetch_val( self, - query: typing.Union[ClauseElement, str], - values: typing.Optional[dict] = None, - column: typing.Any = 0, + query: ClauseElement | str, + values: dict | None = None, + column: Any = 0, pos: int = 0, - timeout: typing.Optional[float] = None, - ) -> typing.Any: + timeout: float | None = None, + ) -> Any: async with self.connection() as connection: return await connection.fetch_val( query, @@ -475,44 +475,44 @@ async def fetch_val( async def execute( self, - query: typing.Union[ClauseElement, str], - values: typing.Any = None, - timeout: typing.Optional[float] = None, - ) -> typing.Union[interfaces.Record, int]: + query: ClauseElement | str, + values: Any = None, + timeout: float | None = None, + ) -> interfaces.Record | int: async with self.connection() as connection: return await connection.execute(query, values, timeout=timeout) async def execute_many( self, - query: typing.Union[ClauseElement, str], - values: typing.Any = None, - timeout: typing.Optional[float] = None, - ) -> typing.Union[typing.Sequence[interfaces.Record], int]: + query: ClauseElement | str, + values: Any = None, + timeout: float | None = None, + ) -> Sequence[interfaces.Record] | int: async with self.connection() as connection: return await connection.execute_many(query, values, timeout=timeout) async def iterate( self, - query: typing.Union[ClauseElement, str], - values: typing.Optional[dict] = None, - chunk_size: typing.Optional[int] = None, - timeout: typing.Optional[float] = None, - ) -> typing.AsyncGenerator[interfaces.Record, None]: + query: ClauseElement | str, + values: dict | None = None, + chunk_size: int | None = None, + timeout: float | None = None, + ) -> AsyncGenerator[interfaces.Record, None]: async with self.connection() as connection: async for record in connection.iterate(query, values, chunk_size, timeout=timeout): yield record async def batched_iterate( self, - query: typing.Union[ClauseElement, str], - values: typing.Optional[dict] = None, - batch_size: typing.Optional[int] = None, + query: ClauseElement | str, + values: dict | None = None, + batch_size: int | None = None, batch_wrapper: BatchCallable = tuple, - timeout: typing.Optional[float] = None, - ) -> typing.AsyncGenerator[BatchCallableResult, None]: + timeout: float | None = None, + ) -> AsyncGenerator[BatchCallableResult, None]: async with self.connection() as connection: - async for batch in typing.cast( - typing.AsyncGenerator["BatchCallableResult", None], + async for batch in cast( + AsyncGenerator["BatchCallableResult", None], connection.batched_iterate( query, values, @@ -523,63 +523,59 @@ async def batched_iterate( ): yield batch - def transaction(self, *, force_rollback: bool = False, **kwargs: typing.Any) -> "Transaction": + def transaction(self, *, force_rollback: bool = False, **kwargs: Any) -> Transaction: return Transaction(self.connection, force_rollback=force_rollback, **kwargs) async def run_sync( self, - fn: typing.Callable[..., typing.Any], - *args: typing.Any, - timeout: typing.Optional[float] = None, - **kwargs: typing.Any, - ) -> typing.Any: + fn: Callable[..., Any], + *args: Any, + timeout: float | None = None, + **kwargs: Any, + ) -> Any: async with self.connection() as connection: return await connection.run_sync(fn, *args, **kwargs, timeout=timeout) async def create_all( - self, meta: MetaData, timeout: typing.Optional[float] = None, **kwargs: typing.Any + self, meta: MetaData, timeout: float | None = None, **kwargs: Any ) -> None: async with self.connection() as connection: await connection.create_all(meta, **kwargs, timeout=timeout) - async def drop_all( - self, meta: MetaData, timeout: typing.Optional[float] = None, **kwargs: typing.Any - ) -> None: + async def drop_all(self, meta: MetaData, timeout: float | None = None, **kwargs: Any) -> None: async with self.connection() as connection: await connection.drop_all(meta, **kwargs, timeout=timeout) @multiloop_protector(False) def _non_global_connection( self, - timeout: typing.Optional[ - float - ] = None, # stub for type checker, multiloop_protector handles timeout + timeout: float | None = None, # stub for type checker, multiloop_protector handles timeout ) -> Connection: if self._connection is None: _connection = self._connection = Connection(self) return _connection return self._connection - def connection(self, timeout: typing.Optional[float] = None) -> Connection: + def connection(self, timeout: float | None = None) -> Connection: if not self.is_connected: raise RuntimeError("Database is not connected") if self.force_rollback: - return typing.cast(Connection, self._global_connection) + return cast(Connection, self._global_connection) return self._non_global_connection(timeout=timeout) @property @multiloop_protector(True) - def engine(self) -> typing.Optional[AsyncEngine]: + def engine(self) -> AsyncEngine | None: return self.backend.engine - @typing.overload + @overload def asgi( self, app: None, handle_lifespan: bool = False, - ) -> typing.Callable[[ASGIApp], ASGIHelper]: ... + ) -> Callable[[ASGIApp], ASGIHelper]: ... - @typing.overload + @overload def asgi( self, app: ASGIApp, @@ -588,9 +584,9 @@ def asgi( def asgi( self, - app: typing.Optional[ASGIApp] = None, + app: ASGIApp | None = None, handle_lifespan: bool = False, - ) -> typing.Union[ASGIHelper, typing.Callable[[ASGIApp], ASGIHelper]]: + ) -> ASGIHelper | Callable[[ASGIApp], ASGIHelper]: """Return wrapper for asgi integration.""" if app is not None: return ASGIHelper(app=app, database=self, handle_lifespan=handle_lifespan) @@ -602,19 +598,19 @@ def get_backends( # let scheme empty for direct imports scheme: str = "", *, - overwrite_paths: typing.Sequence[str] = ["databasez.overwrites"], + overwrite_paths: Sequence[str] = ["databasez.overwrites"], database_name: str = "Database", connection_name: str = "Connection", transaction_name: str = "Transaction", - database_class: typing.Optional[typing.Type[interfaces.DatabaseBackend]] = None, - connection_class: typing.Optional[typing.Type[interfaces.ConnectionBackend]] = None, - transaction_class: typing.Optional[typing.Type[interfaces.TransactionBackend]] = None, - ) -> typing.Tuple[ - typing.Type[interfaces.DatabaseBackend], - typing.Type[interfaces.ConnectionBackend], - typing.Type[interfaces.TransactionBackend], + database_class: type[interfaces.DatabaseBackend] | None = None, + connection_class: type[interfaces.ConnectionBackend] | None = None, + transaction_class: type[interfaces.TransactionBackend] | None = None, + ) -> tuple[ + type[interfaces.DatabaseBackend], + type[interfaces.ConnectionBackend], + type[interfaces.TransactionBackend], ]: - module: typing.Any = None + module: Any = None for overwrite_path in overwrite_paths: imp_path = f"{overwrite_path}.{scheme.replace('+', '_')}" if scheme else overwrite_path try: @@ -650,11 +646,11 @@ def get_backends( @classmethod def apply_database_url_and_options( cls, - url: typing.Union[DatabaseURL, str], + url: DatabaseURL | str, *, - overwrite_paths: typing.Sequence[str] = ["databasez.overwrites"], - **options: typing.Any, - ) -> typing.Tuple[interfaces.DatabaseBackend, DatabaseURL, typing.Dict[str, typing.Any]]: + overwrite_paths: Sequence[str] = ["databasez.overwrites"], + **options: Any, + ) -> tuple[interfaces.DatabaseBackend, DatabaseURL, dict[str, Any]]: url = DatabaseURL(url) database_class, connection_class, transaction_class = cls.get_backends( url.scheme, diff --git a/databasez/core/databaseurl.py b/databasez/core/databaseurl.py index aaf707c..9b5a294 100644 --- a/databasez/core/databaseurl.py +++ b/databasez/core/databaseurl.py @@ -1,5 +1,7 @@ -import typing +from __future__ import annotations + from functools import cached_property +from typing import Any from urllib.parse import SplitResult, parse_qs, quote, unquote, urlencode, urlsplit from sqlalchemy import URL @@ -7,7 +9,7 @@ class DatabaseURL: - def __init__(self, url: typing.Union[str, "DatabaseURL", URL, None] = None): + def __init__(self, url: str | DatabaseURL | URL | None = None): if isinstance(url, DatabaseURL): self._url: str = url._url elif isinstance(url, URL): @@ -47,14 +49,14 @@ def dialect(self) -> str: return self.scheme.split("+")[0] @property - def driver(self) -> typing.Optional[str]: + def driver(self) -> str | None: splitted = self.scheme.split("+", 1) if len(splitted) == 1: return None return splitted[1] @property - def userinfo(self) -> typing.Optional[bytes]: + def userinfo(self) -> bytes | None: if self.components.username: info = quote(self.components.username, safe="+") if self.password: @@ -63,19 +65,19 @@ def userinfo(self) -> typing.Optional[bytes]: return None @property - def username(self) -> typing.Optional[str]: + def username(self) -> str | None: if self.components.username is None: return None return unquote(self.components.username) @property - def password(self) -> typing.Optional[str]: + def password(self) -> str | None: if self.components.password is None: return None return unquote(self.components.password) @property - def hostname(self) -> typing.Optional[str]: + def hostname(self) -> str | None: host = self.components.hostname or self.options.get("host") if isinstance(host, list): if len(host) > 0: @@ -85,11 +87,11 @@ def hostname(self) -> typing.Optional[str]: return host @property - def port(self) -> typing.Optional[int]: + def port(self) -> int | None: return self.components.port @property - def netloc(self) -> typing.Optional[str]: + def netloc(self) -> str | None: return self.components.netloc @property @@ -100,8 +102,8 @@ def database(self) -> str: return unquote(path) @cached_property - def options(self) -> typing.Dict[str, typing.Union[str, typing.List[str]]]: - result: typing.Dict[str, typing.Union[str, typing.List[str]]] = {} + def options(self) -> dict[str, str | list[str]]: + result: dict[str, str | list[str]] = {} for key, val in parse_qs(self.components.query).items(): if len(val) == 1: result[key] = val[0] @@ -109,7 +111,7 @@ def options(self) -> typing.Dict[str, typing.Union[str, typing.List[str]]]: result[key] = val return result - def replace(self, **kwargs: typing.Any) -> "DatabaseURL": + def replace(self, **kwargs: Any) -> DatabaseURL: if ( "username" in kwargs or "user" in kwargs @@ -163,7 +165,7 @@ def obscure_password(self) -> str: def sqla_url(self) -> URL: return make_url(self._url) - def upgrade(self, **extra_options: typing.Any) -> "DatabaseURL": + def upgrade(self, **extra_options: Any) -> DatabaseURL: from .database import Database return Database.apply_database_url_and_options(self, **extra_options)[1] @@ -174,7 +176,7 @@ def __str__(self) -> str: def __repr__(self) -> str: return f"{self.__class__.__name__}({repr(self.obscure_password)})" - def __eq__(self, other: typing.Any) -> bool: + def __eq__(self, other: Any) -> bool: # fix encoding if isinstance(other, str): other = DatabaseURL(other) diff --git a/databasez/core/transaction.py b/databasez/core/transaction.py index d98296d..90b19bd 100644 --- a/databasez/core/transaction.py +++ b/databasez/core/transaction.py @@ -1,56 +1,57 @@ from __future__ import annotations import asyncio -import typing +from collections.abc import Callable, Generator from functools import partial, wraps from types import TracebackType +from typing import TYPE_CHECKING, Any, TypeVar from databasez.utils import _arun_with_timeout, arun_coroutine_threadsafe, multiloop_protector -if typing.TYPE_CHECKING: +if TYPE_CHECKING: from .connection import Connection -_CallableType = typing.TypeVar("_CallableType", bound=typing.Callable) +_CallableType = TypeVar("_CallableType", bound=Callable) class AsyncHelperTransaction: def __init__( self, - transaction: typing.Any, - fn: typing.Callable, - args: typing.Any, - kwargs: typing.Any, - timeout: typing.Optional[float], + transaction: Any, + fn: Callable, + args: Any, + kwargs: Any, + timeout: float | None, ) -> None: self.transaction = transaction self.fn = partial(fn, self.transaction, *args, **kwargs) self.timeout = timeout self.ctm = None - async def call(self) -> typing.Any: + async def call(self) -> Any: # is automatically awaited return await _arun_with_timeout(self.fn(), self.timeout) - async def acall(self) -> typing.Any: + async def acall(self) -> Any: return await arun_coroutine_threadsafe( self.call(), self.transaction._loop, self.transaction.poll_interval ) - def __await__(self) -> typing.Any: + def __await__(self) -> Any: return self.acall().__await__() class Transaction: # async helper - async_helper: typing.Type[AsyncHelperTransaction] = AsyncHelperTransaction + async_helper: type[AsyncHelperTransaction] = AsyncHelperTransaction def __init__( self, - connection_callable: typing.Callable[[], typing.Optional[Connection]], + connection_callable: Callable[[], Connection | None], force_rollback: bool, - existing_transaction: typing.Optional[typing.Any] = None, - **kwargs: typing.Any, + existing_transaction: Any | None = None, + **kwargs: Any, ) -> None: self._connection_callable = connection_callable self._force_rollback = force_rollback @@ -65,7 +66,7 @@ def connection(self) -> Connection: return conn @property - def _loop(self) -> typing.Optional[asyncio.AbstractEventLoop]: + def _loop(self) -> asyncio.AbstractEventLoop | None: return self.connection._loop @property @@ -83,9 +84,9 @@ async def __aenter__(self) -> Transaction: async def __aexit__( self, - exc_type: typing.Optional[typing.Type[BaseException]] = None, - exc_value: typing.Optional[BaseException] = None, - traceback: typing.Optional[TracebackType] = None, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: TracebackType | None = None, ) -> None: """ Called when exiting `async with database.transaction()` @@ -95,7 +96,7 @@ async def __aexit__( else: await self.commit() - def __await__(self) -> typing.Generator[None, None, Transaction]: + def __await__(self) -> Generator[None, None, Transaction]: """ Called if using the low-level `transaction = await database.transaction()` """ @@ -107,7 +108,7 @@ def __call__(self, func: _CallableType) -> _CallableType: """ @wraps(func) - async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: + async def wrapper(*args: Any, **kwargs: Any) -> Any: async with self: await func(*args, **kwargs) @@ -117,9 +118,7 @@ async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: @multiloop_protector(False) async def _start( self, - timeout: typing.Optional[ - float - ] = None, # stub for type checker, multiloop_protector handles timeout + timeout: float | None = None, # stub for type checker, multiloop_protector handles timeout ) -> None: connection = self.connection assert connection._loop @@ -141,7 +140,7 @@ async def _start( # called directly from connection async def start( self, - timeout: typing.Optional[float] = None, + timeout: float | None = None, cleanup_on_error: bool = True, ) -> Transaction: connection = self.connection @@ -168,9 +167,7 @@ async def start( @multiloop_protector(False) async def commit( self, - timeout: typing.Optional[ - float - ] = None, # stub for type checker, multiloop_protector handles timeout + timeout: float | None = None, # stub for type checker, multiloop_protector handles timeout ) -> None: connection = self.connection async with connection._transaction_lock: @@ -186,9 +183,7 @@ async def commit( @multiloop_protector(False) async def rollback( self, - timeout: typing.Optional[ - float - ] = None, # stub for type checker, multiloop_protector handles timeout + timeout: float | None = None, # stub for type checker, multiloop_protector handles timeout ) -> None: connection = self.connection async with connection._transaction_lock: diff --git a/databasez/dialects/dbapi2.py b/databasez/dialects/dbapi2.py index d3c4221..6889d15 100644 --- a/databasez/dialects/dbapi2.py +++ b/databasez/dialects/dbapi2.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import inspect -import typing from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor from importlib import import_module from types import ModuleType +from typing import TYPE_CHECKING, Any, Literal import orjson from sqlalchemy.connectors.asyncio import ( @@ -16,13 +18,13 @@ from databasez.utils import AsyncWrapper -if typing.TYPE_CHECKING: +if TYPE_CHECKING: from sqlalchemy import URL from sqlalchemy.base import Connection from sqlalchemy.engine.interfaces import ConnectArgsType -def get_pool_for(pool: typing.Literal["thread", "process"]) -> typing.Any: +def get_pool_for(pool: Literal["thread", "process"]) -> Any: assert pool in {"thread", "process"}, "invalid option" if pool == "thread": return ThreadPoolExecutor(max_workers=1, thread_name_prefix="dapi2") @@ -43,30 +45,31 @@ class DBAPI2_dialect(DefaultDialect): def __init__( self, *, - dialect_overwrites: typing.Optional[typing.Dict[str, typing.Any]] = None, - json_serializer: typing.Any = None, - json_deserializer: typing.Any = None, - **kwargs: typing.Any, + dialect_overwrites: dict[str, Any] | None = None, + json_serializer: Any = None, + json_deserializer: Any = None, + **kwargs: Any, ): super().__init__(**kwargs) if dialect_overwrites: for k, v in dialect_overwrites.items(): setattr(self, k, v) - def create_connect_args(self, url: "URL") -> "ConnectArgsType": - dbapi_dsn_driver: typing.Optional[str] = url.query.get("dbapi_dsn_driver") # type: ignore - driver_args: typing.Any = url.query.get("dbapi_driver_args") - dbapi_pool: typing.Optional[str] = url.query.get("dbapi_pool") # type: ignore + def create_connect_args(self, url: URL) -> ConnectArgsType: + dbapi_dsn_driver: str | None = url.query.get("dbapi_dsn_driver") # type: ignore + driver_args: Any = url.query.get("dbapi_driver_args") + dbapi_pool: str | None = url.query.get("dbapi_pool") # type: ignore dbapi_force_async_wrapper: str = url.query.get("dbapi_force_async_wrapper") # type: ignore if driver_args: driver_args = orjson.loads(driver_args) dsn: str = url.difference_update_query( ("dbapi_dsn_driver", "dbapi_driver_args") ).render_as_string(hide_password=False) - if dbapi_dsn_driver: - dsn = dsn.replace("dbapi2:", dbapi_dsn_driver, 1) - else: - dsn = dsn.replace("dbapi2://", "", 1) + dsn = ( + dsn.replace("dbapi2:", dbapi_dsn_driver, 1) + if dbapi_dsn_driver + else dsn.replace("dbapi2://", "", 1) + ) kwargs_passed = { "driver_args": driver_args, "dbapi_pool": dbapi_pool, @@ -81,11 +84,11 @@ def create_connect_args(self, url: "URL") -> "ConnectArgsType": def connect( self, - *arg: typing.Any, - dbapi_pool: typing.Literal["thread", "process"] = "thread", - dbapi_force_async_wrapper: typing.Optional[bool] = None, - driver_args: typing.Any = None, - **kw: typing.Any, + *arg: Any, + dbapi_pool: Literal["thread", "process"] = "thread", + dbapi_force_async_wrapper: bool | None = None, + driver_args: Any = None, + **kw: Any, ) -> AsyncAdapt_dbapi_connection: dbapi_namespace = self.loaded_dbapi if dbapi_force_async_wrapper is None: @@ -105,15 +108,15 @@ def connect( ) @classmethod - def get_pool_class(cls, url: "URL") -> typing.Any: + def get_pool_class(cls, url: URL) -> Any: return AsyncAdaptedQueuePool def has_table( self, - connection: "Connection", + connection: Connection, table_name: str, - schema: typing.Optional[str] = None, - **kw: typing.Any, + schema: str | None = None, + **kw: Any, ) -> bool: stmt = text(f"select * from '{quote(connection, table_name)}' LIMIT 1") try: @@ -122,7 +125,7 @@ def has_table( except Exception: return False - def get_isolation_level(self, dbapi_connection: typing.Any) -> typing.Any: + def get_isolation_level(self, dbapi_connection: Any) -> Any: return None @classmethod diff --git a/databasez/dialects/jdbc.py b/databasez/dialects/jdbc.py index aa9f0ce..332fb00 100644 --- a/databasez/dialects/jdbc.py +++ b/databasez/dialects/jdbc.py @@ -1,6 +1,8 @@ -import typing +from __future__ import annotations + from concurrent.futures import ThreadPoolExecutor from importlib import import_module +from typing import TYPE_CHECKING, Any, Optional, cast import orjson from sqlalchemy.connectors.asyncio import ( @@ -14,7 +16,7 @@ from databasez.utils import AsyncWrapper -if typing.TYPE_CHECKING: +if TYPE_CHECKING: from sqlalchemy import URL from sqlalchemy.base import Connection from sqlalchemy.engine.interfaces import ConnectArgsType @@ -36,18 +38,16 @@ class JDBC_dialect(DefaultDialect): def __init__( self, - json_serializer: typing.Any = None, - json_deserializer: typing.Any = None, - **kwargs: typing.Any, + json_serializer: Any = None, + json_deserializer: Any = None, + **kwargs: Any, ) -> None: super().__init__(**kwargs) - def create_connect_args(self, url: "URL") -> "ConnectArgsType": - jdbc_dsn_driver: str = typing.cast(str, url.query["jdbc_dsn_driver"]) - jdbc_driver: typing.Optional[str] = typing.cast( - typing.Optional[str], url.query.get("jdbc_driver") - ) - driver_args: typing.Any = url.query.get("jdbc_driver_args") + def create_connect_args(self, url: URL) -> ConnectArgsType: + jdbc_dsn_driver: str = cast(str, url.query["jdbc_dsn_driver"]) + jdbc_driver: str | None = cast(Optional[str], url.query.get("jdbc_driver")) + driver_args: Any = url.query.get("jdbc_driver_args") if driver_args: driver_args = orjson.loads(driver_args) dsn: str = url.difference_update_query( @@ -57,7 +57,7 @@ def create_connect_args(self, url: "URL") -> "ConnectArgsType": return (dsn,), {"driver_args": driver_args, "driver": jdbc_driver} - def connect(self, *arg: typing.Any, **kw: typing.Any) -> AsyncAdapt_adbapi2_connection: + def connect(self, *arg: Any, **kw: Any) -> AsyncAdapt_adbapi2_connection: creator_fn = AsyncWrapper( self.loaded_dbapi, pool=ThreadPoolExecutor(1, thread_name_prefix="jpype"), @@ -74,10 +74,10 @@ def connect(self, *arg: typing.Any, **kw: typing.Any) -> AsyncAdapt_adbapi2_conn def has_table( self, - connection: "Connection", + connection: Connection, table_name: str, - schema: typing.Optional[str] = None, - **kw: typing.Any, + schema: str | None = None, + **kw: Any, ) -> bool: stmt = text(f"select * from '{quote(connection, table_name)}' LIMIT 1") try: @@ -86,17 +86,17 @@ def has_table( except Exception: return False - def get_isolation_level(self, dbapi_connection: typing.Any) -> typing.Any: + def get_isolation_level(self, dbapi_connection: Any) -> Any: return None @classmethod - def get_pool_class(cls, url: "URL") -> typing.Any: + def get_pool_class(cls, url: URL) -> Any: return AsyncAdaptedQueuePool @classmethod def import_dbapi( cls, - ) -> typing.Any: + ) -> Any: return import_module("jpype.dbapi2") diff --git a/databasez/interfaces.py b/databasez/interfaces.py index 622b2b2..e6f7be6 100644 --- a/databasez/interfaces.py +++ b/databasez/interfaces.py @@ -3,12 +3,12 @@ __all__ = ["Record", "DatabaseBackend", "ConnectionBackend", "TransactionBackend"] -import typing import weakref from abc import ABC, abstractmethod -from collections.abc import AsyncGenerator, Mapping, Sequence +from collections.abc import AsyncGenerator, Callable, Mapping, Sequence +from typing import TYPE_CHECKING, Any, cast -if typing.TYPE_CHECKING: +if TYPE_CHECKING: from sqlalchemy import AsyncEngine, Transaction from sqlalchemy.sql import ClauseElement @@ -20,55 +20,53 @@ class Record(Sequence): @property - def _mapping(self) -> Mapping[str, typing.Any]: + def _mapping(self) -> Mapping[str, Any]: raise NotImplementedError() # pragma: no cover class TransactionBackend(ABC): - raw_transaction: typing.Optional[Transaction] + raw_transaction: Transaction | None def __init__( self, connection: ConnectionBackend, - existing_transaction: typing.Optional[Transaction] = None, + existing_transaction: Transaction | None = None, ): # cannot be a weak ref otherwise connections get lost when retrieving them via transactions self.connection = connection self.raw_transaction = existing_transaction @property - def connection(self) -> typing.Optional[ConnectionBackend]: + def connection(self) -> ConnectionBackend | None: result = self.__dict__.get("connection") if result is None: return None - return typing.cast(ConnectionBackend, result()) + return cast(ConnectionBackend, result()) @connection.setter def connection(self, value: ConnectionBackend) -> None: self.__dict__["connection"] = weakref.ref(value) @property - def async_connection(self) -> typing.Optional[typing.Any]: + def async_connection(self) -> Any | None: result = self.connection if result is None: return None return result.async_connection @property - def owner(self) -> typing.Optional[RootTransaction]: + def owner(self) -> RootTransaction | None: result = self.__dict__.get("owner") if result is None: return None - return typing.cast("RootTransaction", result()) + return cast("RootTransaction", result()) @owner.setter def owner(self, value: RootTransaction) -> None: self.__dict__["owner"] = weakref.ref(value) @abstractmethod - async def start( - self, is_root: bool, **extra_options: typing.Dict[str, typing.Any] - ) -> None: ... + async def start(self, is_root: bool, **extra_options: dict[str, Any]) -> None: ... @abstractmethod async def commit(self) -> None: ... @@ -78,25 +76,25 @@ async def rollback(self) -> None: ... @abstractmethod def get_default_transaction_isolation_level( - self, is_root: bool, **extra_options: typing.Dict[str, typing.Any] - ) -> typing.Optional[str]: ... + self, is_root: bool, **extra_options: dict[str, Any] + ) -> str | None: ... @property - def database(self) -> typing.Optional[DatabaseBackend]: + def database(self) -> DatabaseBackend | None: conn = self.connection if conn is None: return None return conn.database @property - def engine(self) -> typing.Optional[AsyncEngine]: + def engine(self) -> AsyncEngine | None: database = self.database if database is None: return None return database.engine @property - def root(self) -> typing.Optional[RootDatabase]: + def root(self) -> RootDatabase | None: database = self.database if database is None: return None @@ -104,70 +102,68 @@ def root(self) -> typing.Optional[RootDatabase]: class ConnectionBackend(ABC): - async_connection: typing.Optional[typing.Any] = None + async_connection: Any | None = None def __init__(self, database: DatabaseBackend): self.database = database @property - def database(self) -> typing.Optional[DatabaseBackend]: + def database(self) -> DatabaseBackend | None: result = self.__dict__.get("database") if result is None: return None - return typing.cast(DatabaseBackend, result()) + return cast(DatabaseBackend, result()) @database.setter def database(self, value: DatabaseBackend) -> None: self.__dict__["database"] = weakref.ref(value) @property - def owner(self) -> typing.Optional[RootConnection]: + def owner(self) -> RootConnection | None: result = self.__dict__.get("owner") if result is None: return None - return typing.cast("RootConnection", result()) + return cast("RootConnection", result()) @owner.setter def owner(self, value: RootConnection) -> None: self.__dict__["owner"] = weakref.ref(value) @abstractmethod - async def get_raw_connection(self) -> typing.Any: + async def get_raw_connection(self) -> Any: """ Get underlying connection of async_connection. In sqlalchemy based drivers async_connection is the sqlalchemy handle. """ @abstractmethod - async def acquire(self) -> typing.Optional[typing.Any]: ... + async def acquire(self) -> Any | None: ... @abstractmethod async def release(self) -> None: ... @abstractmethod - async def fetch_all(self, query: ClauseElement) -> typing.List[Record]: ... + async def fetch_all(self, query: ClauseElement) -> list[Record]: ... @abstractmethod async def batched_iterate( - self, query: ClauseElement, batch_size: typing.Optional[int] = None - ) -> AsyncGenerator[typing.Sequence[Record], None]: + self, query: ClauseElement, batch_size: int | None = None + ) -> AsyncGenerator[Sequence[Record], None]: # mypy needs async iterators to contain a `yield` # https://github.com/python/mypy/issues/5385#issuecomment-407281656 yield True # type: ignore async def iterate( - self, query: ClauseElement, batch_size: typing.Optional[int] = None + self, query: ClauseElement, batch_size: int | None = None ) -> AsyncGenerator[Record, None]: async for batch in self.batched_iterate(query, batch_size): for record in batch: yield record @abstractmethod - async def fetch_one(self, query: ClauseElement, pos: int = 0) -> typing.Optional[Record]: ... + async def fetch_one(self, query: ClauseElement, pos: int = 0) -> Record | None: ... - async def fetch_val( - self, query: ClauseElement, column: typing.Any = 0, pos: int = 0 - ) -> typing.Any: + async def fetch_val(self, query: ClauseElement, column: Any = 0, pos: int = 0) -> Any: row = await self.fetch_one(query, pos=pos) if row is None: return None @@ -178,26 +174,22 @@ async def fetch_val( @abstractmethod async def run_sync( self, - fn: typing.Callable[..., typing.Any], - *args: typing.Any, - **kwargs: typing.Any, - ) -> typing.Any: ... + fn: Callable[..., Any], + *args: Any, + **kwargs: Any, + ) -> Any: ... @abstractmethod - async def execute_raw(self, stmt: typing.Any, value: typing.Any = None) -> typing.Any: ... + async def execute_raw(self, stmt: Any, value: Any = None) -> Any: ... @abstractmethod - async def execute( - self, stmt: typing.Any, value: typing.Any = None - ) -> typing.Union[Record, int]: + async def execute(self, stmt: Any, value: Any = None) -> Record | int: """ Executes statement and returns the last row defaults (insert) or rowid (insert) or the row count of updates. """ @abstractmethod - async def execute_many( - self, stmt: typing.Any, value: typing.Any = None - ) -> typing.Union[typing.Sequence[Record], int]: + async def execute_many(self, stmt: Any, value: Any = None) -> Sequence[Record] | int: """ Executes statement and returns the row defaults (insert) or the row count of operations. """ @@ -206,15 +198,13 @@ async def execute_many( def in_transaction(self) -> bool: """Is a transaction active?""" - def transaction( - self, existing_transaction: typing.Optional[typing.Any] = None - ) -> TransactionBackend: + def transaction(self, existing_transaction: Any | None = None) -> TransactionBackend: database = self.database assert database is not None return database.transaction_class(self, existing_transaction) @property - def engine(self) -> typing.Optional[AsyncEngine]: + def engine(self) -> AsyncEngine | None: database = self.database if database is None: return None @@ -222,15 +212,15 @@ def engine(self) -> typing.Optional[AsyncEngine]: class DatabaseBackend(ABC): - engine: typing.Optional[AsyncEngine] = None - connection_class: typing.Type[ConnectionBackend] - transaction_class: typing.Type[TransactionBackend] + engine: AsyncEngine | None = None + connection_class: type[ConnectionBackend] + transaction_class: type[TransactionBackend] default_batch_size: int def __init__( self, - connection_class: typing.Type[ConnectionBackend], - transaction_class: typing.Type[TransactionBackend], + connection_class: type[ConnectionBackend], + transaction_class: type[TransactionBackend], ): self.connection_class = connection_class self.transaction_class = transaction_class @@ -239,18 +229,18 @@ def __copy__(self) -> DatabaseBackend: return self.__class__(self.connection_class, self.transaction_class) @property - def owner(self) -> typing.Optional[RootDatabase]: + def owner(self) -> RootDatabase | None: result = self.__dict__.get("owner") if result is None: return None - return typing.cast("RootDatabase", result()) + return cast("RootDatabase", result()) @owner.setter def owner(self, value: RootDatabase) -> None: self.__dict__["owner"] = weakref.ref(value) @abstractmethod - async def connect(self, database_url: DatabaseURL, **options: typing.Any) -> None: + async def connect(self, database_url: DatabaseURL, **options: Any) -> None: """ Set root and start the database backend. @@ -267,8 +257,8 @@ async def disconnect(self) -> None: def extract_options( self, database_url: DatabaseURL, - **options: typing.Dict[str, typing.Any], - ) -> typing.Tuple[DatabaseURL, typing.Dict[str, typing.Any]]: + **options: dict[str, Any], + ) -> tuple[DatabaseURL, dict[str, Any]]: """ Extract options from query. diff --git a/databasez/overwrites/dbapi2.py b/databasez/overwrites/dbapi2.py index 5361315..41ce498 100644 --- a/databasez/overwrites/dbapi2.py +++ b/databasez/overwrites/dbapi2.py @@ -1,15 +1,17 @@ -import typing +from __future__ import annotations + +from typing import TYPE_CHECKING, Any from databasez.sqlalchemy import SQLAlchemyDatabase, SQLAlchemyTransaction -if typing.TYPE_CHECKING: +if TYPE_CHECKING: from databasez.core.databaseurl import DatabaseURL class Transaction(SQLAlchemyTransaction): def get_default_transaction_isolation_level( - self, is_root: bool, **extra_options: typing.Any - ) -> typing.Optional[str]: + self, is_root: bool, **extra_options: Any + ) -> str | None: return None @@ -18,9 +20,9 @@ class Database(SQLAlchemyDatabase): def extract_options( self, - database_url: "DatabaseURL", - **options: typing.Any, - ) -> typing.Tuple["DatabaseURL", typing.Dict[str, typing.Any]]: + database_url: DatabaseURL, + **options: Any, + ) -> tuple[DatabaseURL, dict[str, Any]]: database_url_new, options = super().extract_options(database_url, **options) new_query_options = dict(database_url.options) if database_url_new.driver: diff --git a/databasez/overwrites/jdbc.py b/databasez/overwrites/jdbc.py index 394680c..6719546 100644 --- a/databasez/overwrites/jdbc.py +++ b/databasez/overwrites/jdbc.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import os -import typing from pathlib import Path +from typing import TYPE_CHECKING, Any # ensure jpype.dbapi2 is initialized. Prevent race condition. import jpype.dbapi2 # noqa @@ -8,17 +10,17 @@ from databasez.sqlalchemy import SQLAlchemyDatabase, SQLAlchemyTransaction -if typing.TYPE_CHECKING: +if TYPE_CHECKING: from databasez.core.databaseurl import DatabaseURL -seen_classpathes: typing.Set[str] = set() +seen_classpathes: set[str] = set() class Transaction(SQLAlchemyTransaction): def get_default_transaction_isolation_level( - self, is_root: bool, **extra_options: typing.Any - ) -> typing.Optional[str]: + self, is_root: bool, **extra_options: Any + ) -> str | None: return None @@ -27,16 +29,16 @@ class Database(SQLAlchemyDatabase): def extract_options( self, - database_url: "DatabaseURL", - **options: typing.Any, - ) -> typing.Tuple["DatabaseURL", typing.Dict[str, typing.Any]]: + database_url: DatabaseURL, + **options: Any, + ) -> tuple[DatabaseURL, dict[str, Any]]: database_url_new, options = super().extract_options(database_url, **options) new_query_options = dict(database_url.options) if database_url_new.driver: new_query_options["jdbc_dsn_driver"] = database_url_new.driver if "classpath" in new_query_options: old_classpath = options.pop("classpath", None) - new_classpath: typing.List[str] = [] + new_classpath: list[str] = [] if old_classpath: if isinstance(old_classpath, str): new_classpath.append(old_classpath) @@ -57,10 +59,8 @@ def extract_options( new_query_options["jdbc_driver_args"] = self.json_serializer(jdbc_driver_args) return database_url_new.replace(driver=None, options=new_query_options), options - async def connect(self, database_url: "DatabaseURL", **options: typing.Any) -> None: - classpath: typing.Optional[typing.Union[str, typing.List[str]]] = options.pop( - "classpath", None - ) + async def connect(self, database_url: DatabaseURL, **options: Any) -> None: + classpath: str | list[str] | None = options.pop("classpath", None) if classpath: if isinstance(classpath, str): classpath = [classpath] diff --git a/databasez/overwrites/mssql.py b/databasez/overwrites/mssql.py index 4b657ca..6b6f375 100644 --- a/databasez/overwrites/mssql.py +++ b/databasez/overwrites/mssql.py @@ -1,24 +1,24 @@ -import typing +from __future__ import annotations + +from typing import TYPE_CHECKING, Any from databasez.sqlalchemy import SQLAlchemyDatabase, SQLAlchemyTransaction -if typing.TYPE_CHECKING: +if TYPE_CHECKING: from databasez.core.databaseurl import DatabaseURL class Transaction(SQLAlchemyTransaction): - def get_default_transaction_isolation_level( - self, is_root: bool, **extra_options: typing.Any - ) -> str: + def get_default_transaction_isolation_level(self, is_root: bool, **extra_options: Any) -> str: return "READ UNCOMMITTED" class Database(SQLAlchemyDatabase): def extract_options( self, - database_url: "DatabaseURL", - **options: typing.Any, - ) -> typing.Tuple["DatabaseURL", typing.Dict[str, typing.Any]]: + database_url: DatabaseURL, + **options: Any, + ) -> tuple[DatabaseURL, dict[str, Any]]: database_url_new, options = super().extract_options(database_url, **options) if database_url_new.driver in {None, "pyodbc"}: database_url_new = database_url_new.replace(driver="aioodbc") diff --git a/databasez/overwrites/mysql.py b/databasez/overwrites/mysql.py index d2c0be3..ad892bc 100644 --- a/databasez/overwrites/mysql.py +++ b/databasez/overwrites/mysql.py @@ -1,24 +1,24 @@ -import typing +from __future__ import annotations + +from typing import TYPE_CHECKING, Any from databasez.sqlalchemy import SQLAlchemyDatabase, SQLAlchemyTransaction -if typing.TYPE_CHECKING: +if TYPE_CHECKING: from databasez.core.databaseurl import DatabaseURL class Transaction(SQLAlchemyTransaction): - def get_default_transaction_isolation_level( - self, is_root: bool, **extra_options: typing.Any - ) -> str: + def get_default_transaction_isolation_level(self, is_root: bool, **extra_options: Any) -> str: return "READ COMMITTED" class Database(SQLAlchemyDatabase): def extract_options( self, - database_url: "DatabaseURL", - **options: typing.Any, - ) -> typing.Tuple["DatabaseURL", typing.Dict[str, typing.Any]]: + database_url: DatabaseURL, + **options: Any, + ) -> tuple[DatabaseURL, dict[str, Any]]: database_url_new, options = super().extract_options(database_url, **options) if database_url_new.driver in {None, "pymysql"}: database_url_new = database_url_new.replace(driver="asyncmy") diff --git a/databasez/overwrites/postgresql.py b/databasez/overwrites/postgresql.py index bdf5abe..e008c9f 100644 --- a/databasez/overwrites/postgresql.py +++ b/databasez/overwrites/postgresql.py @@ -1,8 +1,11 @@ -import typing +from __future__ import annotations + +from collections.abc import AsyncGenerator +from typing import TYPE_CHECKING, Any from databasez.sqlalchemy import SQLAlchemyConnection, SQLAlchemyDatabase -if typing.TYPE_CHECKING: +if TYPE_CHECKING: from sqlalchemy.sql import ClauseElement from databasez.core.databaseurl import DatabaseURL @@ -11,9 +14,9 @@ class Database(SQLAlchemyDatabase): def extract_options( self, - database_url: "DatabaseURL", - **options: typing.Dict[str, typing.Any], - ) -> typing.Tuple["DatabaseURL", typing.Dict[str, typing.Any]]: + database_url: DatabaseURL, + **options: dict[str, Any], + ) -> tuple[DatabaseURL, dict[str, Any]]: database_url_new, options = super().extract_options(database_url, **options) if database_url_new.driver in {None, "pscopg2"}: database_url_new = database_url_new.replace(driver="psycopg") @@ -22,8 +25,8 @@ def extract_options( class Connection(SQLAlchemyConnection): async def batched_iterate( - self, query: "ClauseElement", batch_size: typing.Optional[int] = None - ) -> typing.AsyncGenerator[typing.Any, None]: + self, query: ClauseElement, batch_size: int | None = None + ) -> AsyncGenerator[Any, None]: # postgres needs a transaction for iterate/batched_iterate if self.in_transaction(): owner = self.owner @@ -38,8 +41,8 @@ async def batched_iterate( yield batch async def iterate( - self, query: "ClauseElement", batch_size: typing.Optional[int] = None - ) -> typing.AsyncGenerator[typing.Any, None]: + self, query: ClauseElement, batch_size: int | None = None + ) -> AsyncGenerator[Any, None]: # postgres needs a transaction for iterate if self.in_transaction(): owner = self.owner diff --git a/databasez/overwrites/sqlite.py b/databasez/overwrites/sqlite.py index fee6563..d502759 100644 --- a/databasez/overwrites/sqlite.py +++ b/databasez/overwrites/sqlite.py @@ -1,15 +1,13 @@ -import typing +from typing import TYPE_CHECKING, Any from databasez.sqlalchemy import SQLAlchemyDatabase, SQLAlchemyTransaction -if typing.TYPE_CHECKING: +if TYPE_CHECKING: from databasez.core.databaseurl import DatabaseURL class Transaction(SQLAlchemyTransaction): - def get_default_transaction_isolation_level( - self, is_root: bool, **extra_options: typing.Any - ) -> str: + def get_default_transaction_isolation_level(self, is_root: bool, **extra_options: Any) -> str: return "READ UNCOMMITTED" @@ -17,8 +15,8 @@ class Database(SQLAlchemyDatabase): def extract_options( self, database_url: "DatabaseURL", - **options: typing.Any, - ) -> typing.Tuple["DatabaseURL", typing.Dict[str, typing.Any]]: + **options: Any, + ) -> tuple["DatabaseURL", dict[str, Any]]: database_url_new, options = super().extract_options(database_url, **options) if database_url_new.driver is None: database_url_new = database_url_new.replace(driver="aiosqlite") diff --git a/databasez/sqlalchemy.py b/databasez/sqlalchemy.py index 75884ed..afd0674 100644 --- a/databasez/sqlalchemy.py +++ b/databasez/sqlalchemy.py @@ -7,8 +7,9 @@ from __future__ import annotations import logging -import typing +from collections.abc import AsyncGenerator, Callable, Iterable, Sequence from itertools import islice +from typing import TYPE_CHECKING, Any, Optional, cast import orjson from sqlalchemy.ext.asyncio import AsyncConnection, create_async_engine @@ -16,13 +17,13 @@ from databasez.interfaces import ConnectionBackend, DatabaseBackend, Record, TransactionBackend -if typing.TYPE_CHECKING: +if TYPE_CHECKING: from databasez.core import DatabaseURL logger = logging.getLogger("databasez") -def batched(iterable: typing.Iterable[typing.Any], n: int) -> typing.Any: +def batched(iterable: Iterable[Any], n: int) -> Any: # dropin, batched is not available for pythpn < 3.12 iterator = iter(iterable) batch = tuple(islice(iterator, n)) @@ -32,10 +33,10 @@ def batched(iterable: typing.Iterable[typing.Any], n: int) -> typing.Any: class SQLAlchemyTransaction(TransactionBackend): - raw_transaction: typing.Optional[typing.Any] = None + raw_transaction: Any | None = None old_transaction_level: str = "" - async def start(self, is_root: bool, **extra_options: typing.Any) -> None: + async def start(self, is_root: bool, **extra_options: Any) -> None: connection = self.async_connection assert connection is not None, "Connection is not acquired" assert self.raw_transaction is None, "Transaction is already initialized" @@ -50,7 +51,7 @@ async def start(self, is_root: bool, **extra_options: typing.Any) -> None: if ( extra_options.get("isolation_level") is None - or typing.cast(str, extra_options["isolation_level"]) == self.old_transaction_level + or cast(str, extra_options["isolation_level"]) == self.old_transaction_level ): extra_options.pop("isolation_level", None) self.old_transaction_level = "" @@ -82,15 +83,15 @@ async def rollback(self) -> None: await self._close() def get_default_transaction_isolation_level( - self, is_root: bool, **extra_options: typing.Any - ) -> typing.Optional[str]: + self, is_root: bool, **extra_options: Any + ) -> str | None: return "SERIALIZABLE" class SQLAlchemyConnection(ConnectionBackend): - async_connection: typing.Optional[AsyncConnection] = None + async_connection: AsyncConnection | None = None - async def acquire(self) -> typing.Optional[typing.Any]: + async def acquire(self) -> Any | None: assert self.engine is not None, "Database is not started" assert self.async_connection is None, "Connection is already acquired" self.async_connection = await self.engine.connect() @@ -101,28 +102,28 @@ async def release(self) -> None: connection, self.async_connection = self.async_connection, None await connection.close() - async def fetch_all(self, query: ClauseElement) -> typing.List[Record]: + async def fetch_all(self, query: ClauseElement) -> list[Record]: with await self.execute_raw(query) as result: - return typing.cast(typing.List[Record], result.fetchall()) + return cast(list[Record], result.fetchall()) - async def fetch_one(self, query: ClauseElement, pos: int = 0) -> typing.Optional[Record]: + async def fetch_one(self, query: ClauseElement, pos: int = 0) -> Record | None: if pos > 0: query = query.offset(pos) if pos >= 0 and hasattr(query, "limit"): query = query.limit(1) with await self.execute_raw(query) as result: if pos >= 0: - return typing.cast(typing.Optional[Record], result.first()) + return cast(Optional[Record], result.first()) elif pos == -1: - return typing.cast(typing.Optional[Record], result.last()) + return cast(Optional[Record], result.last()) else: raise NotImplementedError( f"Only positive numbers and -1 for the last result are currently supported: {pos}" ) async def batched_iterate( - self, query: ClauseElement, batch_size: typing.Optional[int] = None - ) -> typing.AsyncGenerator[typing.Any, None]: + self, query: ClauseElement, batch_size: int | None = None + ) -> AsyncGenerator[Any, None]: connection = self.async_connection assert connection is not None, "Connection is not acquired" database = self.database @@ -132,9 +133,7 @@ async def batched_iterate( if not connection.dialect.supports_server_side_cursors: with await self.execute_raw(query) as result: - for batch in batched( - typing.cast(typing.List[Record], result.fetchall()), batch_size - ): + for batch in batched(cast(list[Record], result.fetchall()), batch_size): yield batch return @@ -150,8 +149,8 @@ async def batched_iterate( await connection.execution_options(yield_per=0) async def iterate( - self, query: ClauseElement, batch_size: typing.Optional[int] = None - ) -> typing.AsyncGenerator[typing.Any, None]: + self, query: ClauseElement, batch_size: int | None = None + ) -> AsyncGenerator[Any, None]: connection = self.async_connection assert connection is not None, "Connection is not acquired" database = self.database @@ -175,29 +174,27 @@ async def iterate( # undo the connection change await connection.execution_options(yield_per=0) - async def execute_raw(self, stmt: typing.Any, value: typing.Any = None) -> typing.Any: + async def execute_raw(self, stmt: Any, value: Any = None) -> Any: connection = self.async_connection assert connection is not None, "Connection is not acquired" if value is not None: return await connection.execute(stmt, value) return await connection.execute(stmt) - def parse_execute_result(self, result: typing.Any) -> typing.Union[Record, int]: + def parse_execute_result(self, result: Any) -> Record | int: if result.is_insert: try: if result.inserted_primary_key: - return typing.cast(Record, result.inserted_primary_key) + return cast(Record, result.inserted_primary_key) except AttributeError: pass try: - return typing.cast(int, result.lastrowid) + return cast(int, result.lastrowid) except AttributeError: pass - return typing.cast(int, result.rowcount) + return cast(int, result.rowcount) - async def execute( - self, stmt: typing.Any, value: typing.Any = None - ) -> typing.Union[Record, int]: + async def execute(self, stmt: Any, value: Any = None) -> Record | int: """ Executes statement and returns the last row defaults (insert) or rowid (insert) or the row count of updates. """ @@ -205,25 +202,23 @@ async def execute( with await self.execute_raw(stmt, value) as result: return self.parse_execute_result(result) - def parse_execute_many_result( - self, result: typing.Any - ) -> typing.Union[typing.Sequence[Record], int]: + def parse_execute_many_result(self, result: Any) -> Sequence[Record] | int: if result.is_insert: try: if result.inserted_primary_key_rows is not None: # WARNING: only postgresql, other dbs have None values - return typing.cast(typing.Sequence[Record], result.inserted_primary_key_rows) + return cast(Sequence[Record], result.inserted_primary_key_rows) except AttributeError: pass - return typing.cast(int, result.rowcount) + return cast(int, result.rowcount) async def execute_many( - self, stmt: typing.Union[ClauseElement, str], values: typing.Any = None - ) -> typing.Union[typing.Sequence[Record], int]: + self, stmt: ClauseElement | str, values: Any = None + ) -> Sequence[Record] | int: with await self.execute_raw(stmt, values) as result: return self.parse_execute_many_result(result) - async def get_raw_connection(self) -> typing.Any: + async def get_raw_connection(self) -> Any: """The real raw connection.""" connection = self.async_connection assert connection is not None, "Connection is not acquired" @@ -231,10 +226,10 @@ async def get_raw_connection(self) -> typing.Any: async def run_sync( self, - fn: typing.Callable[..., typing.Any], - *args: typing.Any, - **kwargs: typing.Any, - ) -> typing.Any: + fn: Callable[..., Any], + *args: Any, + **kwargs: Any, + ) -> Any: connection = self.async_connection assert connection is not None, "Connection is not acquired" return await connection.run_sync(fn, *args, **kwargs) @@ -246,7 +241,7 @@ def in_transaction(self) -> bool: class SQLAlchemyDatabase(DatabaseBackend): - default_isolation_level: typing.Optional[str] = "AUTOCOMMIT" + default_isolation_level: str | None = "AUTOCOMMIT" default_batch_size: int = 100 def __copy__(self) -> DatabaseBackend: @@ -258,28 +253,26 @@ def __copy__(self) -> DatabaseBackend: def extract_options( self, database_url: DatabaseURL, - **options: typing.Any, - ) -> typing.Tuple[DatabaseURL, typing.Dict[str, typing.Any]]: + **options: Any, + ) -> tuple[DatabaseURL, dict[str, Any]]: # we have our own logic options.setdefault("pool_reset_on_return", None) new_query_options = dict(database_url.options) for param in ["ssl", "echo", "echo_pool"]: if param in new_query_options: assert param not in options - value = typing.cast(str, new_query_options.pop(param)) + value = cast(str, new_query_options.pop(param)) options[param] = value.lower() in {"true", ""} if "isolation_level" in new_query_options: assert "isolation_level" not in options - options["isolation_level"] = typing.cast(str, new_query_options.pop(param)) + options["isolation_level"] = cast(str, new_query_options.pop(param)) for param in ["pool_size", "max_overflow"]: if param in new_query_options: assert param not in options - options[param] = int(typing.cast(str, new_query_options.pop(param))) + options[param] = int(cast(str, new_query_options.pop(param))) if "pool_recycle" in new_query_options: assert "pool_recycle" not in options - options["pool_recycle"] = float( - typing.cast(str, new_query_options.pop("pool_recycle")) - ) + options["pool_recycle"] = float(cast(str, new_query_options.pop("pool_recycle"))) if self.default_isolation_level is not None: options.setdefault("isolation_level", self.default_isolation_level) return database_url.replace(options=new_query_options), options @@ -287,10 +280,10 @@ def extract_options( def json_serializer(self, inp: dict) -> str: return orjson.dumps(inp).decode("utf8") - def json_deserializer(self, inp: typing.Union[str, bytes]) -> dict: - return typing.cast(dict, orjson.loads(inp)) + def json_deserializer(self, inp: str | bytes) -> dict: + return cast(dict, orjson.loads(inp)) - async def connect(self, database_url: DatabaseURL, **options: typing.Any) -> None: + async def connect(self, database_url: DatabaseURL, **options: Any) -> None: self.engine = create_async_engine(database_url.sqla_url, **options) async def disconnect(self) -> None: diff --git a/databasez/testclient.py b/databasez/testclient.py index ebdeb0f..6047d57 100644 --- a/databasez/testclient.py +++ b/databasez/testclient.py @@ -1,7 +1,7 @@ import asyncio +import contextlib import os -import typing -from typing import Any +from typing import Any, Union import sqlalchemy from sqlalchemy.exc import OperationalError, ProgrammingError @@ -12,7 +12,7 @@ from databasez.utils import DATABASEZ_POLL_INTERVAL, ThreadPassingExceptions -async def _get_scalar_result(engine: typing.Any, sql: typing.Any) -> Any: +async def _get_scalar_result(engine: Any, sql: Any) -> Any: try: async with engine.connect() as conn: return await conn.scalar(sql) @@ -45,16 +45,16 @@ class DatabaseTestClient(Database): def __init__( self, - url: typing.Union[str, DatabaseURL, sqlalchemy.URL, Database, None] = None, + url: Union[str, DatabaseURL, sqlalchemy.URL, Database, None] = None, *, - force_rollback: typing.Union[bool, None] = None, - full_isolation: typing.Union[bool, None] = None, - poll_interval: typing.Union[float, None] = None, - use_existing: typing.Union[bool, None] = None, - drop_database: typing.Union[bool, None] = None, - lazy_setup: typing.Union[bool, None] = None, - test_prefix: typing.Union[str, None] = None, - **options: typing.Any, + force_rollback: Union[bool, None] = None, + full_isolation: Union[bool, None] = None, + poll_interval: Union[float, None] = None, + use_existing: Union[bool, None] = None, + drop_database: Union[bool, None] = None, + lazy_setup: Union[bool, None] = None, + test_prefix: Union[str, None] = None, + **options: Any, ): if use_existing is None: use_existing = self.testclient_default_use_existing @@ -126,10 +126,8 @@ async def setup(self) -> None: def setup_protected(self, operation_timeout: float) -> None: thread = ThreadPassingExceptions(target=asyncio.run, args=[self.setup()]) thread.start() - try: + with contextlib.suppress(TimeoutError): thread.join(operation_timeout) - except TimeoutError: - pass async def connect_hook(self) -> None: if not self._setup_executed_init: @@ -143,12 +141,12 @@ async def is_database_exist(self) -> Any: return await self.database_exists(self.test_db_url) @classmethod - async def database_exists(cls, url: typing.Union[str, "sqlalchemy.URL", DatabaseURL]) -> bool: + async def database_exists(cls, url: Union[str, "sqlalchemy.URL", DatabaseURL]) -> bool: url = url if isinstance(url, DatabaseURL) else DatabaseURL(url) database = url.database dialect_name = url.sqla_url.get_dialect(True).name if dialect_name == "postgresql": - text = "SELECT 1 FROM pg_database WHERE datname='%s'" % database + text = f"SELECT 1 FROM pg_database WHERE datname='{database}'" for db in (database, "postgres", "template1", "template0", None): url = url.replace(database=db) async with Database(url, full_isolation=False, force_rollback=False) as db_client: @@ -164,7 +162,7 @@ async def database_exists(cls, url: typing.Union[str, "sqlalchemy.URL", Database url = url.replace(database=None) text = ( "SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA " - "WHERE SCHEMA_NAME = '%s'" % database + f"WHERE SCHEMA_NAME = '{database}'" ) async with Database(url, full_isolation=False, force_rollback=False) as db_client: return bool(await _get_scalar_result(db_client.engine, sqlalchemy.text(text))) @@ -187,9 +185,9 @@ async def database_exists(cls, url: typing.Union[str, "sqlalchemy.URL", Database @classmethod async def create_database( cls, - url: typing.Union[str, "sqlalchemy.URL", DatabaseURL], + url: Union[str, "sqlalchemy.URL", DatabaseURL], encoding: str = "utf8", - template: typing.Any = None, + template: Any = None, ) -> None: url = url if isinstance(url, DatabaseURL) else DatabaseURL(url) database = url.database @@ -221,16 +219,15 @@ async def create_database( template = "template1" async with db_client.engine.begin() as conn: # type: ignore - text = "CREATE DATABASE {} ENCODING '{}' TEMPLATE {}".format( - quote(conn, database), encoding, quote(conn, template) + text = ( + f"CREATE DATABASE {quote(conn, database)} ENCODING " + f"'{encoding}' TEMPLATE {quote(conn, template)}" ) await conn.execute(sqlalchemy.text(text)) elif dialect_name == "mysql": async with db_client.engine.begin() as conn: # type: ignore - text = "CREATE DATABASE {} CHARACTER SET = '{}'".format( - quote(conn, database), encoding - ) + text = f"CREATE DATABASE {quote(conn, database)} CHARACTER SET = '{encoding}'" await conn.execute(sqlalchemy.text(text)) elif dialect_name == "sqlite" and database != ":memory:": @@ -246,7 +243,7 @@ async def create_database( await conn.execute(sqlalchemy.text(text)) @classmethod - async def drop_database(cls, url: typing.Union[str, "sqlalchemy.URL", DatabaseURL]) -> None: + async def drop_database(cls, url: Union[str, "sqlalchemy.URL", DatabaseURL]) -> None: url = url if isinstance(url, DatabaseURL) else DatabaseURL(url) database = url.database dialect = url.sqla_url.get_dialect(True) @@ -273,10 +270,8 @@ async def drop_database(cls, url: typing.Union[str, "sqlalchemy.URL", DatabaseUR db_client = Database(url, force_rollback=False, full_isolation=False) async with db_client: if dialect_name == "sqlite" and database and database != ":memory:": - try: + with contextlib.suppress(FileNotFoundError): os.remove(database) - except FileNotFoundError: - pass elif dialect_name.startswith("postgres"): async with db_client.connection() as conn: # Disconnect all users from the database we are dropping. @@ -288,20 +283,18 @@ async def drop_database(cls, url: typing.Union[str, "sqlalchemy.URL", DatabaseUR version = tuple(map(int, server_version_raw.split("."))) pid_column = "pid" if (version >= (9, 2)) else "procpid" quoted_db = quote(conn.async_connection, database) - text = """ + text = f""" SELECT pg_terminate_backend(pg_stat_activity.{pid_column}) FROM pg_stat_activity - WHERE pg_stat_activity.datname = '{database}' + WHERE pg_stat_activity.datname = '{quoted_db}' AND {pid_column} <> pg_backend_pid(); - """.format(pid_column=pid_column, database=quoted_db) + """ await conn.execute(text) # Drop the database. text = f"DROP DATABASE {quoted_db}" - try: + with contextlib.suppress(ProgrammingError): await conn.execute(text) - except ProgrammingError: - pass else: async with db_client.connection() as conn: text = f"DROP DATABASE {quote(conn.async_connection, database)}" @@ -312,10 +305,8 @@ def drop_db_protected(self) -> None: target=asyncio.run, args=[self.drop_database(self.test_db_url)] ) thread.start() - try: + with contextlib.suppress(TimeoutError): thread.join(self.testclient_operation_timeout) - except TimeoutError: - pass async def disconnect_hook(self) -> None: # next connect the setup routine is reexecuted diff --git a/databasez/types.py b/databasez/types.py index 3fa9c9b..da71cc1 100644 --- a/databasez/types.py +++ b/databasez/types.py @@ -1,10 +1,10 @@ -from collections.abc import Sequence -from typing import TYPE_CHECKING, Any, Callable, Dict, TypeVar +from collections.abc import Callable, Sequence +from typing import TYPE_CHECKING, Any, TypeVar if TYPE_CHECKING: from databasez import interfaces -DictAny = Dict[str, Any] +DictAny = dict[str, Any] BatchCallableResult = TypeVar("BatchCallableResult") BatchCallable = Callable[[Sequence["interfaces.Record"]], BatchCallableResult] diff --git a/databasez/utils.py b/databasez/utils.py index 4c432ec..4ad95b8 100644 --- a/databasez/utils.py +++ b/databasez/utils.py @@ -1,18 +1,21 @@ +from __future__ import annotations + import asyncio import inspect -import typing +from collections.abc import Callable, Coroutine from concurrent.futures import Future from functools import partial, wraps from threading import Thread +from typing import Any, TypeVar, cast -DATABASEZ_RESULT_TIMEOUT: typing.Optional[float] = None +DATABASEZ_RESULT_TIMEOUT: float | None = None # Poll with 0.1ms, this way CPU isn't at 100% DATABASEZ_POLL_INTERVAL: float = 0.0001 async def _arun_coroutine_threadsafe_result_shim( future: Future, loop: asyncio.AbstractEventLoop, poll_interval: float -) -> typing.Any: +) -> Any: while not future.done(): if loop.is_closed(): raise RuntimeError("loop submitted to is closed") @@ -21,8 +24,8 @@ async def _arun_coroutine_threadsafe_result_shim( async def arun_coroutine_threadsafe( - coro: typing.Coroutine, loop: typing.Optional[asyncio.AbstractEventLoop], poll_interval: float -) -> typing.Any: + coro: Coroutine, loop: asyncio.AbstractEventLoop | None, poll_interval: float +) -> Any: running_loop = asyncio.get_running_loop() assert loop is not None and loop.is_running(), "loop is closed" if running_loop is loop: @@ -52,18 +55,18 @@ async def arun_coroutine_threadsafe( class AsyncWrapper: __slots__ = async_wrapper_slots - _async_wrapped: typing.Any - _async_pool: typing.Any - _async_exclude_attrs: typing.Dict[str, typing.Any] - _async_exclude_types: typing.Tuple[typing.Type[typing.Any], ...] + _async_wrapped: Any + _async_pool: Any + _async_exclude_attrs: dict[str, Any] + _async_exclude_types: tuple[type[Any], ...] _async_stringify_exceptions: bool def __init__( self, - wrapped: typing.Any, - pool: typing.Any, - exclude_attrs: typing.Optional[typing.Dict[str, typing.Any]] = None, - exclude_types: typing.Tuple[typing.Type[typing.Any], ...] = default_exclude_types, + wrapped: Any, + pool: Any, + exclude_attrs: dict[str, Any] | None = None, + exclude_types: tuple[type[Any], ...] = default_exclude_types, stringify_exceptions: bool = False, ) -> None: self._async_wrapped = wrapped @@ -72,7 +75,7 @@ def __init__( self._async_exclude_types = exclude_types self._async_stringify_exceptions = stringify_exceptions - def __getattribute__(self, name: str) -> typing.Any: + def __getattribute__(self, name: str) -> Any: if name in async_wrapper_slots: return super().__getattribute__(name) if name == "__aenter__": @@ -88,7 +91,7 @@ def __getattribute__(self, name: str) -> typing.Any: except AttributeError: if name == "__enter__": - async def fn() -> typing.Any: + async def fn() -> Any: return self return fn @@ -103,7 +106,7 @@ async def fn() -> typing.Any: if self._async_exclude_attrs.get(name) is True: if inspect.isroutine(attr): # submit to threadpool - def fn2(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: + def fn2(*args: Any, **kwargs: Any) -> Any: try: return AsyncWrapper( self._async_pool.submit(partial(attr, *args, **kwargs)).result(), @@ -130,7 +133,7 @@ def fn2(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: if isinstance(attr, type) and issubclass(attr, self._async_exclude_types): return attr - async def fn3(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: + async def fn3(*args: Any, **kwargs: Any) -> Any: loop = asyncio.get_running_loop() try: result = await loop.run_in_executor( @@ -158,7 +161,7 @@ async def fn3(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: class ThreadPassingExceptions(Thread): - _exc_raised: typing.Any = None + _exc_raised: Any = None def run(self) -> None: try: @@ -166,22 +169,22 @@ def run(self) -> None: except Exception as exc: self._exc_raised = exc - def join(self, timeout: typing.Union[float, int, None] = None) -> None: + def join(self, timeout: float | int | None = None) -> None: super().join(timeout=timeout) if self._exc_raised: raise self._exc_raised -MultiloopProtectorCallable = typing.TypeVar("MultiloopProtectorCallable", bound=typing.Callable) +MultiloopProtectorCallable = TypeVar("MultiloopProtectorCallable", bound=Callable) -def _run_with_timeout(inp: typing.Any, timeout: typing.Optional[float]) -> typing.Any: +def _run_with_timeout(inp: Any, timeout: float | None) -> Any: if timeout is not None and timeout > 0 and inspect.isawaitable(inp): inp = asyncio.wait_for(inp, timeout=timeout) return inp -async def _arun_with_timeout(inp: typing.Any, timeout: typing.Optional[float]) -> typing.Any: +async def _arun_with_timeout(inp: Any, timeout: float | None) -> Any: if timeout is not None and timeout > 0 and inspect.isawaitable(inp): return await asyncio.wait_for(inp, timeout=timeout) elif inspect.isawaitable(inp): @@ -191,7 +194,7 @@ async def _arun_with_timeout(inp: typing.Any, timeout: typing.Optional[float]) - def multiloop_protector( fail_with_different_loop: bool, inject_parent: bool = False, passthrough_timeout: bool = False -) -> typing.Callable[[MultiloopProtectorCallable], MultiloopProtectorCallable]: +) -> Callable[[MultiloopProtectorCallable], MultiloopProtectorCallable]: """For multiple threads or other reasons why the loop changes""" # True works with all methods False only for methods of Database @@ -199,11 +202,11 @@ def multiloop_protector( def _decorator(fn: MultiloopProtectorCallable) -> MultiloopProtectorCallable: @wraps(fn) def wrapper( - self: typing.Any, - *args: typing.Any, - **kwargs: typing.Any, - ) -> typing.Any: - timeout: typing.Optional[float] = None + self: Any, + *args: Any, + **kwargs: Any, + ) -> Any: + timeout: float | None = None if not passthrough_timeout and "timeout" in kwargs: timeout = kwargs.pop("timeout") if inject_parent: @@ -226,6 +229,6 @@ def wrapper( return self.async_helper(self, fn, args, kwargs, timeout=timeout) return _run_with_timeout(fn(self, *args, **kwargs), timeout=timeout) - return typing.cast(MultiloopProtectorCallable, wrapper) + return cast(MultiloopProtectorCallable, wrapper) return _decorator diff --git a/docs/index.md b/docs/index.md index 1315573..2e18d90 100644 --- a/docs/index.md +++ b/docs/index.md @@ -56,7 +56,7 @@ Databasez is suitable for integrating against any async Web framework, such as [ [Starlette][starlette], [Sanic][sanic], [Responder][responder], [Quart][quart], [aiohttp][aiohttp], [Tornado][tornado], or [FastAPI][fastapi]. -Databasez was built for Python 3.8+ and on the top of the newest **SQLAlchemy 2** and gives you +Databasez was built for Python 3.9+ and on the top of the newest **SQLAlchemy 2** and gives you simple asyncio support for a range of databases. ### Special notes diff --git a/pyproject.toml b/pyproject.toml index a5ab411..410cf0e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "databasez" description = "Async database support for Python." long_description = "Async database support for Python." readme = "README.md" -requires-python = ">=3.8" +requires-python = ">=3.9" dynamic = ['version'] license = "MIT" authors = [{ name = "Tiago Silva", email = "tiago.arasilva@gmail.com" }] @@ -30,7 +30,6 @@ classifiers = [ "Intended Audience :: Developers", "License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", @@ -104,7 +103,8 @@ select = [ "B", # flake8-bugbear "I", # isort "ASYNC", # ASYNC - # "SIM", # simplification + "UP", + "SIM", # simplification ] ignore = [ diff --git a/tests/shared_db.py b/tests/shared_db.py index a066444..703a855 100644 --- a/tests/shared_db.py +++ b/tests/shared_db.py @@ -1,5 +1,6 @@ +from __future__ import annotations + import datetime -import typing from unittest.mock import MagicMock import sqlalchemy @@ -70,7 +71,7 @@ def process_result_value(self, value, dialect): ) -async def database_client(url: typing.Union[dict, str], meta=None) -> DatabaseTestClient: +async def database_client(url: dict | str, meta=None) -> DatabaseTestClient: if meta is None: meta = metadata if isinstance(url, str): diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py index d79d52d..fc6ada7 100644 --- a/tests/test_concurrency.py +++ b/tests/test_concurrency.py @@ -1,6 +1,5 @@ import asyncio import contextvars -import functools import os from concurrent.futures import Future from threading import Thread @@ -22,22 +21,11 @@ DATABASE_URLS = [url.strip() for url in os.environ["TEST_DATABASE_URLS"].split(",")] if os.environ.get("TEST_NO_RISK_SEGFAULTS") or not any( - (x.endswith(" for SQL Server") for x in pyodbc.drivers()) + x.endswith(" for SQL Server") for x in pyodbc.drivers() ): DATABASE_URLS = list(filter(lambda x: "mssql" not in x, DATABASE_URLS)) -try: - to_thread = asyncio.to_thread -except AttributeError: - # for py <= 3.8 - async def to_thread(func, /, *args, **kwargs): - loop = asyncio.get_running_loop() - ctx = contextvars.copy_context() - func_call = functools.partial(ctx.run, func, *args, **kwargs) - return await loop.run_in_executor(None, func_call) - - @pytest.fixture(params=DATABASE_URLS) def database_url(request): """Yield test database despite its name""" @@ -49,10 +37,7 @@ def database_url(request): def _startswith(tested, params): - for param in params: - if tested.startswith(param): - return True - return False + return any(tested.startswith(param) for param in params) def _future_helper(awaitable, future): @@ -114,7 +99,7 @@ async def wrap_in_thread(): thread.start() future.result(4) else: - await to_thread(asyncio.run, asyncio.wait_for(db_lookup(True), 3)) + await asyncio.to_thread(asyncio.run, asyncio.wait_for(db_lookup(True), 3)) await asyncio.gather(db_lookup(False), wrap_in_thread(), wrap_in_thread()) @@ -212,7 +197,7 @@ async def db_connect(depth=3): ops = [] while depth >= 0: depth -= 1 - ops.append(to_thread(asyncio.run, db_connect(depth=depth))) + ops.append(asyncio.to_thread(asyncio.run, db_connect(depth=depth))) await asyncio.gather(*ops) assert new_database.ref_counter == 0 @@ -226,7 +211,7 @@ async def db_connect(depth=3): thread.start() future.result(4) else: - await to_thread(asyncio.run, asyncio.wait_for(db_connect(), 3)) + await asyncio.to_thread(asyncio.run, asyncio.wait_for(db_connect(), 3)) assert database.ref_counter == 0 if force_rollback: async with database: @@ -244,7 +229,7 @@ async def db_connect(): await database.fetch_one("SELECT 1") await database.disconnect() - await to_thread(asyncio.run, db_connect()) + await asyncio.to_thread(asyncio.run, db_connect()) @pytest.mark.asyncio @@ -256,4 +241,4 @@ async def db_connect(): database.disconnect() with pytest.raises(RuntimeError): - await to_thread(asyncio.run, db_connect()) + await asyncio.to_thread(asyncio.run, db_connect()) diff --git a/tests/test_database_testclient.py b/tests/test_database_testclient.py index 1432abc..98f294c 100644 --- a/tests/test_database_testclient.py +++ b/tests/test_database_testclient.py @@ -10,7 +10,7 @@ DATABASE_URLS = [url.strip() for url in os.environ["TEST_DATABASE_URLS"].split(",")] -if not any((x.endswith(" for SQL Server") for x in pyodbc.drivers())): +if not any(x.endswith(" for SQL Server") for x in pyodbc.drivers()): DATABASE_URLS = list(filter(lambda x: "mssql" not in x, DATABASE_URLS)) @@ -93,10 +93,9 @@ async def test_client_drop_existing(database_url): database2 = DatabaseTestClient( database_url, test_prefix="foobar", use_existing=True, drop_database=True, lazy_setup=True ) - async with database2: - async with database2.connection() as conn: - # doesn't crash - await conn.fetch_all("select * from FOOBAR") + async with database2, database2.connection() as conn: + # doesn't crash + await conn.fetch_all("select * from FOOBAR") if database2.drop: assert not await database2.database_exists(database.test_db_url) diff --git a/tests/test_database_url.py b/tests/test_database_url.py index b928023..724e200 100644 --- a/tests/test_database_url.py +++ b/tests/test_database_url.py @@ -44,7 +44,7 @@ def test_database_url_escape(): u = DatabaseURL(f"postgresql://username:{quote('[password')}@localhost/mydatabase") assert u.username == "username" assert u.password == "[password" - assert u.userinfo == f"username:{quote('[password')}".encode("utf-8") + assert u.userinfo == f"username:{quote('[password')}".encode() u2 = DatabaseURL(u) assert u2.password == "[password" diff --git a/tests/test_databases.py b/tests/test_databases.py index b9d48dd..0334ec4 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -27,7 +27,7 @@ DATABASE_URLS = [url.strip() for url in os.environ["TEST_DATABASE_URLS"].split(",")] -if not any((x.endswith(" for SQL Server") for x in pyodbc.drivers())): +if not any(x.endswith(" for SQL Server") for x in pyodbc.drivers()): DATABASE_URLS = list(filter(lambda x: "mssql" not in x, DATABASE_URLS)) DATABASE_CONFIG_URLS = [] @@ -56,10 +56,7 @@ def _startswith(tested, params): - for param in params: - if tested.startswith(param): - return True - return False + return any(tested.startswith(param) for param in params) @pytest.fixture(params=DATABASE_URLS) @@ -176,59 +173,58 @@ async def test_queries_raw(database_url): Test that the basic `execute()`, `execute_many()`, `fetch_all()``, and `fetch_one()` interfaces are all supported (raw queries). """ - async with Database(database_url) as database: - async with database.transaction(force_rollback=True): - # execute() - query = "INSERT INTO notes(text, completed) VALUES (:text, :completed)" - values = {"text": "example1", "completed": True} - await database.execute(query, values) - - # execute_many() - query = "INSERT INTO notes(text, completed) VALUES (:text, :completed)" - values = [ - {"text": "example2", "completed": False}, - {"text": "example3", "completed": True}, - ] - await database.execute_many(query, values) - - # fetch_all() - query = "SELECT * FROM notes WHERE completed = :completed" - results = await database.fetch_all(query=query, values={"completed": True}) - assert len(results) == 2 - assert results[0].text == "example1" - assert results[0].completed == True - assert results[1].text == "example3" - assert results[1].completed == True - - # fetch_one() - query = "SELECT * FROM notes WHERE completed = :completed" - result = await database.fetch_one(query=query, values={"completed": False}) - assert result.text == "example2" - assert result.completed == False - - # fetch_val() - query = "SELECT completed FROM notes WHERE text = :text" - result = await database.fetch_val(query=query, values={"text": "example1"}) - assert result == True - - query = "SELECT * FROM notes WHERE text = :text" - result = await database.fetch_val( - query=query, values={"text": "example1"}, column="completed" - ) - assert result == True - - # iterate() - query = "SELECT * FROM notes" - iterate_results = [] - async for result in database.iterate(query=query): - iterate_results.append(result) - assert len(iterate_results) == 3 - assert iterate_results[0].text == "example1" - assert iterate_results[0].completed == True - assert iterate_results[1].text == "example2" - assert iterate_results[1].completed == False - assert iterate_results[2].text == "example3" - assert iterate_results[2].completed == True + async with Database(database_url) as database, database.transaction(force_rollback=True): + # execute() + query = "INSERT INTO notes(text, completed) VALUES (:text, :completed)" + values = {"text": "example1", "completed": True} + await database.execute(query, values) + + # execute_many() + query = "INSERT INTO notes(text, completed) VALUES (:text, :completed)" + values = [ + {"text": "example2", "completed": False}, + {"text": "example3", "completed": True}, + ] + await database.execute_many(query, values) + + # fetch_all() + query = "SELECT * FROM notes WHERE completed = :completed" + results = await database.fetch_all(query=query, values={"completed": True}) + assert len(results) == 2 + assert results[0].text == "example1" + assert results[0].completed == True + assert results[1].text == "example3" + assert results[1].completed == True + + # fetch_one() + query = "SELECT * FROM notes WHERE completed = :completed" + result = await database.fetch_one(query=query, values={"completed": False}) + assert result.text == "example2" + assert result.completed == False + + # fetch_val() + query = "SELECT completed FROM notes WHERE text = :text" + result = await database.fetch_val(query=query, values={"text": "example1"}) + assert result == True + + query = "SELECT * FROM notes WHERE text = :text" + result = await database.fetch_val( + query=query, values={"text": "example1"}, column="completed" + ) + assert result == True + + # iterate() + query = "SELECT * FROM notes" + iterate_results = [] + async for result in database.iterate(query=query): + iterate_results.append(result) + assert len(iterate_results) == 3 + assert iterate_results[0].text == "example1" + assert iterate_results[0].completed == True + assert iterate_results[1].text == "example2" + assert iterate_results[1].completed == False + assert iterate_results[2].text == "example3" + assert iterate_results[2].completed == True @pytest.mark.asyncio @@ -238,15 +234,14 @@ async def test_ddl_queries(database_url): `CreateTable()` are supported (using SQLAlchemy core). """ - async with Database(database_url) as database: - async with database.transaction(force_rollback=True): - # DropTable() - query = sqlalchemy.schema.DropTable(notes) - await database.execute(query) + async with Database(database_url) as database, database.transaction(force_rollback=True): + # DropTable() + query = sqlalchemy.schema.DropTable(notes) + await database.execute(query) - # CreateTable() - query = sqlalchemy.schema.CreateTable(notes) - await database.execute(query) + # CreateTable() + query = sqlalchemy.schema.CreateTable(notes) + await database.execute(query) @pytest.mark.asyncio @@ -306,14 +301,16 @@ async def test_queries_after_error(database_url, exception): """ async with Database(database_url) as database: - with patch.object( - database.connection()._connection, - "acquire", - new=AsyncMock(side_effect=exception), + with ( + patch.object( + database.connection()._connection, + "acquire", + new=AsyncMock(side_effect=exception), + ), + pytest.raises(exception), ): - with pytest.raises(exception): - query = notes.select() - await database.fetch_all(query) + query = notes.select() + await database.fetch_all(query) query = notes.select() await database.fetch_all(query) @@ -325,24 +322,23 @@ async def test_results_support_mapping_interface(database_url): Casting results to a dict should work, since the interface defines them as supporting the mapping interface. """ - async with Database(database_url) as database: - async with database.transaction(force_rollback=True): - # execute() - query = notes.insert() - values = {"text": "example1", "completed": True} - await database.execute(query, values) + async with Database(database_url) as database, database.transaction(force_rollback=True): + # execute() + query = notes.insert() + values = {"text": "example1", "completed": True} + await database.execute(query, values) - # fetch_all() - query = notes.select() - results = await database.fetch_all(query=query) - results_as_dicts = [dict(item._mapping) for item in results] + # fetch_all() + query = notes.select() + results = await database.fetch_all(query=query) + results_as_dicts = [dict(item._mapping) for item in results] - assert len(results[0]) == 3 - assert len(results_as_dicts[0]) == 3 + assert len(results[0]) == 3 + assert len(results_as_dicts[0]) == 3 - assert isinstance(results_as_dicts[0]["id"], int) - assert results_as_dicts[0]["text"] == "example1" - assert results_as_dicts[0]["completed"] is True + assert isinstance(results_as_dicts[0]["id"], int) + assert results_as_dicts[0]["text"] == "example1" + assert results_as_dicts[0]["completed"] is True @pytest.mark.asyncio @@ -351,13 +347,12 @@ async def test_result_values_allow_duplicate_names(database_url): The values of a result should respect when two columns are selected with the same name. """ - async with Database(database_url) as database: - async with database.transaction(force_rollback=True): - query = "SELECT 1 AS id, 2 AS id" - row = await database.fetch_one(query=query) + async with Database(database_url) as database, database.transaction(force_rollback=True): + query = "SELECT 1 AS id, 2 AS id" + row = await database.fetch_one(query=query) - assert list(row._mapping.keys()) == ["id", "id"] - assert list(row._mapping.values()) == [1, 2] + assert list(row._mapping.keys()) == ["id", "id"] + assert list(row._mapping.values()) == [1, 2] @pytest.mark.asyncio @@ -365,12 +360,11 @@ async def test_fetch_one_returning_no_results(database_url): """ fetch_one should return `None` when no results match. """ - async with Database(database_url) as database: - async with database.transaction(force_rollback=True): - # fetch_all() - query = notes.select() - result = await database.fetch_one(query=query) - assert result is None + async with Database(database_url) as database, database.transaction(force_rollback=True): + # fetch_all() + query = notes.select() + result = await database.fetch_one(query=query) + assert result is None @pytest.mark.asyncio @@ -378,27 +372,26 @@ async def test_execute_return_val(database_url): """ Test using return value from `execute()` to get an inserted primary key. """ - async with Database(database_url) as database: - async with database.transaction(force_rollback=True): - query = notes.insert() - values = {"text": "example1", "completed": True} - pk1 = await database.execute(query, values) - if isinstance(pk1, Sequence): - pk1 = pk1[0] - values = {"text": "example2", "completed": True} - pk2 = await database.execute(query, values) - if isinstance(pk2, Sequence): - pk2 = pk2[0] - assert isinstance(pk1, int) and pk1 > 0 - query = notes.select().where(notes.c.id == pk1) - result = await database.fetch_one(query) - assert result.text == "example1" - assert result.completed is True - assert isinstance(pk2, int) and pk2 > 0 - query = notes.select().where(notes.c.id == pk2) - result = await database.fetch_one(query) - assert result.text == "example2" - assert result.completed is True + async with Database(database_url) as database, database.transaction(force_rollback=True): + query = notes.insert() + values = {"text": "example1", "completed": True} + pk1 = await database.execute(query, values) + if isinstance(pk1, Sequence): + pk1 = pk1[0] + values = {"text": "example2", "completed": True} + pk2 = await database.execute(query, values) + if isinstance(pk2, Sequence): + pk2 = pk2[0] + assert isinstance(pk1, int) and pk1 > 0 + query = notes.select().where(notes.c.id == pk1) + result = await database.fetch_one(query) + assert result.text == "example1" + assert result.completed is True + assert isinstance(pk2, int) and pk2 > 0 + query = notes.select().where(notes.c.id == pk2) + result = await database.fetch_one(query) + assert result.text == "example2" + assert result.completed is True @pytest.mark.asyncio @@ -406,21 +399,20 @@ async def test_datetime_field(database_url): """ Test DataTime columns, to ensure records are coerced to/from proper Python types. """ - async with Database(database_url) as database: - async with database.transaction(force_rollback=True): - now = datetime.datetime.now().replace(microsecond=0) + async with Database(database_url) as database, database.transaction(force_rollback=True): + now = datetime.datetime.now().replace(microsecond=0) - # execute() - query = articles.insert() - values = {"title": "Hello, world", "published": now} - await database.execute(query, values) + # execute() + query = articles.insert() + values = {"title": "Hello, world", "published": now} + await database.execute(query, values) - # fetch_all() - query = articles.select() - results = await database.fetch_all(query=query) - assert len(results) == 1 - assert results[0].title == "Hello, world" - assert results[0].published == now + # fetch_all() + query = articles.select() + results = await database.fetch_all(query=query) + assert len(results) == 1 + assert results[0].title == "Hello, world" + assert results[0].published == now @pytest.mark.asyncio @@ -428,24 +420,23 @@ async def test_decimal_field(database_url): """ Test Decimal (NUMERIC) columns, to ensure records are coerced to/from proper Python types. """ - async with Database(database_url) as database: - async with database.transaction(force_rollback=True): - price = decimal.Decimal("0.700000000000001") + async with Database(database_url) as database, database.transaction(force_rollback=True): + price = decimal.Decimal("0.700000000000001") - # execute() - query = prices.insert() - values = {"price": price} - await database.execute(query, values) + # execute() + query = prices.insert() + values = {"price": price} + await database.execute(query, values) - # fetch_all() - query = prices.select() - results = await database.fetch_all(query=query) - assert len(results) == 1 - if str(database.url).startswith("sqlite"): - # aiosqlite does not support native decimals --> a round-off error is expected - assert results[0].price == pytest.approx(price) - else: - assert results[0].price == price + # fetch_all() + query = prices.select() + results = await database.fetch_all(query=query) + assert len(results) == 1 + if str(database.url).startswith("sqlite"): + # aiosqlite does not support native decimals --> a round-off error is expected + assert results[0].price == pytest.approx(price) + else: + assert results[0].price == price @pytest.mark.asyncio @@ -453,20 +444,19 @@ async def test_json_field(database_url): """ Test JSON columns, to ensure correct cross-database support. """ - async with Database(database_url) as database: - async with database.transaction(force_rollback=True): - # execute() - data = {"text": "hello", "boolean": True, "int": 1} - values = {"data": data} - query = session.insert() - await database.execute(query, values) - - # fetch_all() - query = session.select() - results = await database.fetch_all(query=query) + async with Database(database_url) as database, database.transaction(force_rollback=True): + # execute() + data = {"text": "hello", "boolean": True, "int": 1} + values = {"data": data} + query = session.insert() + await database.execute(query, values) - assert len(results) == 1 - assert results[0].data == {"text": "hello", "boolean": True, "int": 1} + # fetch_all() + query = session.select() + results = await database.fetch_all(query=query) + + assert len(results) == 1 + assert results[0].data == {"text": "hello", "boolean": True, "int": 1} @pytest.mark.asyncio @@ -474,22 +464,21 @@ async def test_custom_field(database_url): """ Test custom column types. """ - async with Database(database_url) as database: - async with database.transaction(force_rollback=True): - today = datetime.date.today() + async with Database(database_url) as database, database.transaction(force_rollback=True): + today = datetime.date.today() - # execute() - query = custom_date.insert() - values = {"title": "Hello, world", "published": today} + # execute() + query = custom_date.insert() + values = {"title": "Hello, world", "published": today} - await database.execute(query, values) + await database.execute(query, values) - # fetch_all() - query = custom_date.select() - results = await database.fetch_all(query=query) - assert len(results) == 1 - assert results[0].title == "Hello, world" - assert results[0].published == today + # fetch_all() + query = custom_date.select() + results = await database.fetch_all(query=query) + assert len(results) == 1 + assert results[0].title == "Hello, world" + assert results[0].published == today @pytest.mark.asyncio @@ -517,10 +506,11 @@ async def test_connect_and_disconnect(database_mixed_url): """ Test explicit connect() and disconnect(). """ - if isinstance(database_mixed_url, str): - data = {"url": database_mixed_url} - else: - data = {"config": database_mixed_url} + data = ( + {"url": database_mixed_url} + if isinstance(database_mixed_url, str) + else {"config": database_mixed_url} + ) database = Database(**data) @@ -618,10 +608,9 @@ async def test_connection_context(database_url): """ Test connection contexts are task-local. """ - async with Database(database_url) as database: - async with database.connection() as connection_1: - async with database.connection() as connection_2: - assert connection_1 is connection_2 + async with Database(database_url) as database, database.connection() as connection_1: # noqa: SIM117 + async with database.connection() as connection_2: + assert connection_1 is connection_2 async with Database(database_url) as database: connection_1 = None @@ -658,11 +647,10 @@ async def test_connection_context_with_raw_connection(database_url): """ Test connection contexts with respect to the raw connection. """ - async with Database(database_url) as database: - async with database.connection() as connection_1: - async with database.connection() as connection_2: - assert connection_1 is connection_2 - assert connection_1.async_connection is connection_2.async_connection + async with Database(database_url) as database, database.connection() as connection_1: # noqa: SIM117 + async with database.connection() as connection_2: + assert connection_1 is connection_2 + assert connection_1.async_connection is connection_2.async_connection @pytest.mark.asyncio @@ -671,142 +659,142 @@ async def test_queries_with_expose_backend_connection(database_url): Replication of `execute()`, `execute_many()`, `fetch_all()``, and `fetch_one()` using the raw driver interface. """ - async with Database(database_url) as database: - async with database.connection() as connection: - async with connection.transaction(force_rollback=True): - # Get the driver connection - raw_connection = (await connection.get_raw_connection()).driver_connection - # Insert query - if database.url.scheme in [ - "mysql", - "mysql+asyncmy", - "mysql+aiomysql", - "postgresql+psycopg", - ]: - insert_query = r"INSERT INTO notes (text, completed) VALUES (%s, %s)" - elif database.url.scheme == "postgresql+asyncpg": - insert_query = r"INSERT INTO notes (text, completed) VALUES ($1, $2)" - else: - insert_query = r"INSERT INTO notes (text, completed) VALUES (?, ?)" - - # execute() - values = ("example1", True) - - if database.url.scheme in [ - "mysql", - "mysql+aiomysql", - "mssql", - "mssql+pyodbc", - "mssql+aioodbc", - ]: - cursor = await raw_connection.cursor() - await cursor.execute(insert_query, values) - elif database.url.scheme == "mysql+asyncmy": - async with raw_connection.cursor() as cursor: - await cursor.execute(insert_query, values) - elif database.url.scheme in [ - "postgresql", - "postgresql+asyncpg", - ]: - await raw_connection.execute(insert_query, *values) - elif database.url.scheme in [ - "postgresql+psycopg", - ]: - await raw_connection.execute(insert_query, values) - elif database.url.scheme in ["sqlite", "sqlite+aiosqlite"]: - await raw_connection.execute(insert_query, values) - - # execute_many() - values = [("example2", False), ("example3", True)] - - if database.url.scheme in ["mysql", "mysql+aiomysql"]: - cursor = await raw_connection.cursor() - await cursor.executemany(insert_query, values) - elif database.url.scheme == "mysql+asyncmy": - async with raw_connection.cursor() as cursor: - await cursor.executemany(insert_query, values) - elif database.url.scheme in [ - "mssql", - "mssql+aioodbc", - "mssql+pyodbc", - ]: - cursor = await raw_connection.cursor() - for value in values: - await cursor.execute(insert_query, value) - elif database.url.scheme in ["postgresql+psycopg"]: - cursor = raw_connection.cursor() - for value in values: - await cursor.execute(insert_query, value) - else: - await raw_connection.executemany(insert_query, values) - - # Select query - select_query = "SELECT notes.id, notes.text, notes.completed FROM notes" - - # fetch_all() - if database.url.scheme in [ - "mysql", - "mysql+aiomysql", - "mssql", - "mssql+pyodbc", - "mssql+aioodbc", - ]: - cursor = await raw_connection.cursor() - await cursor.execute(select_query) - results = await cursor.fetchall() - elif database.url.scheme == "mysql+asyncmy": - async with raw_connection.cursor() as cursor: - await cursor.execute(select_query) - results = await cursor.fetchall() - elif database.url.scheme in ["postgresql", "postgresql+asyncpg"]: - results = await raw_connection.fetch(select_query) - elif database.url.scheme in ["sqlite", "sqlite+aiosqlite"]: - results = await raw_connection.execute_fetchall(select_query) - elif database.url.scheme == "postgresql+psycopg": - cursor = raw_connection.cursor() - await cursor.execute(select_query) - results = await cursor.fetchall() - - assert len(results) == 3 - # Raw output for the raw request - assert results[0][1] == "example1" - assert results[0][2] == True - assert results[1][1] == "example2" - assert results[1][2] == False - assert results[2][1] == "example3" - assert results[2][2] == True - - # fetch_one() - if database.url.scheme in [ - "postgresql", - "postgresql+asyncpg", - ]: - result = await raw_connection.fetchrow(select_query) - elif database.url.scheme in [ - "postgresql+psycopg", - ]: - cursor = raw_connection.cursor() - await cursor.execute(select_query) - result = await cursor.fetchone() - elif database.url.scheme == "mysql+asyncmy": - async with raw_connection.cursor() as cursor: - await cursor.execute(select_query) - result = await cursor.fetchone() - elif database.url.scheme in ["mssql", "mssql+pyodbc", "mssql+aioodbc"]: - cursor = await raw_connection.cursor() - try: - await cursor.execute(select_query) - result = await cursor.fetchone() - finally: - await cursor.close() - else: - cursor = await raw_connection.cursor() - await cursor.execute(select_query) - result = await cursor.fetchone() - - # Raw output for the raw request - assert result[1] == "example1" - assert result[2] == True + async with ( + Database(database_url) as database, + database.connection() as connection, + connection.transaction(force_rollback=True), + ): + # Get the driver connection + raw_connection = (await connection.get_raw_connection()).driver_connection + # Insert query + if database.url.scheme in [ + "mysql", + "mysql+asyncmy", + "mysql+aiomysql", + "postgresql+psycopg", + ]: + insert_query = r"INSERT INTO notes (text, completed) VALUES (%s, %s)" + elif database.url.scheme == "postgresql+asyncpg": + insert_query = r"INSERT INTO notes (text, completed) VALUES ($1, $2)" + else: + insert_query = r"INSERT INTO notes (text, completed) VALUES (?, ?)" + + # execute() + values = ("example1", True) + + if database.url.scheme in [ + "mysql", + "mysql+aiomysql", + "mssql", + "mssql+pyodbc", + "mssql+aioodbc", + ]: + cursor = await raw_connection.cursor() + await cursor.execute(insert_query, values) + elif database.url.scheme == "mysql+asyncmy": + async with raw_connection.cursor() as cursor: + await cursor.execute(insert_query, values) + elif database.url.scheme in [ + "postgresql", + "postgresql+asyncpg", + ]: + await raw_connection.execute(insert_query, *values) + elif database.url.scheme in [ + "postgresql+psycopg", + ] or database.url.scheme in ["sqlite", "sqlite+aiosqlite"]: + await raw_connection.execute(insert_query, values) + + # execute_many() + values = [("example2", False), ("example3", True)] + + if database.url.scheme in ["mysql", "mysql+aiomysql"]: + cursor = await raw_connection.cursor() + await cursor.executemany(insert_query, values) + elif database.url.scheme == "mysql+asyncmy": + async with raw_connection.cursor() as cursor: + await cursor.executemany(insert_query, values) + elif database.url.scheme in [ + "mssql", + "mssql+aioodbc", + "mssql+pyodbc", + ]: + cursor = await raw_connection.cursor() + for value in values: + await cursor.execute(insert_query, value) + elif database.url.scheme in ["postgresql+psycopg"]: + cursor = raw_connection.cursor() + for value in values: + await cursor.execute(insert_query, value) + else: + await raw_connection.executemany(insert_query, values) + + # Select query + select_query = "SELECT notes.id, notes.text, notes.completed FROM notes" + + # fetch_all() + if database.url.scheme in [ + "mysql", + "mysql+aiomysql", + "mssql", + "mssql+pyodbc", + "mssql+aioodbc", + ]: + cursor = await raw_connection.cursor() + await cursor.execute(select_query) + results = await cursor.fetchall() + elif database.url.scheme == "mysql+asyncmy": + async with raw_connection.cursor() as cursor: + await cursor.execute(select_query) + results = await cursor.fetchall() + elif database.url.scheme in ["postgresql", "postgresql+asyncpg"]: + results = await raw_connection.fetch(select_query) + elif database.url.scheme in ["sqlite", "sqlite+aiosqlite"]: + results = await raw_connection.execute_fetchall(select_query) + elif database.url.scheme == "postgresql+psycopg": + cursor = raw_connection.cursor() + await cursor.execute(select_query) + results = await cursor.fetchall() + + assert len(results) == 3 + # Raw output for the raw request + assert results[0][1] == "example1" + assert results[0][2] == True + assert results[1][1] == "example2" + assert results[1][2] == False + assert results[2][1] == "example3" + assert results[2][2] == True + + # fetch_one() + if database.url.scheme in [ + "postgresql", + "postgresql+asyncpg", + ]: + result = await raw_connection.fetchrow(select_query) + elif database.url.scheme in [ + "postgresql+psycopg", + ]: + cursor = raw_connection.cursor() + await cursor.execute(select_query) + result = await cursor.fetchone() + elif database.url.scheme == "mysql+asyncmy": + async with raw_connection.cursor() as cursor: + await cursor.execute(select_query) + result = await cursor.fetchone() + elif database.url.scheme in ["mssql", "mssql+pyodbc", "mssql+aioodbc"]: + cursor = await raw_connection.cursor() + try: + await cursor.execute(select_query) + result = await cursor.fetchone() + finally: + await cursor.close() + else: + cursor = await raw_connection.cursor() + await cursor.execute(select_query) + result = await cursor.fetchone() + + # Raw output for the raw request + assert result[1] == "example1" + assert result[2] == True @pytest.mark.asyncio @@ -814,10 +802,11 @@ async def test_database_url_interface(database_mixed_url): """ Test that Database instances expose a `.url` attribute. """ - if isinstance(database_mixed_url, str): - data = {"url": database_mixed_url} - else: - data = {"config": database_mixed_url} + data = ( + {"url": database_mixed_url} + if isinstance(database_mixed_url, str) + else {"config": database_mixed_url} + ) async with Database(**data) as database: assert isinstance(database.url, DatabaseURL) @@ -843,19 +832,18 @@ async def test_column_names(database_url, select_query): """ Test that column names are exposed correctly through `._mapping.keys()` on each row. """ - async with Database(database_url) as database: - async with database.transaction(force_rollback=True): - # insert values - query = notes.insert() - values = {"text": "example1", "completed": True} - await database.execute(query, values) - # fetch results - results = await database.fetch_all(query=select_query) - assert len(results) == 1 + async with Database(database_url) as database, database.transaction(force_rollback=True): + # insert values + query = notes.insert() + values = {"text": "example1", "completed": True} + await database.execute(query, values) + # fetch results + results = await database.fetch_all(query=select_query) + assert len(results) == 1 - assert sorted(results[0]._mapping.keys()) == ["completed", "id", "text"] - assert results[0].text == "example1" - assert results[0].completed == True + assert sorted(results[0]._mapping.keys()) == ["completed", "id", "text"] + assert results[0].text == "example1" + assert results[0].completed == True @pytest.mark.asyncio diff --git a/tests/test_dbapi2.py b/tests/test_dbapi2.py index 6b82441..7e99f76 100644 --- a/tests/test_dbapi2.py +++ b/tests/test_dbapi2.py @@ -20,9 +20,11 @@ async def test_dbapi2_connect(): """ Test that a basic connection works. """ - async with Database("dbapi2://testsuite.sqlite3", dbapi_path="sqlite3") as database: - async with database.connection(): - pass + async with ( + Database("dbapi2://testsuite.sqlite3", dbapi_path="sqlite3") as database, + database.connection(), + ): + pass @pytest.mark.asyncio diff --git a/tests/test_really_old_jdbc.py b/tests/test_really_old_jdbc.py index 0ca89ac..d182617 100644 --- a/tests/test_really_old_jdbc.py +++ b/tests/test_really_old_jdbc.py @@ -1,3 +1,5 @@ +import contextlib + import pytest import sqlalchemy from sqlalchemy.pool import StaticPool @@ -24,12 +26,14 @@ async def test_jdbc_connect(): """ Test basic connection """ - async with Database( - "jdbc+sqlite://testsuite.sqlite3?classpath=tests/sqlite-jdbc-3.6.13.jar&jdbc_driver=org.sqlite.JDBC", - poolclass=StaticPool, - ) as database: - async with database.connection(): - pass + async with ( + Database( + "jdbc+sqlite://testsuite.sqlite3?classpath=tests/sqlite-jdbc-3.6.13.jar&jdbc_driver=org.sqlite.JDBC", + poolclass=StaticPool, + ) as database, + database.connection(), + ): + pass @pytest.mark.asyncio @@ -38,95 +42,93 @@ async def test_jdbc_queries(): Test that the basic `execute()`, `execute_many()`, `fetch_all()``, `fetch_one()`, `iterate()` and `batched_iterate()` interfaces are all supported (using SQLAlchemy core). """ - async with Database( - "jdbc+sqlite://testsuite.sqlite3?classpath=tests/sqlite-jdbc-3.6.13.jar&jdbc_driver=org.sqlite.JDBC", - poolclass=StaticPool, - ) as database: - async with database.connection() as connection: - await connection.create_all(metadata) - try: - async with connection.transaction(force_rollback=True): - # execute() - query = notes.insert() - values = {"text": "example1", "completed": True} - try: - await connection.execute(query, values) - except Exception: - pass - - # execute_many() - query = notes.insert() - values = [ - {"text": "example2", "completed": False}, - {"text": "example3", "completed": True}, - ] - await connection.execute_many(query, values) - - # fetch_all() - query = notes.select() - results = await database.fetch_all(query=query) - - assert len(results) == 3 - assert results[0].text == "example1" - assert results[0].completed is True - assert results[1].text == "example2" - assert results[1].completed is False - assert results[2].text == "example3" - assert results[2].completed is True - - # fetch_one() - query = notes.select() - result = await database.fetch_one(query=query) - assert result.text == "example1" - assert result.completed is True - - # fetch_val() - query = sqlalchemy.sql.select(*[notes.c.text]) - result = await database.fetch_val(query=query) - assert result == "example1" - - # fetch_val() with no rows - query = sqlalchemy.sql.select(*[notes.c.text]).where( - notes.c.text == "impossible" - ) - result = await database.fetch_val(query=query) - assert result is None - - # fetch_val() with a different column - query = sqlalchemy.sql.select(*[notes.c.id, notes.c.text]) - result = await database.fetch_val(query=query, column=1) - assert result == "example1" - - # row access (needed to maintain test coverage for Record.__getitem__ in postgres backend) - query = sqlalchemy.sql.select(*[notes.c.text]) - result = await database.fetch_one(query=query) - assert result.text == "example1" - assert result[0] == "example1" - - # iterate() - query = notes.select() - iterate_results = [] - async for result in database.iterate(query=query): - iterate_results.append(result) - assert len(iterate_results) == 3 - assert iterate_results[0].text == "example1" - assert iterate_results[0].completed is True - assert iterate_results[1].text == "example2" - assert iterate_results[1].completed is False - assert iterate_results[2].text == "example3" - assert iterate_results[2].completed is True - - # batched_iterate() - query = notes.select() - batched_iterate_results = [] - async for result in database.batched_iterate(query=query, batch_size=2): - batched_iterate_results.append(result) - assert len(batched_iterate_results) == 2 - assert batched_iterate_results[0][0].text == "example1" - assert batched_iterate_results[0][0].completed is True - assert batched_iterate_results[0][1].text == "example2" - assert batched_iterate_results[0][1].completed is False - assert batched_iterate_results[1][0].text == "example3" - assert batched_iterate_results[1][0].completed is True - finally: - await connection.drop_all(metadata) + async with ( + Database( + "jdbc+sqlite://testsuite.sqlite3?classpath=tests/sqlite-jdbc-3.6.13.jar&jdbc_driver=org.sqlite.JDBC", + poolclass=StaticPool, + ) as database, + database.connection() as connection, + ): + await connection.create_all(metadata) + try: + async with connection.transaction(force_rollback=True): + # execute() + query = notes.insert() + values = {"text": "example1", "completed": True} + with contextlib.suppress(Exception): + await connection.execute(query, values) + + # execute_many() + query = notes.insert() + values = [ + {"text": "example2", "completed": False}, + {"text": "example3", "completed": True}, + ] + await connection.execute_many(query, values) + + # fetch_all() + query = notes.select() + results = await database.fetch_all(query=query) + + assert len(results) == 3 + assert results[0].text == "example1" + assert results[0].completed is True + assert results[1].text == "example2" + assert results[1].completed is False + assert results[2].text == "example3" + assert results[2].completed is True + + # fetch_one() + query = notes.select() + result = await database.fetch_one(query=query) + assert result.text == "example1" + assert result.completed is True + + # fetch_val() + query = sqlalchemy.sql.select(*[notes.c.text]) + result = await database.fetch_val(query=query) + assert result == "example1" + + # fetch_val() with no rows + query = sqlalchemy.sql.select(*[notes.c.text]).where(notes.c.text == "impossible") + result = await database.fetch_val(query=query) + assert result is None + + # fetch_val() with a different column + query = sqlalchemy.sql.select(*[notes.c.id, notes.c.text]) + result = await database.fetch_val(query=query, column=1) + assert result == "example1" + + # row access (needed to maintain test coverage for Record.__getitem__ in postgres backend) + query = sqlalchemy.sql.select(*[notes.c.text]) + result = await database.fetch_one(query=query) + assert result.text == "example1" + assert result[0] == "example1" + + # iterate() + query = notes.select() + iterate_results = [] + async for result in database.iterate(query=query): + iterate_results.append(result) + assert len(iterate_results) == 3 + assert iterate_results[0].text == "example1" + assert iterate_results[0].completed is True + assert iterate_results[1].text == "example2" + assert iterate_results[1].completed is False + assert iterate_results[2].text == "example3" + assert iterate_results[2].completed is True + + # batched_iterate() + query = notes.select() + batched_iterate_results = [] + async for result in database.batched_iterate(query=query, batch_size=2): + batched_iterate_results.append(result) + assert len(batched_iterate_results) == 2 + assert batched_iterate_results[0][0].text == "example1" + assert batched_iterate_results[0][0].completed is True + assert batched_iterate_results[0][1].text == "example2" + assert batched_iterate_results[0][1].completed is False + assert batched_iterate_results[1][0].text == "example3" + assert batched_iterate_results[1][0].completed is True + finally: + await connection.drop_all(metadata) diff --git a/tests/test_transactions.py b/tests/test_transactions.py index d7667d6..f632b8d 100644 --- a/tests/test_transactions.py +++ b/tests/test_transactions.py @@ -1,7 +1,7 @@ import asyncio import gc import os -from typing import MutableMapping +from collections.abc import MutableMapping import pyodbc import pytest @@ -14,7 +14,7 @@ DATABASE_URLS = [url.strip() for url in os.environ["TEST_DATABASE_URLS"].split(",")] -if not any((x.endswith(" for SQL Server") for x in pyodbc.drivers())): +if not any(x.endswith(" for SQL Server") for x in pyodbc.drivers()): DATABASE_URLS = list(filter(lambda x: "mssql" not in x, DATABASE_URLS)) @@ -35,10 +35,7 @@ async def test_commit_on_root_transaction(database_url): Deal with this here, and delete the records rather than rolling back. """ - if isinstance(database_url, str): - data = {"url": database_url} - else: - data = {"config": database_url} + data = {"url": database_url} if isinstance(database_url, str) else {"config": database_url} async with Database(**data) as database: try: @@ -205,10 +202,7 @@ async def test_rollback_isolation(database_url): """ Ensure that `database.transaction(force_rollback=True)` provides strict isolation. """ - if isinstance(database_url, str): - data = {"url": database_url} - else: - data = {"config": database_url} + data = {"url": database_url} if isinstance(database_url, str) else {"config": database_url} async with Database(**data) as database: # Perform some INSERT operations on the database. @@ -227,10 +221,7 @@ async def test_rollback_isolation_with_contextmanager(database_url): """ Ensure that `database.force_rollback()` provides strict isolation. """ - if isinstance(database_url, str): - data = {"url": database_url} - else: - data = {"config": database_url} + data = {"url": database_url} if isinstance(database_url, str) else {"config": database_url} database = Database(**data) @@ -253,15 +244,14 @@ async def test_transaction_commit(database_url): Ensure that transaction commit is supported. """ - async with Database(database_url) as database: - async with database.transaction(force_rollback=True): - async with database.transaction(): - query = notes.insert().values(text="example1", completed=True) - await database.execute(query) + async with Database(database_url) as database, database.transaction(force_rollback=True): + async with database.transaction(): + query = notes.insert().values(text="example1", completed=True) + await database.execute(query) - query = notes.select() - results = await database.fetch_all(query=query) - assert len(results) == 1 + query = notes.select() + results = await database.fetch_all(query=query) + assert len(results) == 1 @pytest.mark.asyncio @@ -307,19 +297,21 @@ def delete_independently(): conn.execute(query) conn.close() - async with Database(database_url) as database: - async with database.transaction(force_rollback=True, isolation_level="serializable"): - query = notes.select() - results = await database.fetch_all(query=query) - assert len(results) == 0 + async with ( + Database(database_url) as database, + database.transaction(force_rollback=True, isolation_level="serializable"), + ): + query = notes.select() + results = await database.fetch_all(query=query) + assert len(results) == 0 - insert_independently() + insert_independently() - query = notes.select() - results = await database.fetch_all(query=query) - assert len(results) == 0 + query = notes.select() + results = await database.fetch_all(query=query) + assert len(results) == 0 - delete_independently() + delete_independently() @pytest.mark.asyncio @@ -327,24 +319,20 @@ async def test_transaction_rollback(database_url): """ Ensure that transaction rollback is supported. """ - if isinstance(database_url, str): - data = {"url": database_url} - else: - data = {"config": database_url} + data = {"url": database_url} if isinstance(database_url, str) else {"config": database_url} - async with Database(**data) as database: - async with database.transaction(force_rollback=True): - try: - async with database.transaction(): - query = notes.insert().values(text="example1", completed=True) - await database.execute(query) - raise RuntimeError() - except RuntimeError: - pass + async with Database(**data) as database, database.transaction(force_rollback=True): + try: + async with database.transaction(): + query = notes.insert().values(text="example1", completed=True) + await database.execute(query) + raise RuntimeError() + except RuntimeError: + pass - query = notes.select() - results = await database.fetch_all(query=query) - assert len(results) == 0 + query = notes.select() + results = await database.fetch_all(query=query) + assert len(results) == 0 @pytest.mark.asyncio @@ -352,25 +340,21 @@ async def test_transaction_commit_low_level(database_url): """ Ensure that an explicit `await transaction.commit()` is supported. """ - if isinstance(database_url, str): - data = {"url": database_url} - else: - data = {"config": database_url} + data = {"url": database_url} if isinstance(database_url, str) else {"config": database_url} - async with Database(**data) as database: - async with database.transaction(force_rollback=True): - transaction = await database.transaction() - try: - query = notes.insert().values(text="example1", completed=True) - await database.execute(query) - except Exception: - await transaction.rollback() - else: - await transaction.commit() + async with Database(**data) as database, database.transaction(force_rollback=True): + transaction = await database.transaction() + try: + query = notes.insert().values(text="example1", completed=True) + await database.execute(query) + except Exception: + await transaction.rollback() + else: + await transaction.commit() - query = notes.select() - results = await database.fetch_all(query=query) - assert len(results) == 1 + query = notes.select() + results = await database.fetch_all(query=query) + assert len(results) == 1 @pytest.mark.asyncio @@ -379,21 +363,20 @@ async def test_transaction_rollback_low_level(database_url): Ensure that an explicit `await transaction.rollback()` is supported. """ - async with Database(database_url) as database: - async with database.transaction(force_rollback=True): - transaction = await database.transaction() - try: - query = notes.insert().values(text="example1", completed=True) - await database.execute(query) - raise RuntimeError() - except Exception: - await transaction.rollback() - else: # pragma: no cover - await transaction.commit() + async with Database(database_url) as database, database.transaction(force_rollback=True): + transaction = await database.transaction() + try: + query = notes.insert().values(text="example1", completed=True) + await database.execute(query) + raise RuntimeError() + except Exception: + await transaction.rollback() + else: # pragma: no cover + await transaction.commit() - query = notes.select() - results = await database.fetch_all(query=query) - assert len(results) == 0 + query = notes.select() + results = await database.fetch_all(query=query) + assert len(results) == 0 @pytest.mark.parametrize( @@ -475,27 +458,26 @@ async def test_transaction_context_child_task_inheritance_example(database_url): if db.url.dialect == "mssql": return - async with Database(database_url) as database: - async with database.transaction(): - # Create a note - await database.execute(notes.insert().values(id=1, text="setup", completed=True)) + async with Database(database_url) as database, database.transaction(): + # Create a note + await database.execute(notes.insert().values(id=1, text="setup", completed=True)) - # Change the note from the same task - await database.execute(notes.update().where(notes.c.id == 1).values(text="prior")) + # Change the note from the same task + await database.execute(notes.update().where(notes.c.id == 1).values(text="prior")) - # Confirm the change - result = await database.fetch_one(notes.select().where(notes.c.id == 1)) - assert result.text == "prior" + # Confirm the change + result = await database.fetch_one(notes.select().where(notes.c.id == 1)) + assert result.text == "prior" - async def run_update_from_child_task(connection): - # Change the note from a child task - await connection.execute(notes.update().where(notes.c.id == 1).values(text="test")) + async def run_update_from_child_task(connection): + # Change the note from a child task + await connection.execute(notes.update().where(notes.c.id == 1).values(text="test")) - await asyncio.create_task(run_update_from_child_task(database.connection())) + await asyncio.create_task(run_update_from_child_task(database.connection())) - # Confirm the child's change - result = await database.fetch_one(notes.select().where(notes.c.id == 1)) - assert result.text == "test" + # Confirm the child's change + result = await database.fetch_one(notes.select().where(notes.c.id == 1)) + assert result.text == "test" @pytest.mark.asyncio