diff --git a/examples/examples/ws.rs b/examples/examples/ws.rs index 6587585cc3..9a703a6a7f 100644 --- a/examples/examples/ws.rs +++ b/examples/examples/ws.rs @@ -27,7 +27,7 @@ use std::net::SocketAddr; use jsonrpsee::core::client::ClientT; -use jsonrpsee::server::{RpcServiceBuilder, Server}; +use jsonrpsee::server::{PingConfig, RpcServiceBuilder, Server}; use jsonrpsee::ws_client::WsClientBuilder; use jsonrpsee::{rpc_params, RpcModule}; use tracing_subscriber::util::SubscriberInitExt; @@ -51,7 +51,11 @@ async fn main() -> anyhow::Result<()> { async fn run_server() -> anyhow::Result { let rpc_middleware = RpcServiceBuilder::new().rpc_logger(1024); - let server = Server::builder().set_rpc_middleware(rpc_middleware).build("127.0.0.1:0").await?; + let server = Server::builder() + .enable_ws_ping(PingConfig::new()) + .set_rpc_middleware(rpc_middleware) + .build("127.0.0.1:0") + .await?; let mut module = RpcModule::new(()); module.register_method("say_hello", |_, _, _| "lo")?; let addr = server.local_addr()?; diff --git a/server/src/transport/ws.rs b/server/src/transport/ws.rs index 8bccb14e25..5efc2c4433 100644 --- a/server/src/transport/ws.rs +++ b/server/src/transport/ws.rs @@ -1,14 +1,15 @@ use std::sync::Arc; -use std::time::Instant; +use std::time::{Duration, Instant}; use crate::future::{IntervalStream, SessionClose}; use crate::middleware::rpc::{RpcService, RpcServiceBuilder, RpcServiceCfg, RpcServiceT}; use crate::server::{handle_rpc_call, ConnectionState, ServerConfig}; +use crate::utils::PendingPings; use crate::{HttpBody, HttpRequest, HttpResponse, PingConfig, LOG_TARGET}; -use futures_util::future::{self, Either}; +use futures_util::future::{self, Either, Fuse}; use futures_util::io::{BufReader, BufWriter}; -use futures_util::{Future, StreamExt, TryStreamExt}; +use futures_util::{Future, FutureExt, StreamExt, TryStreamExt}; use hyper::upgrade::Upgraded; use hyper_util::rt::TokioIo; use jsonrpsee_core::server::{BoundedSubscriptions, MethodSink, Methods}; @@ -18,7 +19,7 @@ use soketto::connection::Error as SokettoError; use soketto::data::ByteSlice125; use tokio::sync::{mpsc, oneshot}; -use tokio::time::{interval, interval_at}; +use tokio::time::interval; use tokio_stream::wrappers::ReceiverStream; use tokio_util::compat::{Compat, TokioAsyncReadCompatExt}; @@ -38,7 +39,6 @@ pub(crate) async fn send_message(sender: &mut Sender, response: String) -> Resul } pub(crate) async fn send_ping(sender: &mut Sender) -> Result<(), SokettoError> { - tracing::debug!(target: LOG_TARGET, "Send ping"); // Submit empty slice as "optional" parameter. let slice: &[u8] = &[]; // Byte slice fails if the provided slice is larger than 125 bytes. @@ -76,17 +76,28 @@ where mut on_session_close, extensions, } = params; + + let conn_id = conn.conn_id; let ServerConfig { ping_config, batch_requests_config, max_request_body_size, max_response_body_size, .. } = server_cfg; - let (conn_tx, conn_rx) = oneshot::channel(); + // Spawn ping/pong task if ping config is provided. + let ping_config = if let Some(ping_config) = ping_config { + let (ping_tx, ping_rx) = mpsc::channel::(4); + tokio::spawn(ping_pong_task(ping_rx, ping_config.inactive_limit, ping_config.max_failures, conn_id)); + Some((ping_config, ping_tx)) + } else { + None + }; + + let ping_tx = ping_config.as_ref().map(|(_, tx)| tx.clone()); + // Spawn another task that sends out the responses on the Websocket. let send_task_handle = tokio::spawn(send_task(rx, ws_sender, ping_config, conn_rx)); let stopped = conn.stop_handle.clone().shutdown(); let rpc_service = Arc::new(rpc_service); - let mut missed_pings = 0; tokio::pin!(stopped); @@ -106,8 +117,9 @@ where tokio::pin!(ws_stream); let result = loop { - let data = match try_recv(&mut ws_stream, stopped, ping_config, &mut missed_pings).await { + let data = match try_recv(&mut ws_stream, stopped, ping_tx.as_ref()).await { Receive::ConnectionClosed => break Ok(Shutdown::ConnectionClosed), + Receive::KeepAliveExpired => break Ok(Shutdown::KeepAliveExpired), Receive::Stopped => break Ok(Shutdown::Stopped), Receive::Ok(data, stop) => { stopped = stop; @@ -134,7 +146,6 @@ where continue; } err => { - tracing::debug!(target: LOG_TARGET, "WS error: {}; terminate connection: {}", err, conn.conn_id); break Err(err); } }; @@ -186,6 +197,8 @@ where }); }; + tracing::debug!(target: LOG_TARGET, "Connection closed for conn_id={conn_id}, reason={:?}", result); + // Drive all running methods to completion. // **NOTE** Do not return early in this function. This `await` needs to run to guarantee // proper drop behaviour. @@ -203,23 +216,23 @@ where async fn send_task( rx: mpsc::Receiver, mut ws_sender: Sender, - ping_config: Option, + ping_config: Option<(PingConfig, mpsc::Sender)>, stop: oneshot::Receiver<()>, ) { - let ping_interval = match ping_config { - None => IntervalStream::pending(), - // NOTE: we are emitted a tick here immediately to sync - // with how the receive task work because it starts measuring the pong - // when it starts up. - Some(p) => IntervalStream::new(interval(p.ping_interval)), + // Ping task is only spawned if ping config is provided. + let ping = match ping_config { + None => Either::Left(IntervalStream::pending().map(|_| None)), + Some((p, ping_tx)) => { + Either::Right(IntervalStream::new(interval(p.ping_interval)).map(move |_| Some(ping_tx.clone()))) + } }; let rx = ReceiverStream::new(rx); - tokio::pin!(ping_interval, rx, stop); + tokio::pin!(ping, rx, stop); // Received messages from the WebSocket. let mut rx_item = rx.next(); - let next_ping = ping_interval.next(); + let next_ping = ping.next(); let mut futs = future::select(next_ping, stop); loop { @@ -244,7 +257,7 @@ async fn send_task( } // Handle timer intervals. - Either::Right((Either::Left((_instant, _stopped)), next_rx)) => { + Either::Right((Either::Left((Some(ping_tx), _stopped)), next_rx)) => { stop = _stopped; if let Err(err) = send_ping(&mut ws_sender).await { tracing::debug!(target: LOG_TARGET, "WS send ping error: {}", err); @@ -252,8 +265,21 @@ async fn send_task( } rx_item = next_rx; - futs = future::select(ping_interval.next(), stop); + + let ping_tx = ping_tx.expect("ping tx is only `None` if ping_config is `None` checked above; qed"); + tokio::spawn(async move { + ping_tx.send(KeepAlive::Ping(Instant::now())).await.ok(); + }); + + futs = future::select(ping.next(), stop); } + + // The interval stream has been closed. + // This should be unreachable because the interval stream never ends. + Either::Right((Either::Left((None, _stopped)), _)) => { + break; + } + Either::Right((Either::Right((_stopped, _)), _)) => { // server has stopped break; @@ -268,68 +294,101 @@ async fn send_task( enum Receive { ConnectionClosed, + KeepAliveExpired, Stopped, Err(SokettoError, S), Ok(Vec, S), } /// Attempts to read data from WebSocket fails if the server was stopped. -async fn try_recv( - ws_stream: &mut T, - mut stopped: S, - ping_config: Option, - missed_pings: &mut usize, -) -> Receive +async fn try_recv(ws_stream: &mut T, stopped: S, ping_tx: Option<&mpsc::Sender>) -> Receive where S: Future + Unpin, T: StreamExt> + Unpin, { - let mut last_active = Instant::now(); - let inactivity_check = match ping_config { - Some(p) => IntervalStream::new(interval_at(tokio::time::Instant::now() + p.ping_interval, p.ping_interval)), - None => IntervalStream::pending(), + let mut futs = future::select(ws_stream.next(), stopped); + let closed = match ping_tx { + Some(ping_tx) => ping_tx.closed().fuse(), + None => Fuse::terminated(), }; - tokio::pin!(inactivity_check); - - let mut futs = futures_util::future::select(ws_stream.next(), inactivity_check.next()); + tokio::pin!(closed); loop { - match futures_util::future::select(futs, stopped).await { + match future::select(futs, closed).await { // The connection is closed. Either::Left((Either::Left((None, _)), _)) => break Receive::ConnectionClosed, // The message has been received, we are done - Either::Left((Either::Left((Some(Ok(Incoming::Data(d))), _)), s)) => break Receive::Ok(d, s), - // Got a pong response, update our "last seen" timestamp. - Either::Left((Either::Left((Some(Ok(Incoming::Pong)), inactive)), s)) => { - last_active = Instant::now(); - stopped = s; - futs = futures_util::future::select(ws_stream.next(), inactive); + Either::Left((Either::Left((Some(Ok(Incoming::Data(d))), s)), _)) => { + if let Some(ping_tx) = ping_tx { + let ping_tx = ping_tx.clone(); + tokio::spawn(async move { + _ = ping_tx.send(KeepAlive::Data(Instant::now())).await; + }); + } + + break Receive::Ok(d, s); + } + // Got a pong response send status to the ping_pong_task. + Either::Left((Either::Left((Some(Ok(Incoming::Pong)), s)), c)) => { + if let Some(ping_tx) = ping_tx { + let ping_tx = ping_tx.clone(); + tokio::spawn(async move { + _ = ping_tx.send(KeepAlive::Pong(Instant::now())).await; + }); + } + futs = futures_util::future::select(ws_stream.next(), s); + closed = c; } // Received an error, terminate the connection. - Either::Left((Either::Left((Some(Err(e)), _)), s)) => break Receive::Err(e, s), - // Max inactivity timeout fired, check if the connection has been idle too long. - Either::Left((Either::Right((_instant, rcv)), s)) => { - if let Some(p) = ping_config { - if last_active.elapsed() > p.inactive_limit { - *missed_pings += 1; - - if *missed_pings >= p.max_failures { - tracing::debug!( - target: LOG_TARGET, - "WS ping/pong inactivity limit `{}` exceeded; closing connection", - p.max_failures, - ); - break Receive::ConnectionClosed; + Either::Left((Either::Left((Some(Err(e)), s)), _)) => break Receive::Err(e, s), + + // Server has been stopped or closed by inactive peer. + Either::Left((Either::Right((_, _)), _)) => break Receive::Stopped, + // Ping task has been stopped. + Either::Right((_, _)) => break Receive::KeepAliveExpired, + } + } +} + +#[derive(Debug, Copy, Clone)] +pub(crate) enum KeepAlive { + Ping(Instant), + Data(Instant), + Pong(Instant), +} + +async fn ping_pong_task( + mut rx: mpsc::Receiver, + max_inactivity_dur: Duration, + max_missed_pings: usize, + conn_id: u32, +) { + let mut polling_interval = IntervalStream::new(interval(max_inactivity_dur)); + let mut pending_pings = PendingPings::new(max_missed_pings, max_inactivity_dur, conn_id); + + loop { + tokio::select! { + // If the ping is never answered, we use this timer as a fallback. + _ = polling_interval.next() => { + if !pending_pings.check_alive() { + break; + } + } + // Data on the connection. + msg = rx.recv() => { + match msg { + Some(KeepAlive::Ping(start)) => { + pending_pings.push(start); + } + Some(KeepAlive::Pong(end)) | Some(KeepAlive::Data(end)) => { + if !pending_pings.alive_response(end) { + break; } } + None => break, } - - stopped = s; - futs = futures_util::future::select(rcv, inactivity_check.next()); } - // Server has been stopped. - Either::Right(_) => break Receive::Stopped, } } } @@ -338,6 +397,7 @@ where pub(crate) enum Shutdown { Stopped, ConnectionClosed, + KeepAliveExpired, } /// Enforce a graceful shutdown. diff --git a/server/src/utils.rs b/server/src/utils.rs index d510a84661..dc3a86da0e 100644 --- a/server/src/utils.rs +++ b/server/src/utils.rs @@ -24,11 +24,13 @@ // IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. +use std::collections::VecDeque; use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; +use std::time::{Duration, Instant}; -use crate::{HttpBody, HttpRequest}; +use crate::{HttpBody, HttpRequest, LOG_TARGET}; use futures_util::future::{self, Either}; use hyper_util::rt::{TokioExecutor, TokioIo}; @@ -83,6 +85,90 @@ where } } +#[derive(Debug, Clone)] +pub(crate) struct PendingPings { + list: VecDeque, + max_missed_pings: usize, + missed_pings: usize, + max_inactivity_dur: Duration, + conn_id: u32, +} + +impl PendingPings { + pub(crate) fn new(max_missed_pings: usize, max_inactivity_dur: Duration, conn_id: u32) -> Self { + Self { list: VecDeque::new(), max_missed_pings, max_inactivity_dur, missed_pings: 0, conn_id } + } + + fn log_ping_expired(elapsed: Duration, conn_id: u32, max_inactivity_dur: Duration) { + tracing::debug!(target: LOG_TARGET, "Ping/pong keep alive for conn_id={conn_id}, elapsed={}ms/max={}ms", elapsed.as_millis(), max_inactivity_dur.as_millis()); + } + + fn log_connection_closed(missed_pings: usize, conn_id: u32) { + tracing::debug!(target: LOG_TARGET, "Missed {missed_pings} ping/pongs for conn_id={conn_id}; closing connection"); + } + + pub(crate) fn push(&mut self, instant: Instant) { + self.list.push_back(instant); + } + + /// Check if there are any pending pings that have expired + /// + /// It's different from [`PendingPing::alive_response`] because + /// this shouldn't be used when data is received. + /// + /// It's just way to ensure that pings are checked despite no message is received on the + /// connection. + /// + /// Returns `true` if the connection is still alive, `false` otherwise. + pub(crate) fn check_alive(&mut self) -> bool { + let mut list = VecDeque::new(); + + for ping_start in self.list.drain(..) { + if ping_start.elapsed() >= self.max_inactivity_dur { + self.missed_pings += 1; + Self::log_ping_expired(ping_start.elapsed(), self.conn_id, self.max_inactivity_dur); + } else { + list.push_back(ping_start); + } + + if self.missed_pings >= self.max_missed_pings { + Self::log_connection_closed(self.missed_pings, self.conn_id); + return false; + } + } + + self.list = list; + true + } + + /// Register a alive response. + /// + /// Returns `true` if the pong was answered in time, `false` otherwise. + pub(crate) fn alive_response(&mut self, end: Instant) -> bool { + for ping_start in self.list.drain(..) { + // Calculate the round-trip time (RTT) of the ping/pong. + // We adjust for the time when the pong was received. + let elapsed = ping_start.elapsed().saturating_sub(end.elapsed()); + + tracing::trace!(target: LOG_TARGET, "ws_ping_pong_rtt={}ms, conn_id={}", elapsed.as_millis(), self.conn_id); + + if elapsed >= self.max_inactivity_dur { + self.missed_pings += 1; + Self::log_ping_expired(ping_start.elapsed(), self.conn_id, self.max_inactivity_dur); + } else { + self.missed_pings = 0; + } + + if self.missed_pings >= self.max_missed_pings { + Self::log_connection_closed(self.missed_pings, self.conn_id); + return false; + } + } + + true + } +} + /// Serve a service over a TCP connection without graceful shutdown. /// This means that pending requests will be dropped when the server is stopped. /// @@ -163,3 +249,47 @@ pub(crate) mod deserialize { Ok(req) } } + +#[cfg(test)] +mod tests { + use super::PendingPings; + use std::time::{Duration, Instant}; + + #[test] + fn pending_ping_works() { + let mut pending_pings = PendingPings::new(1, std::time::Duration::from_secs(1), 0); + + pending_pings.push(Instant::now()); + assert!(pending_pings.alive_response(std::time::Instant::now())); + assert!(pending_pings.list.is_empty()); + assert_eq!(pending_pings.missed_pings, 0); + } + + #[test] + fn inactive_too_long() { + let mut pending_pings = PendingPings::new(2, std::time::Duration::from_millis(100), 0); + + pending_pings.push(Instant::now()); + pending_pings.push(Instant::now()); + + std::thread::sleep(Duration::from_millis(200)); + + assert!(!pending_pings.check_alive()); + assert_eq!(pending_pings.missed_pings, 2); + } + + #[test] + fn active_reset_counter() { + let mut pending_pings = PendingPings::new(2, std::time::Duration::from_millis(100), 0); + pending_pings.push(std::time::Instant::now()); + + std::thread::sleep(Duration::from_millis(200)); + + assert!(pending_pings.check_alive()); + assert_eq!(pending_pings.missed_pings, 1); + + pending_pings.push(std::time::Instant::now()); + assert!(pending_pings.alive_response(Instant::now())); + assert_eq!(pending_pings.missed_pings, 0); + } +}