diff --git a/quaint/src/connector/mssql.rs b/quaint/src/connector/mssql.rs index cef092edb9d7..7e1ee59f9be7 100644 --- a/quaint/src/connector/mssql.rs +++ b/quaint/src/connector/mssql.rs @@ -22,6 +22,7 @@ use std::{ use tiberius::*; use tokio::net::TcpStream; use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt}; +use std::sync::Arc; /// The underlying SQL Server driver. Only available with the `expose-drivers` Cargo feature. #[cfg(feature = "expose-drivers")] @@ -106,10 +107,25 @@ impl TransactionCapable for Mssql { .or(self.url.query_params.transaction_isolation_level) .or(Some(SQL_SERVER_DEFAULT_ISOLATION)); - let opts = TransactionOptions::new(isolation, self.requires_isolation_first()); + let mut transaction_depth = self.transaction_depth.lock().await; + *transaction_depth += 1; + let st_depth = *transaction_depth + 0; + + let begin_statement = self.begin_statement(st_depth).await; + let commit_stmt = self.commit_statement(st_depth).await; + let rollback_stmt = self.rollback_statement(st_depth).await; + + let opts = TransactionOptions::new( + isolation, + self.requires_isolation_first(), + self.transaction_depth.clone(), + commit_stmt, + rollback_stmt, + ); + Ok(Box::new( - DefaultTransaction::new(self, self.begin_statement(), opts).await?, + DefaultTransaction::new(self, &begin_statement, opts).await?, )) } } @@ -273,6 +289,7 @@ pub struct Mssql { url: MssqlUrl, socket_timeout: Option, is_healthy: AtomicBool, + transaction_depth: Arc>, } impl Mssql { @@ -304,6 +321,7 @@ impl Mssql { url, socket_timeout, is_healthy: AtomicBool::new(true), + transaction_depth: Arc::new(Mutex::new(0)), }; if let Some(isolation) = this.url.transaction_isolation_level() { @@ -443,8 +461,41 @@ impl Queryable for Mssql { Ok(()) } - fn begin_statement(&self) -> &'static str { - "BEGIN TRAN" + /// Statement to begin a transaction + async fn begin_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("SAVE TRANSACTION savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "BEGIN TRAN".to_string() + }; + + return ret + } + + /// Statement to commit a transaction + async fn commit_statement(&self, depth: i32) -> String { + // MSSQL doesn't have a "RELEASE SAVEPOINT" equivalent, so in a nested + // transaction we just continue onwards + let ret = if depth > 1 { + " ".to_string() + } else { + "COMMIT".to_string() + }; + + return ret + } + + /// Statement to rollback a transaction + async fn rollback_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("ROLLBACK TRANSACTION savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "ROLLBACK".to_string() + }; + + return ret } fn requires_isolation_first(&self) -> bool { diff --git a/quaint/src/connector/mysql.rs b/quaint/src/connector/mysql.rs index e5a1b794ab5b..68d9cdb95e65 100644 --- a/quaint/src/connector/mysql.rs +++ b/quaint/src/connector/mysql.rs @@ -23,6 +23,7 @@ use std::{ }; use tokio::sync::Mutex; use url::{Host, Url}; +use std::sync::Arc; /// The underlying MySQL driver. Only available with the `expose-drivers` /// Cargo feature. @@ -39,6 +40,7 @@ pub struct Mysql { socket_timeout: Option, is_healthy: AtomicBool, statement_cache: Mutex>, + transaction_depth: Arc>, } /// Wraps a connection url and exposes the parsing logic used by quaint, including default values. @@ -374,6 +376,7 @@ impl Mysql { statement_cache: Mutex::new(url.cache()), url, is_healthy: AtomicBool::new(true), + transaction_depth: Arc::new(futures::lock::Mutex::new(0)), }) } @@ -581,6 +584,42 @@ impl Queryable for Mysql { fn requires_isolation_first(&self) -> bool { true } + + /// Statement to begin a transaction + async fn begin_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "BEGIN".to_string() + }; + + return ret + } + + /// Statement to commit a transaction + async fn commit_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("RELEASE SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "COMMIT".to_string() + }; + + return ret + } + + /// Statement to rollback a transaction + async fn rollback_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("ROLLBACK TO savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "ROLLBACK".to_string() + }; + + return ret + } } #[cfg(test)] diff --git a/quaint/src/connector/postgres.rs b/quaint/src/connector/postgres.rs index 2c81144c812b..e1ab15bb9a8e 100644 --- a/quaint/src/connector/postgres.rs +++ b/quaint/src/connector/postgres.rs @@ -26,6 +26,7 @@ use tokio_postgres::{ Client, Config, Statement, }; use url::{Host, Url}; +use std::sync::Arc; pub(crate) const DEFAULT_SCHEMA: &str = "public"; @@ -61,6 +62,7 @@ pub struct PostgreSql { socket_timeout: Option, statement_cache: Mutex>, is_healthy: AtomicBool, + transaction_depth: Arc>, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -650,6 +652,7 @@ impl PostgreSql { pg_bouncer: url.query_params.pg_bouncer, statement_cache: Mutex::new(url.cache()), is_healthy: AtomicBool::new(true), + transaction_depth: Arc::new(Mutex::new(0)), }) } @@ -930,6 +933,42 @@ impl Queryable for PostgreSql { fn requires_isolation_first(&self) -> bool { false } + + /// Statement to begin a transaction + async fn begin_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "BEGIN".to_string() + }; + + return ret + } + + /// Statement to commit a transaction + async fn commit_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("RELEASE SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "COMMIT".to_string() + }; + + return ret + } + + /// Statement to rollback a transaction + async fn rollback_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("ROLLBACK TO SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "ROLLBACK".to_string() + }; + + return ret + } } /// Sorted list of CockroachDB's reserved keywords. diff --git a/quaint/src/connector/queryable.rs b/quaint/src/connector/queryable.rs index 09dbc7abba4c..f991e5506d9e 100644 --- a/quaint/src/connector/queryable.rs +++ b/quaint/src/connector/queryable.rs @@ -87,8 +87,18 @@ pub trait Queryable: Send + Sync { } /// Statement to begin a transaction - fn begin_statement(&self) -> &'static str { - "BEGIN" + async fn begin_statement(&self, _depth: i32) -> String { + "BEGIN".to_string() + } + + /// Statement to commit a transaction + async fn commit_statement(&self, _depth: i32) -> String { + "COMMIT".to_string() + } + + /// Statement to rollback a transaction + async fn rollback_statement(&self, _depth: i32) -> String { + "ROLLBACK".to_string() } /// Sets the transaction isolation level to given value. @@ -117,10 +127,28 @@ macro_rules! impl_default_TransactionCapable { &'a self, isolation: Option, ) -> crate::Result> { - let opts = crate::connector::TransactionOptions::new(isolation, self.requires_isolation_first()); + let depth = self.transaction_depth.clone(); + let mut depth_guard = self.transaction_depth.lock().await; + *depth_guard += 1; + + let st_depth = *depth_guard; + + let begin_statement = self.begin_statement(st_depth).await; + let commit_stmt = self.commit_statement(st_depth).await; + let rollback_stmt = self.rollback_statement(st_depth).await; + + + + let opts = crate::connector::TransactionOptions::new( + isolation, + self.requires_isolation_first(), + depth, + commit_stmt, + rollback_stmt, + ); Ok(Box::new( - crate::connector::DefaultTransaction::new(self, self.begin_statement(), opts).await?, + crate::connector::DefaultTransaction::new(self, &begin_statement, opts).await?, )) } } diff --git a/quaint/src/connector/sqlite.rs b/quaint/src/connector/sqlite.rs index 6db49523c80a..83a3c3d0274a 100644 --- a/quaint/src/connector/sqlite.rs +++ b/quaint/src/connector/sqlite.rs @@ -13,6 +13,7 @@ use crate::{ use async_trait::async_trait; use std::{convert::TryFrom, path::Path, time::Duration}; use tokio::sync::Mutex; +use std::sync::Arc; pub(crate) const DEFAULT_SQLITE_SCHEMA_NAME: &str = "main"; @@ -23,6 +24,7 @@ pub use rusqlite; /// A connector interface for the SQLite database pub struct Sqlite { pub(crate) client: Mutex, + transaction_depth: Arc>, } /// Wraps a connection url and exposes the parsing logic used by Quaint, @@ -139,7 +141,10 @@ impl TryFrom<&str> for Sqlite { let client = Mutex::new(conn); - Ok(Sqlite { client }) + Ok(Sqlite { + client, + transaction_depth: Arc::new(futures::lock::Mutex::new(0)), + }) } } @@ -154,6 +159,7 @@ impl Sqlite { Ok(Sqlite { client: Mutex::new(client), + transaction_depth: Arc::new(futures::lock::Mutex::new(0)), }) } @@ -250,6 +256,42 @@ impl Queryable for Sqlite { fn requires_isolation_first(&self) -> bool { false } + + /// Statement to begin a transaction + async fn begin_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "BEGIN".to_string() + }; + + return ret + } + + /// Statement to commit a transaction + async fn commit_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("RELEASE SAVEPOINT savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "COMMIT".to_string() + }; + + return ret + } + + /// Statement to rollback a transaction + async fn rollback_statement(&self, depth: i32) -> String { + let savepoint_stmt = format!("ROLLBACK TO savepoint{}", depth); + let ret = if depth > 1 { + savepoint_stmt + } else { + "ROLLBACK".to_string() + }; + + return ret + } } #[cfg(test)] diff --git a/quaint/src/connector/transaction.rs b/quaint/src/connector/transaction.rs index b7e91e97f6a8..e6caa01b0e2a 100644 --- a/quaint/src/connector/transaction.rs +++ b/quaint/src/connector/transaction.rs @@ -6,16 +6,18 @@ use crate::{ use async_trait::async_trait; use metrics::{decrement_gauge, increment_gauge}; use std::{fmt, str::FromStr}; +use futures::lock::Mutex; +use std::sync::Arc; extern crate metrics as metrics; #[async_trait] pub trait Transaction: Queryable { /// Commit the changes to the database and consume the transaction. - async fn commit(&self) -> crate::Result<()>; + async fn commit(&mut self) -> crate::Result<()>; /// Rolls back the changes to the database. - async fn rollback(&self) -> crate::Result<()>; + async fn rollback(&mut self) -> crate::Result<()>; /// workaround for lack of upcasting between traits https://github.com/rust-lang/rust/issues/65991 fn as_queryable(&self) -> &dyn Queryable; @@ -27,6 +29,15 @@ pub(crate) struct TransactionOptions { /// Whether or not to put the isolation level `SET` before or after the `BEGIN`. pub(crate) isolation_first: bool, + + /// The depth of the transaction, used to determine the nested transaction statements. + pub depth: Arc>, + + /// The statement to use to commit the transaction. + pub commit_stmt: String, + + /// The statement to use to rollback the transaction. + pub rollback_stmt: String, } /// A default representation of an SQL database transaction. If not commited, a @@ -36,6 +47,9 @@ pub(crate) struct TransactionOptions { /// transaction object will panic. pub struct DefaultTransaction<'a> { pub inner: &'a dyn Queryable, + pub depth: Arc>, + pub commit_stmt: String, + pub rollback_stmt: String, } impl<'a> DefaultTransaction<'a> { @@ -44,7 +58,12 @@ impl<'a> DefaultTransaction<'a> { begin_stmt: &str, tx_opts: TransactionOptions, ) -> crate::Result> { - let this = Self { inner }; + let this = Self { + inner, + depth: tx_opts.depth, + commit_stmt: tx_opts.commit_stmt, + rollback_stmt: tx_opts.rollback_stmt, + }; if tx_opts.isolation_first { if let Some(isolation) = tx_opts.isolation_level { @@ -70,17 +89,29 @@ impl<'a> DefaultTransaction<'a> { #[async_trait] impl<'a> Transaction for DefaultTransaction<'a> { /// Commit the changes to the database and consume the transaction. - async fn commit(&self) -> crate::Result<()> { + async fn commit(&mut self) -> crate::Result<()> { decrement_gauge!("prisma_client_queries_active", 1.0); - self.inner.raw_cmd("COMMIT").await?; + + let mut depth_guard = self.depth.lock().await; + + self.inner.raw_cmd(&self.commit_stmt).await?; + + // Modify the depth value through the MutexGuard + *depth_guard -= 1; Ok(()) } /// Rolls back the changes to the database. - async fn rollback(&self) -> crate::Result<()> { + async fn rollback(&mut self) -> crate::Result<()> { decrement_gauge!("prisma_client_queries_active", 1.0); - self.inner.raw_cmd("ROLLBACK").await?; + + let mut depth_guard = self.depth.lock().await; + + self.inner.raw_cmd(&self.rollback_stmt).await?; + + // Modify the depth value through the MutexGuard + *depth_guard -= 1; Ok(()) } @@ -190,10 +221,19 @@ impl FromStr for IsolationLevel { } } impl TransactionOptions { - pub fn new(isolation_level: Option, isolation_first: bool) -> Self { + pub fn new( + isolation_level: Option, + isolation_first: bool, + depth: Arc>, + commit_stmt: String, + rollback_stmt: String, + ) -> Self { Self { isolation_level, isolation_first, + depth, + commit_stmt, + rollback_stmt, } } } diff --git a/quaint/src/pooled.rs b/quaint/src/pooled.rs index 4c4152923377..aec229b744dc 100644 --- a/quaint/src/pooled.rs +++ b/quaint/src/pooled.rs @@ -500,7 +500,10 @@ impl Quaint { } }; - Ok(PooledConnection { inner }) + Ok(PooledConnection { + inner, + transaction_depth: Arc::new(futures::lock::Mutex::new(0)) + }) } /// Info about the connection and underlying database. diff --git a/quaint/src/pooled/manager.rs b/quaint/src/pooled/manager.rs index c0aa8c93b75d..27367961cbe5 100644 --- a/quaint/src/pooled/manager.rs +++ b/quaint/src/pooled/manager.rs @@ -11,11 +11,14 @@ use crate::{ }; use async_trait::async_trait; use mobc::{Connection as MobcPooled, Manager}; +use futures::lock::Mutex; +use std::sync::Arc; /// A connection from the pool. Implements /// [Queryable](connector/trait.Queryable.html). pub struct PooledConnection { pub(crate) inner: MobcPooled, + pub transaction_depth: Arc>, } impl_default_TransactionCapable!(PooledConnection); @@ -62,8 +65,16 @@ impl Queryable for PooledConnection { self.inner.server_reset_query(tx).await } - fn begin_statement(&self) -> &'static str { - self.inner.begin_statement() + async fn begin_statement(&self, depth: i32) -> String { + self.inner.begin_statement(depth).await + } + + async fn commit_statement(&self, depth: i32) -> String { + self.inner.commit_statement(depth).await + } + + async fn rollback_statement(&self, depth: i32) -> String { + self.inner.rollback_statement(depth).await } async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { diff --git a/quaint/src/single.rs b/quaint/src/single.rs index 82042f58010b..da173321ff51 100644 --- a/quaint/src/single.rs +++ b/quaint/src/single.rs @@ -8,6 +8,7 @@ use crate::{ }; use async_trait::async_trait; use std::{fmt, sync::Arc}; +use futures::lock::Mutex; #[cfg(feature = "sqlite")] use std::convert::TryFrom; @@ -17,6 +18,7 @@ use std::convert::TryFrom; pub struct Quaint { inner: Arc, connection_info: Arc, + transaction_depth: Arc>, } impl fmt::Debug for Quaint { @@ -163,7 +165,7 @@ impl Quaint { let connection_info = Arc::new(ConnectionInfo::from_url(url_str)?); Self::log_start(&connection_info); - Ok(Self { inner, connection_info }) + Ok(Self { inner, connection_info, transaction_depth: Arc::new(Mutex::new(0)) }) } #[cfg(feature = "sqlite")] @@ -174,6 +176,7 @@ impl Quaint { connection_info: Arc::new(ConnectionInfo::InMemorySqlite { db_name: DEFAULT_SQLITE_SCHEMA_NAME.to_owned(), }), + transaction_depth: Arc::new(Mutex::new(0)), }) } @@ -228,8 +231,16 @@ impl Queryable for Quaint { self.inner.is_healthy() } - fn begin_statement(&self) -> &'static str { - self.inner.begin_statement() + async fn begin_statement(&self, depth: i32) -> String { + self.inner.begin_statement(depth).await + } + + async fn commit_statement(&self, depth: i32) -> String { + self.inner.commit_statement(depth).await + } + + async fn rollback_statement(&self, depth: i32) -> String { + self.inner.rollback_statement(depth).await } async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { diff --git a/quaint/src/tests/query.rs b/quaint/src/tests/query.rs index 06bebe1a9601..cf471fbf7330 100644 --- a/quaint/src/tests/query.rs +++ b/quaint/src/tests/query.rs @@ -64,7 +64,7 @@ async fn select_star_from(api: &mut dyn TestApi) -> crate::Result<()> { async fn transactions(api: &mut dyn TestApi) -> crate::Result<()> { let table = api.create_temp_table("value int").await?; - let tx = api.conn().start_transaction(None).await?; + let mut tx = api.conn().start_transaction(None).await?; let insert = Insert::single_into(&table).value("value", 10); let rows_affected = tx.execute(insert.into()).await?; @@ -75,6 +75,20 @@ async fn transactions(api: &mut dyn TestApi) -> crate::Result<()> { assert_eq!(Value::int32(10), res[0]); + // Check that nested transactions are also rolled back, even at multiple levels deep + let mut tx_inner = api.conn().start_transaction(None).await?; + let inner_insert1 = Insert::single_into(&table).value("value", 20); + let inner_rows_affected1 = tx.execute(inner_insert1.into()).await?; + assert_eq!(1, inner_rows_affected1); + + let mut tx_inner2 = api.conn().start_transaction(None).await?; + let inner_insert2 = Insert::single_into(&table).value("value", 20); + let inner_rows_affected2 = tx.execute(inner_insert2.into()).await?; + assert_eq!(1, inner_rows_affected2); + tx_inner2.commit().await?; + + tx_inner.commit().await?; + tx.rollback().await?; let select = Select::from_table(&table).column("value"); diff --git a/quaint/src/tests/query/error.rs b/quaint/src/tests/query/error.rs index 69c57332b6d3..67334858576e 100644 --- a/quaint/src/tests/query/error.rs +++ b/quaint/src/tests/query/error.rs @@ -456,7 +456,7 @@ async fn concurrent_transaction_conflict(api: &mut dyn TestApi) -> crate::Result let conn1 = api.create_additional_connection().await?; let conn2 = api.create_additional_connection().await?; - let tx1 = conn1.start_transaction(Some(IsolationLevel::Serializable)).await?; + let mut tx1 = conn1.start_transaction(Some(IsolationLevel::Serializable)).await?; let tx2 = conn2.start_transaction(Some(IsolationLevel::Serializable)).await?; tx1.query(Select::from_table(&table).into()).await?;