Skip to content

Commit

Permalink
refactor: prepare for altering prepared execution for Update (typelev…
Browse files Browse the repository at this point in the history
  • Loading branch information
ulfryk committed Dec 18, 2024
1 parent 1423839 commit 5a89748
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 37 deletions.
2 changes: 1 addition & 1 deletion modules/core/src/main/scala/doobie/hi/connection.scala
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ object connection {
loggingInfo
)

def executionWithResultSet[A](
def executeWithResultSet[A](
prepared: PreparedExecution[A],
loggingInfo: LoggingInfo
): ConnectionIO[A] = executeWithResultSet(
Expand Down
4 changes: 2 additions & 2 deletions modules/core/src/main/scala/doobie/util/query.scala
Original file line number Diff line number Diff line change
Expand Up @@ -192,14 +192,14 @@ object query {
toConnectionIOAlteringExecution(a, IHRS.nel[B], fn)

private def toConnectionIO[C](a: A, rsio: ResultSetIO[C]): ConnectionIO[C] =
IHC.executionWithResultSet(preparedExecution(sql, a, rsio), mkLoggingInfo(a))
IHC.executeWithResultSet(preparedExecution(sql, a, rsio), mkLoggingInfo(a))

private def toConnectionIOAlteringExecution[C](
a: A,
rsio: ResultSetIO[C],
fn: PreparedExecution[C] => PreparedExecution[C]
): ConnectionIO[C] =
IHC.executionWithResultSet(fn(preparedExecution(sql, a, rsio)), mkLoggingInfo(a))
IHC.executeWithResultSet(fn(preparedExecution(sql, a, rsio)), mkLoggingInfo(a))

private def preparedExecution[C](sql: String, a: A, rsio: ResultSetIO[C]): PreparedExecution[C] =
PreparedExecution(
Expand Down
78 changes: 44 additions & 34 deletions modules/core/src/main/scala/doobie/util/update.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import doobie.implicits.*
import doobie.util.analysis.Analysis
import doobie.util.pos.Pos
import doobie.free.{connection as IFC, preparedstatement as IFPS}
import doobie.hi.connection.{PreparedExecution, PreparedExecutionWithoutProcessStep}
import doobie.hi.{connection as IHC, preparedstatement as IHPS, resultset as IHRS}
import doobie.util.fragment.Fragment
import doobie.util.log.{LoggingInfo, Parameters}
Expand Down Expand Up @@ -81,18 +82,22 @@ object update {
/** Construct a program to execute the update and yield a count of affected rows, given the writable argument `a`.
* @group Execution
*/
def run(a: A): ConnectionIO[Int] = {
IHC.executeWithoutResultSet(
IFC.prepareStatement(sql),
IHPS.set(a),
IFPS.executeUpdate,
LoggingInfo(
sql,
Parameters.NonBatch(write.toList(a)),
label
)
def run(a: A): ConnectionIO[Int] =
IHC.executeWithoutResultSet(prepareExecutionForRun(a), loggingForRun(a))

private def prepareExecutionForRun(a: A): PreparedExecutionWithoutProcessStep[Int] =
PreparedExecutionWithoutProcessStep(
create = IFC.prepareStatement(sql),
prep = IHPS.set(a),
exec = IFPS.executeUpdate
)

private def loggingForRun(a: A): LoggingInfo =
LoggingInfo(
sql,
Parameters.NonBatch(write.toList(a)),
label
)
}

/** Add many sets of parameters and execute as a batch update, returning total rows updated. Note that when an error
* occurred while executing the batch, your JDBC driver may decide to continue executing the rest of the batch
Expand All @@ -103,15 +108,20 @@ object update {
* @group Execution
*/
def updateMany[F[_]: Foldable](fa: F[A]): ConnectionIO[Int] =
IHC.executeWithoutResultSet(
IHC.executeWithoutResultSet(prepareExecutionForUpdateMany(fa), loggingInfoForUpdateMany(fa))

private def prepareExecutionForUpdateMany[F[_]: Foldable](fa: F[A]): PreparedExecutionWithoutProcessStep[Int] =
PreparedExecutionWithoutProcessStep(
create = IFC.prepareStatement(sql),
prep = fa.foldMap(a => IHPS.set(a) *> IFPS.addBatch),
exec = IFPS.executeBatch.map(updateCounts => updateCounts.foldLeft(0)((acc, n) => acc + (n.max(0)))),
loggingInfo = LoggingInfo(
sql,
Parameters.Batch(() => fa.toList.map(write.toList)),
label
)
exec = IFPS.executeBatch.map(updateCounts => updateCounts.foldLeft(0)((acc, n) => acc + n.max(0)))
)

private def loggingInfoForUpdateMany[F[_]: Foldable](fa: F[A]) =
LoggingInfo(
sql,
Parameters.Batch(() => fa.toList.map(write.toList)),
label
)

/** Construct a stream that performs a batch update as with `updateMany`, yielding generated keys of readable type
Expand All @@ -130,11 +140,7 @@ object update {
prep = IHPS.addBatches(as),
exec = IFPS.executeBatch *> IFPS.getGeneratedKeys,
chunkSize = chunkSize,
loggingInfo = LoggingInfo(
sql,
Parameters.Batch(() => as.toList.map(Write[A].toList)),
label
)
loggingInfo = loggingInfoForUpdateMany(as)
)
}

Expand All @@ -157,11 +163,7 @@ object update {
prep = IHPS.set(a),
exec = IFPS.executeUpdate *> IFPS.getGeneratedKeys,
chunkSize = chunkSize,
loggingInfo = LoggingInfo(
sql = sql,
params = Parameters.NonBatch(Write[A].toList(a)),
label = label
)
loggingInfo = loggingInfoForUpdateWithGeneratedKeys(a)
)

/** Construct a program that performs the update, yielding a single set of generated keys of readable type `K`,
Expand All @@ -171,15 +173,23 @@ object update {
*/
def withUniqueGeneratedKeys[K: Read](columns: String*)(a: A): ConnectionIO[K] =
IHC.executeWithResultSet(
prepareExecutionForWithUniqueGeneratedKeys(columns*)(a),
loggingInfoForUpdateWithGeneratedKeys(a)
)

private def prepareExecutionForWithUniqueGeneratedKeys[K: Read](columns: String*)(a: A): PreparedExecution[K] =
PreparedExecution(
create = IFC.prepareStatement(sql, columns.toArray),
prep = IHPS.set(a),
exec = IFPS.executeUpdate *> IFPS.getGeneratedKeys,
process = IHRS.getUnique,
loggingInfo = LoggingInfo(
sql,
Parameters.NonBatch(write.toList(a)),
label
)
process = IHRS.getUnique
)

private def loggingInfoForUpdateWithGeneratedKeys(a: A) =
LoggingInfo(
sql,
Parameters.NonBatch(write.toList(a)),
label
)

/** Update is a contravariant functor.
Expand Down

0 comments on commit 5a89748

Please sign in to comment.