diff --git a/src/win.rs b/src/win.rs index b0a2bce..64cfe0e 100644 --- a/src/win.rs +++ b/src/win.rs @@ -8,7 +8,6 @@ use std::ffi::c_void; use std::io::{Error, ErrorKind, Result}; use std::pin::Pin; use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Arc; use std::task::{Context, Poll}; use windows::Win32::Foundation::{BOOLEAN, HANDLE}; use windows::Win32::NetworkManagement::IpHelper::{ @@ -43,24 +42,25 @@ pub struct IfWatcher { queue: VecDeque, #[allow(unused)] notif: IpChangeNotification, - waker: Arc, - resync: Arc, + shared: Pin>, } impl IfWatcher { /// Create a watcher. pub fn new() -> Result { - let resync = Arc::new(AtomicBool::new(true)); - let waker = Arc::new(AtomicWaker::new()); + let shared = IfWatcherShared { + resync: true.into(), + waker: Default::default(), + }; + let shared = Box::pin(shared); Ok(Self { addrs: Default::default(), queue: Default::default(), - waker: waker.clone(), - resync: resync.clone(), - notif: IpChangeNotification::new(Box::new(move |_, _| { - resync.store(true, Ordering::Relaxed); - waker.wake(); - }))?, + // Safety: + // Self referential structure, `shared` will be dropped + // after `notif` + notif: unsafe { IpChangeNotification::new(shared.as_ref())? }, + shared, }) } @@ -96,10 +96,13 @@ impl IfWatcher { if let Some(event) = self.queue.pop_front() { return Poll::Ready(Ok(event)); } - if !self.resync.swap(false, Ordering::Relaxed) { - self.waker.register(cx.waker()); + + self.shared.waker.register(cx.waker()); + if !self.shared.resync.swap(false, Ordering::AcqRel) { return Poll::Pending; } + self.shared.waker.take(); + if let Err(error) = self.resync() { return Poll::Ready(Err(error)); } @@ -137,10 +140,22 @@ fn ifaddr_to_ipnet(addr: IfAddr) -> IpNet { } } +#[derive(Debug)] +struct IfWatcherShared { + waker: AtomicWaker, + resync: AtomicBool, +} + +impl IpChangeCallback for IfWatcherShared { + fn callback(&self, _row: &MIB_IPINTERFACE_ROW, _notification_type: MIB_NOTIFICATION_TYPE) { + self.resync.store(true, Ordering::Release); + self.waker.wake(); + } +} + /// IP change notifications struct IpChangeNotification { handle: HANDLE, - callback: *mut IpChangeCallback, } impl std::fmt::Debug for IpChangeNotification { @@ -149,31 +164,37 @@ impl std::fmt::Debug for IpChangeNotification { } } -type IpChangeCallback = Box; - impl IpChangeNotification { /// Register for route change notifications - fn new(cb: IpChangeCallback) -> Result { - unsafe extern "system" fn global_callback( + /// + /// Safety: C must outlive the resulting Self + unsafe fn new(cb: Pin<&C>) -> Result + where + C: IpChangeCallback + Send + Sync, + { + unsafe extern "system" fn global_callback( caller_context: *const c_void, row: *const MIB_IPINTERFACE_ROW, notification_type: MIB_NOTIFICATION_TYPE, - ) { - (**(caller_context as *const IpChangeCallback))(&*row, notification_type) + ) where + C: IpChangeCallback + Send + Sync, + { + let caller_context = &*(caller_context as *const C); + caller_context.callback(&*row, notification_type) } let mut handle = HANDLE::default(); - let callback = Box::into_raw(Box::new(cb)); + let callback = cb.get_ref() as *const C; unsafe { NotifyIpInterfaceChange( AF_UNSPEC, - Some(global_callback), - Some(callback as _), + Some(global_callback::), + Some(callback as *const c_void), BOOLEAN(0), &mut handle as _, ) .map_err(|err| Error::new(ErrorKind::Other, err.to_string()))?; } - Ok(Self { callback, handle }) + Ok(Self { handle }) } } @@ -183,9 +204,12 @@ impl Drop for IpChangeNotification { if let Err(err) = CancelMibChangeNotify2(self.handle) { log::error!("error deregistering notification: {}", err); } - drop(Box::from_raw(self.callback)); } } } unsafe impl Send for IpChangeNotification {} + +trait IpChangeCallback { + fn callback(&self, row: &MIB_IPINTERFACE_ROW, notification_type: MIB_NOTIFICATION_TYPE); +}