Skip to content

Commit

Permalink
fixes and docs
Browse files Browse the repository at this point in the history
  • Loading branch information
devkral committed Aug 23, 2024
1 parent de55806 commit 58c75f5
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 27 deletions.
43 changes: 31 additions & 12 deletions databasez/core/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,16 @@ def _init_thread(database: Database, is_initialized: Event) -> None:
loop.run_forever()
except RuntimeError:
pass
task.result()
try:
loop.run_until_complete(database.disconnect())
loop.run_until_complete(loop.shutdown_asyncgens())
task.result()
finally:
del task
loop.close()
database._loop = None
try:
loop.run_until_complete(database.disconnect())
loop.run_until_complete(loop.shutdown_asyncgens())
finally:
del task
loop.close()
database._loop = None


class Connection:
Expand All @@ -54,7 +56,9 @@ def __init__(
self._full_isolation = full_isolation
self._database = database
if full_isolation:
self._database = database.__class__(database, force_rollback=force_rollback)
self._database = database.__class__(
database, force_rollback=force_rollback, full_isolation=False
)
self._database._call_hooks = False
self._database._global_connection = self
self._backend = backend
Expand All @@ -77,11 +81,23 @@ async def __aenter__(self) -> Connection:
if self._full_isolation:
ctx = copy_context()
is_initialized = Event()
self._isolation_thread = Thread(
target=ctx.run, args=[_init_thread, self._database, is_initialized], daemon=False
self._isolation_thread = thread = Thread(
target=ctx.run,
args=[
_init_thread,
self._database.__class__(
self._database, force_rollback=True, full_isolation=False
),
is_initialized,
],
daemon=False,
)
self._isolation_thread.start()
thread.start()
is_initialized.wait()
if not thread.is_alive():
thread = self._isolation_thread
self._isolation_thread = None
thread.join()

async with self._connection_lock:
self._connection_counter += 1
Expand Down Expand Up @@ -134,9 +150,12 @@ async def __aexit__(
finally:
if closing and self._full_isolation:
assert self._isolation_thread is not None
self._database._loop.stop()
self._isolation_thread.join()
thread = self._isolation_thread
loop = self._database._loop
self._isolation_thread = None
if loop:
loop.stop()
thread.join()

@property
def _loop(self) -> typing.Any:
Expand Down
14 changes: 10 additions & 4 deletions databasez/core/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,9 @@ async def connect(self) -> bool:
self.is_connected = True

if self._global_connection is None:
connection = Connection(self, self.backend, force_rollback=True, full_isolation=self._full_isolation)
connection = Connection(
self, self.backend, force_rollback=True, full_isolation=self._full_isolation
)
self._global_connection = connection
return True

Expand All @@ -321,9 +323,13 @@ async def disconnect(
loop = asyncio.get_running_loop()
del parent_database._databases_map[loop]
if force:
for sub_database in self._databases_map.values():
await sub_database.disconnect(True)
self._databases_map = {}
if self._databases_map:
assert not self._databases_map, "sub databases still active, force terminate them"
for sub_database in self._databases_map.values():
asyncio.run_coroutine_threadsafe(
sub_database.disconnect(True), sub_database._loop
)
self._databases_map = {}
assert not self._databases_map, "sub databases still active"

try:
Expand Down
22 changes: 15 additions & 7 deletions databasez/testclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class DatabaseTestClient(Database):
# is used for copying Database and DatabaseTestClientand providing an early url
test_db_url: str
# hooks for overwriting defaults of args with None
testclient_default_full_isolation: bool = False
testclient_default_force_rollback: bool = False
testclient_default_lazy_setup: bool = False
# customization hooks
Expand All @@ -47,6 +48,7 @@ def __init__(
*,
force_rollback: typing.Union[bool, None] = None,
use_existing: typing.Union[bool, None] = None,
full_isolation: typing.Union[bool, None] = None,
drop_database: typing.Union[bool, None] = None,
lazy_setup: typing.Union[bool, None] = None,
test_prefix: typing.Union[str, None] = None,
Expand All @@ -56,6 +58,8 @@ def __init__(
use_existing = self.testclient_default_use_existing
if drop_database is None:
drop_database = self.testclient_default_drop_database
if full_isolation is None:
full_isolation = self.testclient_default_full_isolation
if test_prefix is None:
test_prefix = self.testclient_default_test_prefix
self._setup_executed_init = False
Expand Down Expand Up @@ -144,7 +148,7 @@ async def database_exists(cls, url: typing.Union[str, "sa.URL", DatabaseURL]) ->
text = "SELECT 1 FROM pg_database WHERE datname='%s'" % database
for db in (database, "postgres", "template1", "template0", None):
url = url.replace(database=db)
async with Database(url) as db_client:
async with Database(url, full_isolation=False, force_rollback=False) as db_client:
try:
return bool(await _get_scalar_result(db_client.engine, sa.text(text)))
except (ProgrammingError, OperationalError):
Expand All @@ -157,7 +161,7 @@ async def database_exists(cls, url: typing.Union[str, "sa.URL", DatabaseURL]) ->
"SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA "
"WHERE SCHEMA_NAME = '%s'" % database
)
async with Database(url) as db_client:
async with Database(url, full_isolation=False, force_rollback=False) as db_client:
return bool(await _get_scalar_result(db_client.engine, sa.text(text)))

elif dialect_name == "sqlite":
Expand All @@ -169,7 +173,7 @@ async def database_exists(cls, url: typing.Union[str, "sa.URL", DatabaseURL]) ->
return True
else:
text = "SELECT 1"
async with Database(url) as db_client:
async with Database(url, full_isolation=False, force_rollback=False) as db_client:
try:
return bool(await _get_scalar_result(db_client.engine, sa.text(text)))
except (ProgrammingError, OperationalError):
Expand Down Expand Up @@ -201,9 +205,11 @@ async def create_database(
dialect_name == "postgresql"
and dialect_driver in {"asyncpg", "pg8000", "psycopg", "psycopg2", "psycopg2cffi"}
):
db_client = Database(url, isolation_level="AUTOCOMMIT", force_rollback=False)
db_client = Database(
url, isolation_level="AUTOCOMMIT", force_rollback=False, full_isolation=False
)
else:
db_client = Database(url, force_rollback=False)
db_client = Database(url, force_rollback=False, full_isolation=False)
async with db_client:
if dialect_name == "postgresql":
if not template:
Expand Down Expand Up @@ -255,9 +261,11 @@ async def drop_database(cls, url: typing.Union[str, "sa.URL", DatabaseURL]) -> N
dialect_name == "postgresql"
and dialect_driver in {"asyncpg", "pg8000", "psycopg", "psycopg2", "psycopg2cffi"}
):
db_client = Database(url, isolation_level="AUTOCOMMIT", force_rollback=False)
db_client = Database(
url, isolation_level="AUTOCOMMIT", force_rollback=False, full_isolation=False
)
else:
db_client = Database(url, force_rollback=False)
db_client = Database(url, force_rollback=False, full_isolation=False)
async with db_client:
if dialect_name == "sqlite" and database and database != ":memory:":
try:
Expand Down
2 changes: 2 additions & 0 deletions databasez/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,13 +145,15 @@ def _run_with_timeout(inp: typing.Any, timeout: typing.Optional[float]) -> typin
inp = asyncio.wait_for(inp, timeout=timeout)
return inp


async def _arun_with_timeout(inp: typing.Any, timeout: typing.Optional[float]) -> typing.Any:
if timeout is not None and timeout > 0 and inspect.isawaitable(inp):
inp = await asyncio.wait_for(inp, timeout=timeout)
elif inspect.isawaitable(inp):
return await inp
return inp


class AsyncHelperDatabase:
def __init__(
self,
Expand Down
7 changes: 7 additions & 0 deletions docs/connections-and-transactions.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ from databasez import Database

<sup>Default: `None`</sup>

* **full_isolation** - Special mode for using force_rollback with nested queries. This parameter fully isolates the global connection
in an extra thread. This way it is possible to use blocking operations.

<sup>Default: `None`</sup>

* **config** - A python like dictionary as alternative to the `url` that contains the information
to connect to the database.

Expand All @@ -35,6 +40,8 @@ to connect to the database.
Be careful when setting up the `url` or `config`. You can use one or the other but not both
at the same time.

!!! Warning
`full_isolation` is not mature and shouldn't be used in production code.

**Attributes***

Expand Down
8 changes: 8 additions & 0 deletions docs/release-notes.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
# Release Notes


## 0.10.0


### Added

- `full_isolation` parameter. Isolate the force_rollback Connection in a thread.


## 0.9.7

### Added
Expand Down
15 changes: 11 additions & 4 deletions tests/test_concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ async def test_concurrent_access_on_single_connection(database_url):
database_url = DatabaseURL(str(database_url.url))
if not _startswith(database_url.dialect, ["mysql", "mariadb", "postgres", "mssql"]):
pytest.skip("Test requires sleep function")
async with Database(database_url, force_rollback=True) as database:
async with Database(database_url, force_rollback=True, full_isolation=False) as database:

async def db_lookup():
if database_url.dialect.startswith("postgres"):
Expand All @@ -75,7 +75,9 @@ async def db_lookup():
@pytest.mark.asyncio
async def test_multi_thread(database_url, force_rollback):
database_url = DatabaseURL(str(database_url.url))
async with Database(database_url, force_rollback=force_rollback) as database:
async with Database(
database_url, force_rollback=force_rollback, full_isolation=False
) as database:

async def db_lookup(in_thread):
async with database.connection() as conn:
Expand Down Expand Up @@ -111,7 +113,12 @@ def _future_helper(awaitable, future):
@pytest.mark.parametrize("force_rollback", [True, False])
@pytest.mark.asyncio
async def test_multi_thread_db_contextmanager(database_url, force_rollback, join_type):
async with Database(database_url, force_rollback=force_rollback) as database:
if join_type.startswith("thread_join") and force_rollback:
pytest.skip("not supported yet")

async with Database(
database_url, force_rollback=force_rollback, full_isolation=False
) as database:
query = notes.insert().values(text="examplecontext", completed=True)
await database.execute(query, timeout=10)
database._non_copied_attribute = True
Expand Down Expand Up @@ -146,7 +153,7 @@ async def db_connect(depth=3):
args.insert(0, ctx.run)
thread = Thread(target=args[0], args=args[1:])
thread.start()
future.result()
future.result(4)
else:
await to_thread(asyncio.run, asyncio.wait_for(db_connect(), 5))
assert database.ref_counter == 0
Expand Down

0 comments on commit 58c75f5

Please sign in to comment.