Skip to content

Commit

Permalink
Handle connections that are not accessed before max lifetime
Browse files Browse the repository at this point in the history
Ensure that a connection older than the maximum lifetime is never given
out by the pool.
  • Loading branch information
guymers authored Oct 3, 2023
1 parent 0f4cf95 commit e7d0247
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 15 deletions.
31 changes: 23 additions & 8 deletions modules/zio/src/main/scala/zoobie/ConnectionPool.scala
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,19 @@ object ConnectionPool {
} yield {
new ConnectionPool {
override def get(implicit trace: Trace) = for {
tuple <- _get.withEarlyRelease
(close, i) = tuple
now <- zio.Clock.nanoTime
age = (now - i.acquired).nanoseconds
c <- {
if (age > config.maxConnectionLifetime) {
// been idle and missed invalidation, close and let the finalizer invalidate
close *> _get
} else ZIO.succeed(i)
}
} yield c.connection

private def _get(implicit trace: Trace) = for {
atQueueSize <- numQueuedRef.modify { i =>
if (i >= config.queueSize) (true, config.queueSize) else (false, i + 1)
}.uninterruptible
Expand All @@ -83,20 +96,22 @@ object ConnectionPool {
_ <- ZIO.acquireRelease(inUse.increment)(_ => inUse.decrement)

_ <- ZIO.addFinalizerExit {
case Exit.Success(_) => for {
now <- zio.Clock.nanoTime
age = (now - c.acquired).nanoseconds
maxLifetimeJitter <- zio.Random.nextDoubleBetween(0.9, 1.1)
maxLifetime = (config.maxConnectionLifetime.toNanos * maxLifetimeJitter).toLong.nanoseconds
_ <- invalidate(c).when(age > maxLifetime)
} yield ()
case Exit.Success(_) =>
invalidate(c).whenZIO {
for {
now <- zio.Clock.nanoTime
age = (now - c.acquired).nanoseconds
maxLifetimeJitter <- zio.Random.nextDoubleBetween(0.89, 0.99)
maxLifetime = (config.maxConnectionLifetime.toNanos * maxLifetimeJitter).toLong.nanoseconds
} yield age > maxLifetime
}

case Exit.Failure(_) =>
invalidate(c).unlessZIO {
c.isValid(config.validationTimeout).catchAll(_ => ZIO.succeed(false))
}
}
} yield c.connection
} yield c
}
}
}
Expand Down
12 changes: 7 additions & 5 deletions modules/zio/src/main/scala/zoobie/Transactor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,7 @@ sealed abstract class Transactor { self =>
* Execute the given [[ConnectionIO]] on a connection using the strategy.
*/
def run[A](io: ConnectionIO[A])(implicit trace: Trace): ZIO[Any, DatabaseError, A] = ZIO.scoped {
connection.flatMap { conn =>
translate(conn) { strategy.resource.use(_ => io) }
.mapError(DatabaseError(_))
}
connection.flatMap(interpret(io)(_))
}

/**
Expand All @@ -44,7 +41,7 @@ sealed abstract class Transactor { self =>
import zio.stream.interop.fs2z.*

s.translate(new (ConnectionIO ~> Task) {
override def apply[T](io: ConnectionIO[T]): ZIO[Any, DatabaseError, T] = run(io)
override def apply[T](io: ConnectionIO[T]) = run(io)
}).toZStream(chunkSize).mapError(DatabaseError(_))
}

Expand All @@ -65,6 +62,11 @@ sealed abstract class Transactor { self =>
}
}

def interpret[A](io: ConnectionIO[A])(c: Connection): ZIO[Any, DatabaseError, A] = {
translate(c) { strategy.resource.use(_ => io) }
.mapError(DatabaseError(_))
}

def translate(c: Connection): ConnectionIO ~> Task = {
implicit val monad: Monad[Task] = Transactor.sync

Expand Down
4 changes: 2 additions & 2 deletions modules/zio/src/test/scala/zoobie/ConnectionPoolSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ object ConnectionPoolSpec extends ZIOSpecDefault {

override val spec = suite("ConnectionPool")(
test("connection fails to acquire") {
// FIXME this test only passes due to being lucky with the connection order
for {
_ <- ZIO.unit
ref <- Ref.make(1)
create = ref.modify { i =>
val conn = if (i % 2 == 0) {
Expand Down Expand Up @@ -150,7 +150,7 @@ object ConnectionPoolSpec extends ZIOSpecDefault {
_ <- ZIO.foreachDiscard((1 to config.size).toList)(_ => ZIO.scoped(pool.get))
createdInitial <- createdRef.get

_ <- TestClock.adjust((config.maxConnectionLifetime.toNanos * 1.11).toLong.nanos) // jittered 0.9-1.1
_ <- TestClock.adjust(config.maxConnectionLifetime)

_ <- ZIO.foreachDiscard((1 to config.size).toList)(_ => ZIO.scoped(pool.get))
createdRefreshed <- createdRef.get
Expand Down

0 comments on commit e7d0247

Please sign in to comment.