diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/DeltaLog.scala b/spark/src/main/scala/org/apache/spark/sql/delta/DeltaLog.scala index 0aff32f357..cb4dfa3e1b 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/DeltaLog.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/DeltaLog.scala @@ -221,9 +221,9 @@ class DeltaLog private( catalogTableOpt: Option[CatalogTable], snapshotOpt: Option[Snapshot] = None)( thunk: OptimisticTransaction => T): T = { + val txn = startTransaction(catalogTableOpt, snapshotOpt) + OptimisticTransaction.setActive(txn) try { - val txn = startTransaction(catalogTableOpt, snapshotOpt) - OptimisticTransaction.setActive(txn) thunk(txn) } finally { OptimisticTransaction.clearActive() @@ -233,9 +233,9 @@ class DeltaLog private( /** Legacy/compat overload that does not require catalog table information. Avoid prod use. */ @deprecated("Please use the CatalogTable overload instead", "3.0") def withNewTransaction[T](thunk: OptimisticTransaction => T): T = { + val txn = startTransaction() + OptimisticTransaction.setActive(txn) try { - val txn = startTransaction() - OptimisticTransaction.setActive(txn) thunk(txn) } finally { OptimisticTransaction.clearActive() diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/OptimisticTransaction.scala b/spark/src/main/scala/org/apache/spark/sql/delta/OptimisticTransaction.scala index 77d7d2d804..4e094fc7cd 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/OptimisticTransaction.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/OptimisticTransaction.scala @@ -216,10 +216,14 @@ object OptimisticTransaction { * `OptimisticTransaction.withNewTransaction`. Use that to create and set active txns. */ private[delta] def setActive(txn: OptimisticTransaction): Unit = { - if (active.get != null) { - throw DeltaErrors.activeTransactionAlreadySet() + getActive() match { + case Some(activeTxn) => + if (!(activeTxn eq txn)) { + throw DeltaErrors.activeTransactionAlreadySet() + } + case _ => + active.set(txn) } - active.set(txn) } /** diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/DeltaWithNewTransactionSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/DeltaWithNewTransactionSuite.scala index 79fa8a3fb6..981295ba8c 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/DeltaWithNewTransactionSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/DeltaWithNewTransactionSuite.scala @@ -263,16 +263,38 @@ trait DeltaWithNewTransactionSuiteBase extends QueryTest withTempDir { dir => val log = DeltaLog.forTable(spark, dir.getCanonicalPath) log.withNewTransaction { txn => - - require(OptimisticTransaction.getActive().nonEmpty) + assert(OptimisticTransaction.getActive() === Some(txn)) intercept[IllegalStateException] { - OptimisticTransaction.setActive(txn) + log.withNewTransaction { txn2 => } + } + assert(OptimisticTransaction.getActive() === Some(txn)) + } + assert(OptimisticTransaction.getActive().isEmpty) + } + } + + test("withActiveTxn idempotency") { + withTempDir { dir => + val log = DeltaLog.forTable(spark, dir.getCanonicalPath) + val txn = log.startTransaction() + assert(OptimisticTransaction.getActive().isEmpty) + OptimisticTransaction.withActive(txn) { + assert(OptimisticTransaction.getActive() === Some(txn)) + OptimisticTransaction.withActive(txn) { + assert(OptimisticTransaction.getActive() === Some(txn)) } + assert(OptimisticTransaction.getActive() === Some(txn)) + val txn2 = log.startTransaction() intercept[IllegalStateException] { - log.withNewTransaction { txn2 => } + OptimisticTransaction.withActive(txn2) { } + } + intercept[IllegalStateException] { + OptimisticTransaction.setActive(txn2) } + assert(OptimisticTransaction.getActive() === Some(txn)) } + assert(OptimisticTransaction.getActive().isEmpty) } }