From 3860300de2004a727bd363f2b6b9774a0595f238 Mon Sep 17 00:00:00 2001 From: Jonas Platte Date: Sat, 13 Jul 2024 22:49:29 +0200 Subject: [PATCH] WIP: Waker Drop fix --- Cargo.toml | 2 +- eyeball/Cargo.toml | 2 +- eyeball/src/lock.rs | 12 ++++++++++++ eyeball/src/state.rs | 13 ++++++++++++- eyeball/src/subscriber.rs | 18 ++++++++++++++---- eyeball/src/subscriber/async_lock.rs | 18 +++++++++++++++--- 6 files changed, 55 insertions(+), 10 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 552a153..af57324 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,7 @@ keywords = ["async", "observable", "reactive"] assert_matches = "1.5.0" futures-core = "0.3.26" futures-util = { version = "0.3.26", default-features = false } -readlock = "0.1.5" +readlock = "0.1.8" stream_assert = "0.1.0" tokio = { version = "1.25.0", features = ["sync"] } tokio-util = "0.7.8" diff --git a/eyeball/Cargo.toml b/eyeball/Cargo.toml index f09e5c6..9d63dbe 100644 --- a/eyeball/Cargo.toml +++ b/eyeball/Cargo.toml @@ -16,7 +16,7 @@ rustdoc-args = ["--cfg", "docsrs"] [dependencies] futures-core.workspace = true readlock.workspace = true -readlock-tokio = { version = "0.1.1", optional = true } +readlock-tokio = { version = "0.1.3", optional = true } slab = "0.4.9" tracing = { workspace = true, optional = true } tokio = { workspace = true, optional = true } diff --git a/eyeball/src/lock.rs b/eyeball/src/lock.rs index 4b2d31b..7e9b668 100644 --- a/eyeball/src/lock.rs +++ b/eyeball/src/lock.rs @@ -25,6 +25,8 @@ pub trait Lock { fn new_shared(value: T) -> Self::Shared; fn shared_read_count(shared: &Self::Shared) -> usize; fn shared_into_inner(shared: Self::Shared) -> Arc>; + + fn drop_waker(state: &Self::SubscriberState, observed_version: u64, waker_key: usize); } /// Marker type for using a synchronous lock for the inner value. @@ -61,6 +63,12 @@ impl Lock for SyncLock { fn shared_into_inner(shared: Self::Shared) -> Arc> { Self::Shared::into_inner(shared) } + + fn drop_waker(state: &Self::SubscriberState, observed_version: u64, waker_key: usize) { + if let Ok(guard) = state.try_lock() { + guard.drop_waker(observed_version, waker_key); + } + } } /// Marker type for using an asynchronous lock for the inner value. @@ -99,4 +107,8 @@ impl Lock for AsyncLock { fn shared_into_inner(shared: Self::Shared) -> Arc> { Self::Shared::into_inner(shared) } + + fn drop_waker(state: &Self::SubscriberState, observed_version: u64, waker_key: usize) { + state.drop_waker(observed_version, waker_key); + } } diff --git a/eyeball/src/state.rs b/eyeball/src/state.rs index d0c21ee..01d1295 100644 --- a/eyeball/src/state.rs +++ b/eyeball/src/state.rs @@ -59,21 +59,32 @@ impl ObservableState { pub(crate) fn poll_update( &self, observed_version: &mut u64, + waker_key: &mut Option, cx: &Context<'_>, ) -> Poll> { let mut metadata = self.metadata.write().unwrap(); if metadata.version == 0 { + *waker_key = None; Poll::Ready(None) } else if *observed_version < metadata.version { + *waker_key = None; *observed_version = metadata.version; Poll::Ready(Some(())) } else { - metadata.wakers.insert(cx.waker().clone()); + *waker_key = Some(metadata.wakers.insert(cx.waker().clone())); Poll::Pending } } + pub(crate) fn drop_waker(&self, observed_version: u64, waker_key: usize) { + let mut metadata = self.metadata.write().unwrap(); + if metadata.version == observed_version { + let _res = metadata.wakers.try_remove(waker_key); + debug_assert!(_res.is_some()); + } + } + pub(crate) fn set(&mut self, value: T) -> T { let result = mem::replace(&mut self.value, value); self.incr_version_and_wake(); diff --git a/eyeball/src/subscriber.rs b/eyeball/src/subscriber.rs index ff2637b..7703220 100644 --- a/eyeball/src/subscriber.rs +++ b/eyeball/src/subscriber.rs @@ -22,11 +22,13 @@ pub(crate) mod async_lock; pub struct Subscriber { state: L::SubscriberState, observed_version: u64, + // TODO: NonMaxUsize would be nice + waker_key: Option, } impl Subscriber { pub(crate) fn new(state: readlock::SharedReadLock>, version: u64) -> Self { - Self { state, observed_version: version } + Self { state, observed_version: version, waker_key: None } } /// Wait for an update and get a clone of the updated value. @@ -123,7 +125,7 @@ impl Subscriber { fn poll_next_ref(&mut self, cx: &Context<'_>) -> Poll>> { let state = self.state.lock(); state - .poll_update(&mut self.observed_version, cx) + .poll_update(&mut self.observed_version, &mut self.waker_key, cx) .map(|ready| ready.map(|_| ObservableReadGuard::new(state))) } } @@ -153,7 +155,7 @@ impl Subscriber { where L::SubscriberState: Clone, { - Self { state: self.state.clone(), observed_version: 0 } + Self { state: self.state.clone(), observed_version: 0, waker_key: None } } } @@ -171,7 +173,7 @@ where L::SubscriberState: Clone, { fn clone(&self) -> Self { - Self { state: self.state.clone(), observed_version: self.observed_version } + Self { state: self.state.clone(), observed_version: self.observed_version, waker_key: None } } } @@ -195,6 +197,14 @@ impl Stream for Subscriber { } } +impl Drop for Subscriber { + fn drop(&mut self) { + if let Some(waker_key) = self.waker_key { + L::drop_waker(&self.state, self.observed_version, waker_key); + } + } +} + /// Future returned by [`Subscriber::next`]. #[must_use] #[allow(missing_debug_implementations)] diff --git a/eyeball/src/subscriber/async_lock.rs b/eyeball/src/subscriber/async_lock.rs index 1380646..c964953 100644 --- a/eyeball/src/subscriber/async_lock.rs +++ b/eyeball/src/subscriber/async_lock.rs @@ -17,6 +17,14 @@ pub struct AsyncSubscriberState { get_lock: ReusableBoxFuture<'static, OwnedSharedReadGuard>>, } +impl AsyncSubscriberState { + pub(crate) fn drop_waker(&self, observed_version: u64, waker_key: usize) { + if let Ok(guard) = self.inner.try_lock() { + guard.drop_waker(observed_version, waker_key); + } + } +} + impl Clone for AsyncSubscriberState { fn clone(&self) -> Self { Self { @@ -35,7 +43,11 @@ impl fmt::Debug for AsyncSubscriberState { impl Subscriber { pub(crate) fn new_async(inner: SharedReadLock>, version: u64) -> Self { let get_lock = ReusableBoxFuture::new(inner.clone().lock_owned()); - Self { state: AsyncSubscriberState { inner, get_lock }, observed_version: version } + Self { + state: AsyncSubscriberState { inner, get_lock }, + observed_version: version, + waker_key: None, + } } /// Wait for an update and get a clone of the updated value. @@ -132,7 +144,7 @@ impl Subscriber { fn poll_update(&mut self, cx: &mut Context<'_>) -> Poll> { let state = ready!(self.state.get_lock.poll(cx)); self.state.get_lock.set(self.state.inner.clone().lock_owned()); - state.poll_update(&mut self.observed_version, cx) + state.poll_update(&mut self.observed_version, &mut self.waker_key, cx) } fn poll_next_nopin(&mut self, cx: &mut Context<'_>) -> Poll> @@ -142,7 +154,7 @@ impl Subscriber { let state = ready!(self.state.get_lock.poll(cx)); self.state.get_lock.set(self.state.inner.clone().lock_owned()); state - .poll_update(&mut self.observed_version, cx) + .poll_update(&mut self.observed_version, &mut self.waker_key, cx) .map(|ready| ready.map(|_| state.get().clone())) } }