diff --git a/bb8/src/inner.rs b/bb8/src/inner.rs index be2ef52..90ff63e 100644 --- a/bb8/src/inner.rs +++ b/bb8/src/inner.rs @@ -91,6 +91,7 @@ where let mut wait_time_start = None; let future = async { + let _guard = self.inner.request(); loop { let (conn, approvals) = self.inner.pop(); self.spawn_replenishing_approvals(approvals); @@ -158,7 +159,7 @@ where } let approvals = locked.dropped(1, &self.inner.statics); self.spawn_replenishing_approvals(approvals); - self.inner.notify.notify_waiters(); + self.inner.notify.notify_one(); } } } diff --git a/bb8/src/internals.rs b/bb8/src/internals.rs index 155e21a..855dd1c 100644 --- a/bb8/src/internals.rs +++ b/bb8/src/internals.rs @@ -22,6 +22,27 @@ where pub(crate) statistics: AtomicStatistics, } +pub(crate) struct GetGuard { + inner: Arc>, +} + +impl GetGuard { + fn new(inner: Arc>) -> Self { + { + let mut locked = inner.internals.lock(); + locked.inflight_gets += 1; + } + GetGuard { inner } + } +} + +impl Drop for GetGuard { + fn drop(&mut self) { + let mut locked = self.inner.internals.lock(); + locked.inflight_gets -= 1; + } +} + impl SharedPool where M: ManageConnection + Send, @@ -41,12 +62,19 @@ where let conn = locked.conns.pop_front().map(|idle| idle.conn); let approvals = match &conn { Some(_) => locked.wanted(&self.statics), - None => locked.approvals(&self.statics, 1), + None => { + let approvals = min(1, locked.inflight_gets.saturating_sub(locked.pending_conns)); + locked.approvals(&self.statics, approvals) + } }; (conn, approvals) } + pub(crate) fn request(self: &Arc) -> GetGuard { + GetGuard::new(self.clone()) + } + pub(crate) fn try_put(self: &Arc, conn: M::Connection) -> Result<(), M::Connection> { let mut locked = self.internals.lock(); let mut approvals = locked.approvals(&self.statics, 1); @@ -81,6 +109,7 @@ where conns: VecDeque>, num_conns: u32, pending_conns: u32, + inflight_gets: u32, } impl PoolInternals @@ -202,6 +231,7 @@ where conns: VecDeque::new(), num_conns: 0, pending_conns: 0, + inflight_gets: 0, } } } diff --git a/bb8/tests/test.rs b/bb8/tests/test.rs index 0e1225e..0246501 100644 --- a/bb8/tests/test.rs +++ b/bb8/tests/test.rs @@ -1068,3 +1068,37 @@ async fn test_add_checks_broken_connections() { let res = pool.add(conn); assert!(matches!(res, Err(AddError::Broken(_)))); } + +#[tokio::test] +async fn test_reuse_on_drop() { + let pool = Pool::builder() + .min_idle(0) + .max_size(100) + .queue_strategy(QueueStrategy::Lifo) + .build(OkManager::::new()) + .await + .unwrap(); + + // The first get should + // 1) see nothing in the pool, + // 2) spawn a single replenishing approval, + // 3) get notified of the new connection and grab it from the pool + let conn_0 = pool.get().await.expect("should connect"); + // Dropping the connection queues up a notify + drop(conn_0); + // The second get should + // 1) see the first connection in the pool and grab it + let _conn_1: PooledConnection> = + pool.get().await.expect("should connect"); + // The third get will + // 1) see nothing in the pool, + // 2) spawn a single replenishing approval, + // 3) get notified of the new connection, + // 4) see nothing in the pool, + // 5) _not_ spawn a single replenishing approval, + // 6) get notified of the new connection and grab it from the pool + let _conn_2: PooledConnection> = + pool.get().await.expect("should connect"); + + assert_eq!(pool.state().connections, 2); +}