Skip to content

Commit

Permalink
WIP: Waker Drop fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jplatte committed Jul 13, 2024
1 parent 73ccb62 commit 3860300
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 10 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ keywords = ["async", "observable", "reactive"]
assert_matches = "1.5.0"
futures-core = "0.3.26"
futures-util = { version = "0.3.26", default-features = false }
readlock = "0.1.5"
readlock = "0.1.8"
stream_assert = "0.1.0"
tokio = { version = "1.25.0", features = ["sync"] }
tokio-util = "0.7.8"
Expand Down
2 changes: 1 addition & 1 deletion eyeball/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ rustdoc-args = ["--cfg", "docsrs"]
[dependencies]
futures-core.workspace = true
readlock.workspace = true
readlock-tokio = { version = "0.1.1", optional = true }
readlock-tokio = { version = "0.1.3", optional = true }
slab = "0.4.9"
tracing = { workspace = true, optional = true }
tokio = { workspace = true, optional = true }
Expand Down
12 changes: 12 additions & 0 deletions eyeball/src/lock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ pub trait Lock {
fn new_shared<T>(value: T) -> Self::Shared<T>;
fn shared_read_count<T>(shared: &Self::Shared<T>) -> usize;
fn shared_into_inner<T>(shared: Self::Shared<T>) -> Arc<Self::RwLock<T>>;

fn drop_waker<S>(state: &Self::SubscriberState<S>, observed_version: u64, waker_key: usize);
}

/// Marker type for using a synchronous lock for the inner value.
Expand Down Expand Up @@ -61,6 +63,12 @@ impl Lock for SyncLock {
fn shared_into_inner<T>(shared: Self::Shared<T>) -> Arc<Self::RwLock<T>> {
Self::Shared::into_inner(shared)
}

fn drop_waker<S>(state: &Self::SubscriberState<S>, observed_version: u64, waker_key: usize) {
if let Ok(guard) = state.try_lock() {
guard.drop_waker(observed_version, waker_key);
}
}
}

/// Marker type for using an asynchronous lock for the inner value.
Expand Down Expand Up @@ -99,4 +107,8 @@ impl Lock for AsyncLock {
fn shared_into_inner<T>(shared: Self::Shared<T>) -> Arc<Self::RwLock<T>> {
Self::Shared::into_inner(shared)
}

fn drop_waker<S>(state: &Self::SubscriberState<S>, observed_version: u64, waker_key: usize) {
state.drop_waker(observed_version, waker_key);
}
}
13 changes: 12 additions & 1 deletion eyeball/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,21 +59,32 @@ impl<T> ObservableState<T> {
pub(crate) fn poll_update(
&self,
observed_version: &mut u64,
waker_key: &mut Option<usize>,
cx: &Context<'_>,
) -> Poll<Option<()>> {
let mut metadata = self.metadata.write().unwrap();

if metadata.version == 0 {
*waker_key = None;
Poll::Ready(None)
} else if *observed_version < metadata.version {
*waker_key = None;
*observed_version = metadata.version;
Poll::Ready(Some(()))
} else {
metadata.wakers.insert(cx.waker().clone());
*waker_key = Some(metadata.wakers.insert(cx.waker().clone()));
Poll::Pending
}
}

pub(crate) fn drop_waker(&self, observed_version: u64, waker_key: usize) {
let mut metadata = self.metadata.write().unwrap();
if metadata.version == observed_version {
let _res = metadata.wakers.try_remove(waker_key);
debug_assert!(_res.is_some());
}
}

pub(crate) fn set(&mut self, value: T) -> T {
let result = mem::replace(&mut self.value, value);
self.incr_version_and_wake();
Expand Down
18 changes: 14 additions & 4 deletions eyeball/src/subscriber.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@ pub(crate) mod async_lock;
pub struct Subscriber<T, L: Lock = SyncLock> {
state: L::SubscriberState<T>,
observed_version: u64,
// TODO: NonMaxUsize would be nice
waker_key: Option<usize>,
}

impl<T> Subscriber<T> {
pub(crate) fn new(state: readlock::SharedReadLock<ObservableState<T>>, version: u64) -> Self {
Self { state, observed_version: version }
Self { state, observed_version: version, waker_key: None }
}

/// Wait for an update and get a clone of the updated value.
Expand Down Expand Up @@ -123,7 +125,7 @@ impl<T> Subscriber<T> {
fn poll_next_ref(&mut self, cx: &Context<'_>) -> Poll<Option<ObservableReadGuard<'_, T>>> {
let state = self.state.lock();
state
.poll_update(&mut self.observed_version, cx)
.poll_update(&mut self.observed_version, &mut self.waker_key, cx)
.map(|ready| ready.map(|_| ObservableReadGuard::new(state)))
}
}
Expand Down Expand Up @@ -153,7 +155,7 @@ impl<T, L: Lock> Subscriber<T, L> {
where
L::SubscriberState<T>: Clone,
{
Self { state: self.state.clone(), observed_version: 0 }
Self { state: self.state.clone(), observed_version: 0, waker_key: None }
}
}

Expand All @@ -171,7 +173,7 @@ where
L::SubscriberState<T>: Clone,
{
fn clone(&self) -> Self {
Self { state: self.state.clone(), observed_version: self.observed_version }
Self { state: self.state.clone(), observed_version: self.observed_version, waker_key: None }
}
}

Expand All @@ -195,6 +197,14 @@ impl<T: Clone> Stream for Subscriber<T> {
}
}

impl<T, L: Lock> Drop for Subscriber<T, L> {
fn drop(&mut self) {
if let Some(waker_key) = self.waker_key {
L::drop_waker(&self.state, self.observed_version, waker_key);
}
}
}

/// Future returned by [`Subscriber::next`].
#[must_use]
#[allow(missing_debug_implementations)]
Expand Down
18 changes: 15 additions & 3 deletions eyeball/src/subscriber/async_lock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@ pub struct AsyncSubscriberState<T> {
get_lock: ReusableBoxFuture<'static, OwnedSharedReadGuard<ObservableState<T>>>,
}

impl<T> AsyncSubscriberState<T> {
pub(crate) fn drop_waker(&self, observed_version: u64, waker_key: usize) {
if let Ok(guard) = self.inner.try_lock() {
guard.drop_waker(observed_version, waker_key);
}
}
}

impl<S: Send + Sync + 'static> Clone for AsyncSubscriberState<S> {
fn clone(&self) -> Self {
Self {
Expand All @@ -35,7 +43,11 @@ impl<S: fmt::Debug> fmt::Debug for AsyncSubscriberState<S> {
impl<T: Send + Sync + 'static> Subscriber<T, AsyncLock> {
pub(crate) fn new_async(inner: SharedReadLock<ObservableState<T>>, version: u64) -> Self {
let get_lock = ReusableBoxFuture::new(inner.clone().lock_owned());
Self { state: AsyncSubscriberState { inner, get_lock }, observed_version: version }
Self {
state: AsyncSubscriberState { inner, get_lock },
observed_version: version,
waker_key: None,
}
}

/// Wait for an update and get a clone of the updated value.
Expand Down Expand Up @@ -132,7 +144,7 @@ impl<T: Send + Sync + 'static> Subscriber<T, AsyncLock> {
fn poll_update(&mut self, cx: &mut Context<'_>) -> Poll<Option<()>> {
let state = ready!(self.state.get_lock.poll(cx));
self.state.get_lock.set(self.state.inner.clone().lock_owned());
state.poll_update(&mut self.observed_version, cx)
state.poll_update(&mut self.observed_version, &mut self.waker_key, cx)
}

fn poll_next_nopin(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>>
Expand All @@ -142,7 +154,7 @@ impl<T: Send + Sync + 'static> Subscriber<T, AsyncLock> {
let state = ready!(self.state.get_lock.poll(cx));
self.state.get_lock.set(self.state.inner.clone().lock_owned());
state
.poll_update(&mut self.observed_version, cx)
.poll_update(&mut self.observed_version, &mut self.waker_key, cx)
.map(|ready| ready.map(|_| state.get().clone()))
}
}
Expand Down

0 comments on commit 3860300

Please sign in to comment.