diff --git a/databasez/core/connection.py b/databasez/core/connection.py
index e917cb3..d3def51 100644
--- a/databasez/core/connection.py
+++ b/databasez/core/connection.py
@@ -33,14 +33,16 @@ def _init_thread(database: Database, is_initialized: Event) -> None:
loop.run_forever()
except RuntimeError:
pass
- task.result()
try:
- loop.run_until_complete(database.disconnect())
- loop.run_until_complete(loop.shutdown_asyncgens())
+ task.result()
finally:
- del task
- loop.close()
- database._loop = None
+ try:
+ loop.run_until_complete(database.disconnect())
+ loop.run_until_complete(loop.shutdown_asyncgens())
+ finally:
+ del task
+ loop.close()
+ database._loop = None
class Connection:
@@ -54,7 +56,9 @@ def __init__(
self._full_isolation = full_isolation
self._database = database
if full_isolation:
- self._database = database.__class__(database, force_rollback=force_rollback)
+ self._database = database.__class__(
+ database, force_rollback=force_rollback, full_isolation=False
+ )
self._database._call_hooks = False
self._database._global_connection = self
self._backend = backend
@@ -77,11 +81,23 @@ async def __aenter__(self) -> Connection:
if self._full_isolation:
ctx = copy_context()
is_initialized = Event()
- self._isolation_thread = Thread(
- target=ctx.run, args=[_init_thread, self._database, is_initialized], daemon=False
+ self._isolation_thread = thread = Thread(
+ target=ctx.run,
+ args=[
+ _init_thread,
+ self._database.__class__(
+ self._database, force_rollback=True, full_isolation=False
+ ),
+ is_initialized,
+ ],
+ daemon=False,
)
- self._isolation_thread.start()
+ thread.start()
is_initialized.wait()
+ if not thread.is_alive():
+ thread = self._isolation_thread
+ self._isolation_thread = None
+ thread.join()
async with self._connection_lock:
self._connection_counter += 1
@@ -134,9 +150,12 @@ async def __aexit__(
finally:
if closing and self._full_isolation:
assert self._isolation_thread is not None
- self._database._loop.stop()
- self._isolation_thread.join()
+ thread = self._isolation_thread
+ loop = self._database._loop
self._isolation_thread = None
+ if loop:
+ loop.stop()
+ thread.join()
@property
def _loop(self) -> typing.Any:
diff --git a/databasez/core/database.py b/databasez/core/database.py
index b0acb6f..3159d79 100644
--- a/databasez/core/database.py
+++ b/databasez/core/database.py
@@ -294,7 +294,9 @@ async def connect(self) -> bool:
self.is_connected = True
if self._global_connection is None:
- connection = Connection(self, self.backend, force_rollback=True, full_isolation=self._full_isolation)
+ connection = Connection(
+ self, self.backend, force_rollback=True, full_isolation=self._full_isolation
+ )
self._global_connection = connection
return True
@@ -321,9 +323,13 @@ async def disconnect(
loop = asyncio.get_running_loop()
del parent_database._databases_map[loop]
if force:
- for sub_database in self._databases_map.values():
- await sub_database.disconnect(True)
- self._databases_map = {}
+ if self._databases_map:
+ assert not self._databases_map, "sub databases still active, force terminate them"
+ for sub_database in self._databases_map.values():
+ asyncio.run_coroutine_threadsafe(
+ sub_database.disconnect(True), sub_database._loop
+ )
+ self._databases_map = {}
assert not self._databases_map, "sub databases still active"
try:
diff --git a/databasez/testclient.py b/databasez/testclient.py
index c00b262..ae6be71 100644
--- a/databasez/testclient.py
+++ b/databasez/testclient.py
@@ -34,6 +34,7 @@ class DatabaseTestClient(Database):
# is used for copying Database and DatabaseTestClientand providing an early url
test_db_url: str
# hooks for overwriting defaults of args with None
+ testclient_default_full_isolation: bool = False
testclient_default_force_rollback: bool = False
testclient_default_lazy_setup: bool = False
# customization hooks
@@ -47,6 +48,7 @@ def __init__(
*,
force_rollback: typing.Union[bool, None] = None,
use_existing: typing.Union[bool, None] = None,
+ full_isolation: typing.Union[bool, None] = None,
drop_database: typing.Union[bool, None] = None,
lazy_setup: typing.Union[bool, None] = None,
test_prefix: typing.Union[str, None] = None,
@@ -56,6 +58,8 @@ def __init__(
use_existing = self.testclient_default_use_existing
if drop_database is None:
drop_database = self.testclient_default_drop_database
+ if full_isolation is None:
+ full_isolation = self.testclient_default_full_isolation
if test_prefix is None:
test_prefix = self.testclient_default_test_prefix
self._setup_executed_init = False
@@ -144,7 +148,7 @@ async def database_exists(cls, url: typing.Union[str, "sa.URL", DatabaseURL]) ->
text = "SELECT 1 FROM pg_database WHERE datname='%s'" % database
for db in (database, "postgres", "template1", "template0", None):
url = url.replace(database=db)
- async with Database(url) as db_client:
+ async with Database(url, full_isolation=False, force_rollback=False) as db_client:
try:
return bool(await _get_scalar_result(db_client.engine, sa.text(text)))
except (ProgrammingError, OperationalError):
@@ -157,7 +161,7 @@ async def database_exists(cls, url: typing.Union[str, "sa.URL", DatabaseURL]) ->
"SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA "
"WHERE SCHEMA_NAME = '%s'" % database
)
- async with Database(url) as db_client:
+ async with Database(url, full_isolation=False, force_rollback=False) as db_client:
return bool(await _get_scalar_result(db_client.engine, sa.text(text)))
elif dialect_name == "sqlite":
@@ -169,7 +173,7 @@ async def database_exists(cls, url: typing.Union[str, "sa.URL", DatabaseURL]) ->
return True
else:
text = "SELECT 1"
- async with Database(url) as db_client:
+ async with Database(url, full_isolation=False, force_rollback=False) as db_client:
try:
return bool(await _get_scalar_result(db_client.engine, sa.text(text)))
except (ProgrammingError, OperationalError):
@@ -201,9 +205,11 @@ async def create_database(
dialect_name == "postgresql"
and dialect_driver in {"asyncpg", "pg8000", "psycopg", "psycopg2", "psycopg2cffi"}
):
- db_client = Database(url, isolation_level="AUTOCOMMIT", force_rollback=False)
+ db_client = Database(
+ url, isolation_level="AUTOCOMMIT", force_rollback=False, full_isolation=False
+ )
else:
- db_client = Database(url, force_rollback=False)
+ db_client = Database(url, force_rollback=False, full_isolation=False)
async with db_client:
if dialect_name == "postgresql":
if not template:
@@ -255,9 +261,11 @@ async def drop_database(cls, url: typing.Union[str, "sa.URL", DatabaseURL]) -> N
dialect_name == "postgresql"
and dialect_driver in {"asyncpg", "pg8000", "psycopg", "psycopg2", "psycopg2cffi"}
):
- db_client = Database(url, isolation_level="AUTOCOMMIT", force_rollback=False)
+ db_client = Database(
+ url, isolation_level="AUTOCOMMIT", force_rollback=False, full_isolation=False
+ )
else:
- db_client = Database(url, force_rollback=False)
+ db_client = Database(url, force_rollback=False, full_isolation=False)
async with db_client:
if dialect_name == "sqlite" and database and database != ":memory:":
try:
diff --git a/databasez/utils.py b/databasez/utils.py
index 743df0f..4b1b289 100644
--- a/databasez/utils.py
+++ b/databasez/utils.py
@@ -145,6 +145,7 @@ def _run_with_timeout(inp: typing.Any, timeout: typing.Optional[float]) -> typin
inp = asyncio.wait_for(inp, timeout=timeout)
return inp
+
async def _arun_with_timeout(inp: typing.Any, timeout: typing.Optional[float]) -> typing.Any:
if timeout is not None and timeout > 0 and inspect.isawaitable(inp):
inp = await asyncio.wait_for(inp, timeout=timeout)
@@ -152,6 +153,7 @@ async def _arun_with_timeout(inp: typing.Any, timeout: typing.Optional[float]) -
return await inp
return inp
+
class AsyncHelperDatabase:
def __init__(
self,
diff --git a/docs/connections-and-transactions.md b/docs/connections-and-transactions.md
index ee8300d..e6a1047 100644
--- a/docs/connections-and-transactions.md
+++ b/docs/connections-and-transactions.md
@@ -24,6 +24,11 @@ from databasez import Database
Default: `None`
+* **full_isolation** - Special mode for using force_rollback with nested queries. This parameter fully isolates the global connection
+ in an extra thread. This way it is possible to use blocking operations.
+
+ Default: `None`
+
* **config** - A python like dictionary as alternative to the `url` that contains the information
to connect to the database.
@@ -35,6 +40,8 @@ to connect to the database.
Be careful when setting up the `url` or `config`. You can use one or the other but not both
at the same time.
+!!! Warning
+ `full_isolation` is not mature and shouldn't be used in production code.
**Attributes***
diff --git a/docs/release-notes.md b/docs/release-notes.md
index c099b24..9ec9f7f 100644
--- a/docs/release-notes.md
+++ b/docs/release-notes.md
@@ -1,6 +1,14 @@
# Release Notes
+## 0.10.0
+
+
+### Added
+
+- `full_isolation` parameter. Isolate the force_rollback Connection in a thread.
+
+
## 0.9.7
### Added
diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py
index 81aec73..87f1530 100644
--- a/tests/test_concurrency.py
+++ b/tests/test_concurrency.py
@@ -56,7 +56,7 @@ async def test_concurrent_access_on_single_connection(database_url):
database_url = DatabaseURL(str(database_url.url))
if not _startswith(database_url.dialect, ["mysql", "mariadb", "postgres", "mssql"]):
pytest.skip("Test requires sleep function")
- async with Database(database_url, force_rollback=True) as database:
+ async with Database(database_url, force_rollback=True, full_isolation=False) as database:
async def db_lookup():
if database_url.dialect.startswith("postgres"):
@@ -75,7 +75,9 @@ async def db_lookup():
@pytest.mark.asyncio
async def test_multi_thread(database_url, force_rollback):
database_url = DatabaseURL(str(database_url.url))
- async with Database(database_url, force_rollback=force_rollback) as database:
+ async with Database(
+ database_url, force_rollback=force_rollback, full_isolation=False
+ ) as database:
async def db_lookup(in_thread):
async with database.connection() as conn:
@@ -111,7 +113,12 @@ def _future_helper(awaitable, future):
@pytest.mark.parametrize("force_rollback", [True, False])
@pytest.mark.asyncio
async def test_multi_thread_db_contextmanager(database_url, force_rollback, join_type):
- async with Database(database_url, force_rollback=force_rollback) as database:
+ if join_type.startswith("thread_join") and force_rollback:
+ pytest.skip("not supported yet")
+
+ async with Database(
+ database_url, force_rollback=force_rollback, full_isolation=False
+ ) as database:
query = notes.insert().values(text="examplecontext", completed=True)
await database.execute(query, timeout=10)
database._non_copied_attribute = True
@@ -146,7 +153,7 @@ async def db_connect(depth=3):
args.insert(0, ctx.run)
thread = Thread(target=args[0], args=args[1:])
thread.start()
- future.result()
+ future.result(4)
else:
await to_thread(asyncio.run, asyncio.wait_for(db_connect(), 5))
assert database.ref_counter == 0