diff --git a/CHANGES.md b/CHANGES.md index a3b5a84..f008035 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,9 @@ # Changes +## [3.2.0] - 2024-12-03 + +* Fix control queue handling + ## [3.1.0] - 2024-12-01 * Set "next_incoming_id" for Flow frame diff --git a/Cargo.toml b/Cargo.toml index 38c725e..67428d6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-amqp" -version = "3.1.0" +version = "3.2.0" authors = ["ntex contributors "] description = "AMQP 1.0 Client/Server framework" documentation = "https://docs.rs/ntex-amqp" @@ -25,7 +25,8 @@ default = [] frame-trace = [] [dependencies] -ntex = "2" +ntex = "2.9" +ntex-util = "2.7" ntex-amqp-codec = "0.9" bitflags = "2" diff --git a/src/connection.rs b/src/connection.rs index 8f1a00e..33d0821 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -6,7 +6,7 @@ use ntex::util::{HashMap, PoolRef, Ready}; use crate::codec::protocol::{self as codec, Begin, Close, End, Error, Frame, Role}; use crate::codec::{AmqpCodec, AmqpFrame}; -use crate::dispatcher::ControlQueue; +use crate::control::ControlQueue; use crate::session::{Session, SessionInner, INITIAL_NEXT_OUTGOING_ID}; use crate::sndlink::{SenderLink, SenderLinkInner}; use crate::{cell::Cell, error::AmqpProtocolError, types::Action, Configuration}; diff --git a/src/control.rs b/src/control.rs index cfc1c47..254a05d 100644 --- a/src/control.rs +++ b/src/control.rs @@ -1,6 +1,6 @@ -use std::{fmt, io}; +use std::{cell::RefCell, collections::VecDeque, fmt, io}; -use ntex::util::Either; +use ntex::{task::LocalWaker, util::Either}; use ntex_amqp_codec::protocol; use crate::cell::Cell; @@ -72,3 +72,16 @@ impl ControlFrame { self.0.get_ref().session.clone().map(Session::new) } } + +#[derive(Default, Debug)] +pub(crate) struct ControlQueue { + pub(crate) pending: RefCell>, + pub(crate) waker: LocalWaker, +} + +impl ControlQueue { + pub(crate) fn enqueue_frame(&self, frame: ControlFrame) { + self.pending.borrow_mut().push_back(frame); + self.waker.wake(); + } +} diff --git a/src/dispatcher.rs b/src/dispatcher.rs index ecd2acc..09a360d 100644 --- a/src/dispatcher.rs +++ b/src/dispatcher.rs @@ -1,38 +1,29 @@ -use std::collections::VecDeque; use std::task::{Context, Poll}; use std::{cell, cmp, future::poll_fn, future::Future, marker, pin::Pin, rc::Rc}; use ntex::service::{Pipeline, PipelineBinding, PipelineCall, Service, ServiceCtx}; use ntex::time::{sleep, Millis, Sleep}; -use ntex::util::{ready, Either}; +use ntex::util::{ready, select, Either}; use ntex::{io::DispatchItem, rt::spawn, task::LocalWaker}; use crate::codec::{protocol::Frame, AmqpCodec, AmqpFrame}; use crate::error::{AmqpDispatcherError, AmqpProtocolError, Error}; use crate::{connection::Connection, types, ControlFrame, ControlFrameKind, ReceiverLink}; -#[derive(Default, Debug)] -pub(crate) struct ControlQueue { - pending: cell::RefCell>, - waker: LocalWaker, -} - -impl ControlQueue { - pub(crate) fn enqueue_frame(&self, frame: ControlFrame) { - self.pending.borrow_mut().push_back(frame); - self.waker.wake(); - } -} - /// Amqp server dispatcher service. -pub(crate) struct Dispatcher, Ctl: Service> { +pub(crate) struct Dispatcher, Ctl: Service>( + Rc>, +); + +struct DispatcherInner, Ctl: Service> { sink: Connection, service: PipelineBinding, ctl_service: PipelineBinding, - ctl_fut: cell::RefCell)>>, - ctl_queue: Rc, - expire: Sleep, + ctl_error: cell::Cell>, + ctl_error_waker: LocalWaker, + idle_sleep: Sleep, idle_timeout: Millis, + stopped: cell::Cell, } impl Dispatcher @@ -48,146 +39,25 @@ where idle_timeout: Millis, ) -> Self { let idle_timeout = Millis(cmp::min(idle_timeout.0 >> 1, 1000)); - let ctl_queue = sink.get_control_queue().clone(); - Dispatcher { + let inner = Rc::new(DispatcherInner { sink, idle_timeout, - ctl_queue, service: service.bind(), ctl_service: ctl_service.bind(), - ctl_fut: cell::RefCell::new(Vec::new()), - expire: sleep(idle_timeout), - } - } + ctl_error: cell::Cell::new(None), + ctl_error_waker: LocalWaker::default(), + idle_sleep: sleep(idle_timeout), + stopped: cell::Cell::new(false), + }); - fn call_control_service(&self, frame: ControlFrame) { - let fut = self.ctl_service.call(frame.clone()); - self.ctl_fut.borrow_mut().push((frame, fut)); - self.ctl_queue.waker.wake(); - } - - fn handle_idle_timeout(&self, cx: &mut Context<'_>) { - if self.idle_timeout.non_zero() && self.expire.poll_elapsed(cx).is_ready() { - log::trace!( - "{}: Send keep-alive ping, timeout: {:?} secs", - self.sink.tag(), - self.idle_timeout - ); - self.sink.post_frame(AmqpFrame::new(0, Frame::Empty)); - self.expire.reset(self.idle_timeout); - self.handle_idle_timeout(cx); - } + let disp = Dispatcher(inner); + disp.start_idle_timer(); + disp.start_control_queue(); + disp } - fn handle_control_fut(&self, cx: &mut Context<'_>) -> Result { - let mut ready = true; - let mut inner = self.ctl_fut.borrow_mut(); - - // process control frame - let mut idx = 0; - while inner.len() > idx { - let item = &mut inner[idx]; - let res = match Pin::new(&mut item.1).poll(cx) { - Poll::Pending => { - idx += 1; - ready = false; - continue; - } - Poll::Ready(res) => res, - }; - let (frame, _) = inner.swap_remove(idx); - match res { - Ok(_) => { - self.handle_control_frame(&frame, None)?; - } - Err(e) => { - self.handle_control_frame(&frame, Some(e.into()))?; - } - } - } - Ok(ready) - } - - fn handle_control_frame( - &self, - frame: &ControlFrame, - err: Option, - ) -> Result<(), AmqpDispatcherError> { - if let Some(err) = err { - match &frame.0.get_mut().kind { - ControlFrameKind::AttachReceiver(_, ref link) => { - let _ = link.close_with_error(err); - } - ControlFrameKind::AttachSender(ref frm, ref link) => { - frame - .session_cell() - .get_mut() - .detach_unconfirmed_sender_link(frm, link.inner.clone(), Some(err)); - } - ControlFrameKind::Flow(_, ref link) => { - let _ = link.close_with_error(err); - } - ControlFrameKind::LocalDetachSender(..) => {} - ControlFrameKind::LocalDetachReceiver(..) => {} - ControlFrameKind::RemoteDetachSender(_, ref link) => { - let _ = link.close_with_error(err); - } - ControlFrameKind::RemoteDetachReceiver(_, ref link) => { - let _ = link.close_with_error(err); - } - ControlFrameKind::ProtocolError(ref err) => { - self.sink.set_error(err.clone()); - return Err(err.clone().into()); - } - ControlFrameKind::Closed | ControlFrameKind::Disconnected(_) => { - self.sink.set_error(AmqpProtocolError::Disconnected); - } - ControlFrameKind::LocalSessionEnded(_) - | ControlFrameKind::RemoteSessionEnded(_) => (), - } - } else { - match frame.0.get_mut().kind { - ControlFrameKind::AttachReceiver(ref frm, ref link) => { - let link = link.clone(); - let frm = frm.clone(); - let fut = self - .service - .call(types::Message::Attached(frm.clone(), link.clone())); - let _ = ntex::rt::spawn(async move { - let result = fut.await; - if let Err(err) = result { - let _ = link.close_with_error(Error::from(err)).await; - } else { - link.confirm_receiver_link(&frm); - link.set_link_credit(50); - } - }); - } - ControlFrameKind::AttachSender(ref frm, ref link) => { - frame - .session_cell() - .get_mut() - .attach_remote_sender_link(frm, link.inner.clone()); - } - ControlFrameKind::Flow(ref frm, ref link) => { - frame.session_cell().get_mut().handle_flow(frm, Some(link)); - } - ControlFrameKind::ProtocolError(ref err) => { - self.sink.set_error(err.clone()); - return Err(err.clone().into()); - } - ControlFrameKind::Closed | ControlFrameKind::Disconnected(_) => { - self.sink.set_error(AmqpProtocolError::Disconnected); - } - ControlFrameKind::LocalDetachSender(..) - | ControlFrameKind::LocalDetachReceiver(..) - | ControlFrameKind::LocalSessionEnded(_) - | ControlFrameKind::RemoteDetachSender(..) - | ControlFrameKind::RemoteDetachReceiver(..) - | ControlFrameKind::RemoteSessionEnded(_) => (), - } - } - Ok(()) + fn call_control_service(&self, frame: ControlFrame) { + self.0.sink.get_control_queue().enqueue_frame(frame); } } @@ -202,50 +72,36 @@ where async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> { poll_fn(|cx| { - self.ctl_queue.waker.register(cx.waker()); - - // idle timeout - self.handle_idle_timeout(cx); - - // process control frame - let mut control_fut_pending = !self.handle_control_fut(cx)?; + if let Some(err) = self.0.ctl_error.take() { + log::error!("{}: Control service failed: {:?}", self.0.sink.tag(), err); + let _ = self.0.sink.close(); + return Poll::Ready(Err(err)); + } // check readiness - let service_poll = self.service.poll_ready(cx).map_err(|err| { + let service_poll = self.0.service.poll_ready(cx).map_err(|err| { let err = Error::from(err); log::error!( "{}: Publish service readiness check failed: {:?}", - self.sink.tag(), + self.0.sink.tag(), err ); - let _ = self.sink.close_with_error(err); + let _ = self.0.sink.close_with_error(err); AmqpDispatcherError::Service })?; - let ctl_service_poll = self.ctl_service.poll_ready(cx).map_err(|err| { + let ctl_service_poll = self.0.ctl_service.poll_ready(cx).map_err(|err| { let err = Error::from(err); log::error!( "{}: Control service readiness check failed: {:?}", - self.sink.tag(), + self.0.sink.tag(), err ); - let _ = self.sink.close_with_error(err); + let _ = self.0.sink.close_with_error(err); AmqpDispatcherError::Service })?; - // enqueue pending control frames - if ctl_service_poll.is_ready() && !self.ctl_queue.pending.borrow().is_empty() { - self.ctl_queue - .pending - .borrow_mut() - .drain(..) - .for_each(|frame| { - self.call_control_service(frame); - }); - control_fut_pending = true; - } - - if control_fut_pending || service_poll.is_pending() || ctl_service_poll.is_pending() { + if service_poll.is_pending() || ctl_service_poll.is_pending() { Poll::Pending } else { Poll::Ready(Ok(())) @@ -254,18 +110,39 @@ where .await } + async fn not_ready(&self) { + select( + select( + poll_fn(|cx| self.0.service.poll_not_ready(cx)), + poll_fn(|cx| self.0.ctl_service.poll_not_ready(cx)), + ), + poll_fn(|cx| { + self.0.ctl_error_waker.register(cx.waker()); + if let Some(err) = self.0.ctl_error.take() { + self.0.ctl_error.set(Some(err)); + Poll::Ready(()) + } else { + Poll::Pending + } + }), + ) + .await; + } + async fn shutdown(&self) { - self.sink + self.0 + .sink .0 .get_mut() .set_error(AmqpProtocolError::Disconnected); let _ = self + .0 .ctl_service .call(ControlFrame::new_kind(ControlFrameKind::Closed)) .await; - self.service.shutdown().await; - self.ctl_service.shutdown().await; + self.0.service.shutdown().await; + self.0.ctl_service.shutdown().await; } async fn call( @@ -276,9 +153,10 @@ where match request { DispatchItem::Item(frame) => { #[cfg(feature = "frame-trace")] - log::trace!("{}: incoming: {:#?}", self.sink.tag(), frame); + log::trace!("{}: incoming: {:#?}", self.0.sink.tag(), frame); let action = match self + .0 .sink .handle_frame(frame) .map_err(AmqpDispatcherError::Protocol) @@ -289,9 +167,10 @@ where match action { types::Action::Transfer(link) => { - if self.sink.is_opened() { + if self.0.sink.is_opened() { let lnk = link.clone(); - if let Err(e) = self.service.call(types::Message::Transfer(link)).await + if let Err(e) = + self.0.service.call(types::Message::Transfer(link)).await { let e = Error::from(e); log::trace!("Service error {:?}", e); @@ -326,7 +205,7 @@ where } types::Action::DetachReceiver(link, frm) => { let lnk = link.clone(); - let fut = self.service.call(types::Message::Detached(lnk)); + let fut = self.0.service.call(types::Message::Detached(lnk)); let _ = spawn(async move { let _ = fut.await; }); @@ -344,7 +223,7 @@ where }) .collect(); - let fut = self.service.call(types::Message::DetachedAll(receivers)); + let fut = self.0.service.call(types::Message::DetachedAll(receivers)); let _ = spawn(async move { let _ = fut.await; }); @@ -391,6 +270,185 @@ where } } +impl Dispatcher +where + Sr: Service + 'static, + Ctl: Service + 'static, + Error: From + From, +{ + fn start_idle_timer(&self) { + if self.0.idle_timeout.non_zero() { + let slf = self.0.clone(); + ntex::rt::spawn(async move { + poll_fn(|cx| slf.idle_sleep.poll_elapsed(cx)).await; + if slf.stopped.get() || !slf.sink.is_opened() { + return; + } + log::trace!( + "{}: Send keep-alive ping, timeout: {:?} secs", + slf.sink.tag(), + slf.idle_timeout + ); + slf.sink.post_frame(AmqpFrame::new(0, Frame::Empty)); + slf.idle_sleep.reset(slf.idle_timeout); + }); + } + } + + fn start_control_queue(&self) { + let slf = self.0.clone(); + let queue = self.0.sink.get_control_queue().clone(); + let on_close = self.0.sink.get_ref().on_close(); + ntex::rt::spawn(async move { + let mut futs: Vec<(ControlFrame, PipelineCall)> = Vec::new(); + poll_fn(|cx| { + queue.waker.register(cx.waker()); + + // enqueue pending control frames + queue.pending.borrow_mut().drain(..).for_each(|frame| { + let fut = slf.ctl_service.call(frame.clone()); + futs.push((frame, fut)); + }); + + // process control frame + let mut idx = 0; + while futs.len() > idx { + let item = &mut futs[idx]; + let res = match Pin::new(&mut item.1).poll(cx) { + Poll::Pending => { + idx += 1; + continue; + } + Poll::Ready(res) => res, + }; + let (frame, _) = futs.swap_remove(idx); + let result = match res { + Ok(_) => slf.handle_control_frame(&frame, None), + Err(e) => slf.handle_control_frame(&frame, Some(e.into())), + }; + + if let Err(err) = result { + slf.ctl_error.set(Some(err)); + slf.ctl_error_waker.wake(); + return Poll::Ready(()); + } + } + + if !slf.sink.is_opened() { + let _ = on_close.poll_ready(cx); + } + + if !futs.is_empty() || !slf.stopped.get() || slf.sink.is_opened() { + Poll::Pending + } else { + Poll::Ready(()) + } + }) + .await; + }); + } +} + +impl Drop for Dispatcher +where + Sr: Service, + Ctl: Service, +{ + fn drop(&mut self) { + self.0.stopped.set(true); + self.0.idle_sleep.elapse(); + } +} + +impl DispatcherInner +where + Sr: Service + 'static, + Ctl: Service + 'static, + Error: From + From, +{ + fn handle_control_frame( + &self, + frame: &ControlFrame, + err: Option, + ) -> Result<(), AmqpDispatcherError> { + if let Some(err) = err { + match &frame.0.get_mut().kind { + ControlFrameKind::AttachReceiver(_, ref link) => { + let _ = link.close_with_error(err); + } + ControlFrameKind::AttachSender(ref frm, ref link) => { + frame + .session_cell() + .get_mut() + .detach_unconfirmed_sender_link(frm, link.inner.clone(), Some(err)); + } + ControlFrameKind::Flow(_, ref link) => { + let _ = link.close_with_error(err); + } + ControlFrameKind::LocalDetachSender(..) => {} + ControlFrameKind::LocalDetachReceiver(..) => {} + ControlFrameKind::RemoteDetachSender(_, ref link) => { + let _ = link.close_with_error(err); + } + ControlFrameKind::RemoteDetachReceiver(_, ref link) => { + let _ = link.close_with_error(err); + } + ControlFrameKind::ProtocolError(ref err) => { + self.sink.set_error(err.clone()); + return Err(err.clone().into()); + } + ControlFrameKind::Closed | ControlFrameKind::Disconnected(_) => { + self.sink.set_error(AmqpProtocolError::Disconnected); + } + ControlFrameKind::LocalSessionEnded(_) + | ControlFrameKind::RemoteSessionEnded(_) => (), + } + } else { + match frame.0.get_mut().kind { + ControlFrameKind::AttachReceiver(ref frm, ref link) => { + let link = link.clone(); + let frm = frm.clone(); + let fut = self + .service + .call(types::Message::Attached(frm.clone(), link.clone())); + let _ = ntex::rt::spawn(async move { + let result = fut.await; + if let Err(err) = result { + let _ = link.close_with_error(Error::from(err)).await; + } else { + link.confirm_receiver_link(&frm); + link.set_link_credit(50); + } + }); + } + ControlFrameKind::AttachSender(ref frm, ref link) => { + frame + .session_cell() + .get_mut() + .attach_remote_sender_link(frm, link.inner.clone()); + } + ControlFrameKind::Flow(ref frm, ref link) => { + frame.session_cell().get_mut().handle_flow(frm, Some(link)); + } + ControlFrameKind::ProtocolError(ref err) => { + self.sink.set_error(err.clone()); + return Err(err.clone().into()); + } + ControlFrameKind::Closed | ControlFrameKind::Disconnected(_) => { + self.sink.set_error(AmqpProtocolError::Disconnected); + } + ControlFrameKind::LocalDetachSender(..) + | ControlFrameKind::LocalDetachReceiver(..) + | ControlFrameKind::LocalSessionEnded(_) + | ControlFrameKind::RemoteDetachSender(..) + | ControlFrameKind::RemoteDetachReceiver(..) + | ControlFrameKind::RemoteSessionEnded(_) => (), + } + } + Ok(()) + } +} + pin_project_lite::pin_project! { pub struct ServiceResult<'f, F, E> where F: 'f diff --git a/src/types.rs b/src/types.rs index 312f551..2e0bbb3 100644 --- a/src/types.rs +++ b/src/types.rs @@ -19,6 +19,7 @@ pub enum Message { Transfer(ReceiverLink), } +#[derive(Debug)] pub(crate) enum Action { None, AttachSender(SenderLink, Attach),