From 5217c9c59525b0a375f31880c58d5ed24339ea66 Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Thu, 1 Aug 2024 13:41:13 +0200 Subject: [PATCH 01/16] server: syncronize ws ping/pong messages --- server/src/transport/ws.rs | 144 +++++++++++++++++++++++++------------ 1 file changed, 98 insertions(+), 46 deletions(-) diff --git a/server/src/transport/ws.rs b/server/src/transport/ws.rs index 8bccb14e25..64884b4dfa 100644 --- a/server/src/transport/ws.rs +++ b/server/src/transport/ws.rs @@ -1,5 +1,6 @@ +use std::collections::VecDeque; use std::sync::Arc; -use std::time::Instant; +use std::time::{Duration, Instant}; use crate::future::{IntervalStream, SessionClose}; use crate::middleware::rpc::{RpcService, RpcServiceBuilder, RpcServiceCfg, RpcServiceT}; @@ -18,7 +19,7 @@ use soketto::connection::Error as SokettoError; use soketto::data::ByteSlice125; use tokio::sync::{mpsc, oneshot}; -use tokio::time::{interval, interval_at}; +use tokio::time::interval; use tokio_stream::wrappers::ReceiverStream; use tokio_util::compat::{Compat, TokioAsyncReadCompatExt}; @@ -80,13 +81,17 @@ where server_cfg; let (conn_tx, conn_rx) = oneshot::channel(); + let (ping_tx, ping_rx) = mpsc::channel::(4); // Spawn another task that sends out the responses on the Websocket. - let send_task_handle = tokio::spawn(send_task(rx, ws_sender, ping_config, conn_rx)); + let send_task_handle = tokio::spawn(send_task(rx, ws_sender, ping_config, conn_rx, ping_tx.clone())); + + if let Some(ping_config) = ping_config { + tokio::spawn(ping_pong_task(ping_rx, ping_config.inactive_limit, ping_config.max_failures)); + } let stopped = conn.stop_handle.clone().shutdown(); let rpc_service = Arc::new(rpc_service); - let mut missed_pings = 0; tokio::pin!(stopped); @@ -106,7 +111,7 @@ where tokio::pin!(ws_stream); let result = loop { - let data = match try_recv(&mut ws_stream, stopped, ping_config, &mut missed_pings).await { + let data = match try_recv(&mut ws_stream, stopped, &ping_tx).await { Receive::ConnectionClosed => break Ok(Shutdown::ConnectionClosed), Receive::Stopped => break Ok(Shutdown::Stopped), Receive::Ok(data, stop) => { @@ -205,6 +210,7 @@ async fn send_task( mut ws_sender: Sender, ping_config: Option, stop: oneshot::Receiver<()>, + ping_tx: mpsc::Sender, ) { let ping_interval = match ping_config { None => IntervalStream::pending(), @@ -252,6 +258,12 @@ async fn send_task( } rx_item = next_rx; + + let ping_tx = ping_tx.clone(); + tokio::spawn(async move { + ping_tx.send(KeepAlive::Ping(Instant::now())).await.ok(); + }); + futs = future::select(ping_interval.next(), stop); } Either::Right((Either::Right((_stopped, _)), _)) => { @@ -274,62 +286,102 @@ enum Receive { } /// Attempts to read data from WebSocket fails if the server was stopped. -async fn try_recv( - ws_stream: &mut T, - mut stopped: S, - ping_config: Option, - missed_pings: &mut usize, -) -> Receive +async fn try_recv(ws_stream: &mut T, mut stopped: S, ping_tx: &mpsc::Sender) -> Receive where S: Future + Unpin, T: StreamExt> + Unpin, { - let mut last_active = Instant::now(); - let inactivity_check = match ping_config { - Some(p) => IntervalStream::new(interval_at(tokio::time::Instant::now() + p.ping_interval, p.ping_interval)), - None => IntervalStream::pending(), - }; - - tokio::pin!(inactivity_check); - - let mut futs = futures_util::future::select(ws_stream.next(), inactivity_check.next()); - loop { - match futures_util::future::select(futs, stopped).await { + match futures_util::future::select(ws_stream.next(), stopped).await { // The connection is closed. - Either::Left((Either::Left((None, _)), _)) => break Receive::ConnectionClosed, + Either::Left((None, _)) => break Receive::ConnectionClosed, // The message has been received, we are done - Either::Left((Either::Left((Some(Ok(Incoming::Data(d))), _)), s)) => break Receive::Ok(d, s), + Either::Left((Some(Ok(Incoming::Data(d))), s)) => { + let ping_tx = ping_tx.clone(); + tokio::spawn(async move { + _ = ping_tx.send(KeepAlive::Pong(Instant::now())).await; + }); + + break Receive::Ok(d, s); + } // Got a pong response, update our "last seen" timestamp. - Either::Left((Either::Left((Some(Ok(Incoming::Pong)), inactive)), s)) => { - last_active = Instant::now(); + Either::Left((Some(Ok(Incoming::Pong)), s)) => { + let ping_tx = ping_tx.clone(); + tokio::spawn(async move { + _ = ping_tx.send(KeepAlive::Pong(Instant::now())).await; + }); + stopped = s; - futs = futures_util::future::select(ws_stream.next(), inactive); } // Received an error, terminate the connection. - Either::Left((Either::Left((Some(Err(e)), _)), s)) => break Receive::Err(e, s), - // Max inactivity timeout fired, check if the connection has been idle too long. - Either::Left((Either::Right((_instant, rcv)), s)) => { - if let Some(p) = ping_config { - if last_active.elapsed() > p.inactive_limit { - *missed_pings += 1; - - if *missed_pings >= p.max_failures { - tracing::debug!( - target: LOG_TARGET, - "WS ping/pong inactivity limit `{}` exceeded; closing connection", - p.max_failures, - ); - break Receive::ConnectionClosed; - } + Either::Left((Some(Err(e)), s)) => break Receive::Err(e, s), + // Server has been stopped. + Either::Right(_) => break Receive::Stopped, + } + } +} + +#[derive(Debug, Copy, Clone)] +enum KeepAlive { + Ping(Instant), + Pong(Instant), +} + +async fn ping_pong_task(mut rx: mpsc::Receiver, inactive_limit: Duration, max_inactive: usize) { + let polling_interval = inactive_limit.mul_f64(1.2); + let mut pending_pings: VecDeque = VecDeque::new(); + let mut missed_pings = 0; + + loop { + tokio::select! { + // If the ping is never answered, we use this timer as a fallback. + _ = tokio::time::sleep(polling_interval) => { + let mut remove = false; + + if let Some(ping_start) = pending_pings.front() { + + if ping_start.elapsed() > inactive_limit { + missed_pings += 1; + remove = true; + } + + if missed_pings >= max_inactive { + tracing::debug!(target: LOG_TARGET, "Too many missed pings, closing connection"); + break; } } - stopped = s; - futs = futures_util::future::select(rcv, inactivity_check.next()); + if remove { + pending_pings.pop_front(); + } + } + msg = rx.recv() => { + match msg { + Some(KeepAlive::Ping(start)) => { + pending_pings.push_back(start); + } + Some(KeepAlive::Pong(end)) => { + if let Some(start) = pending_pings.pop_front() { + // Calculate the round-trip time (RTT) of the ping/pong. + // We adjust for the time to send it to this task. + let elapsed = start.elapsed() - end.elapsed(); + + if elapsed > inactive_limit { + missed_pings += 1; + } + + + tracing::debug!(target: LOG_TARGET, "ping/pong RTT: {:?}", elapsed); + + if missed_pings >= max_inactive { + tracing::debug!(target: LOG_TARGET, "Too many missed pings, closing connection"); + break; + } + } + } + None => break, + } } - // Server has been stopped. - Either::Right(_) => break Receive::Stopped, } } } From 0b86a2ab7d456993c85fed54e230ee8b162aab16 Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Fri, 2 Aug 2024 09:04:58 +0200 Subject: [PATCH 02/16] fix nit: use interval stream --- server/src/transport/ws.rs | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/server/src/transport/ws.rs b/server/src/transport/ws.rs index 64884b4dfa..c7aba3b374 100644 --- a/server/src/transport/ws.rs +++ b/server/src/transport/ws.rs @@ -327,26 +327,29 @@ enum KeepAlive { Pong(Instant), } -async fn ping_pong_task(mut rx: mpsc::Receiver, inactive_limit: Duration, max_inactive: usize) { - let polling_interval = inactive_limit.mul_f64(1.2); +async fn ping_pong_task(mut rx: mpsc::Receiver, max_inactive_limit: Duration, max_inactive: usize) { + let mut polling_interval = IntervalStream::new(interval(max_inactive_limit)); let mut pending_pings: VecDeque = VecDeque::new(); let mut missed_pings = 0; loop { tokio::select! { // If the ping is never answered, we use this timer as a fallback. - _ = tokio::time::sleep(polling_interval) => { + _ = polling_interval.next() => { let mut remove = false; if let Some(ping_start) = pending_pings.front() { + let elapsed = ping_start.elapsed(); - if ping_start.elapsed() > inactive_limit { + if elapsed > max_inactive_limit { missed_pings += 1; remove = true; } + tracing::debug!(target: LOG_TARGET, "ping/pong keep alive expired elapsed={:?}/max={:?}", elapsed, max_inactive_limit); + if missed_pings >= max_inactive { - tracing::debug!(target: LOG_TARGET, "Too many missed pings, closing connection"); + tracing::debug!(target: LOG_TARGET, "Too many missed ping/pongs, closing connection"); break; } } @@ -366,15 +369,14 @@ async fn ping_pong_task(mut rx: mpsc::Receiver, inactive_limit: Durat // We adjust for the time to send it to this task. let elapsed = start.elapsed() - end.elapsed(); - if elapsed > inactive_limit { + if elapsed > max_inactive_limit { missed_pings += 1; } - tracing::debug!(target: LOG_TARGET, "ping/pong RTT: {:?}", elapsed); if missed_pings >= max_inactive { - tracing::debug!(target: LOG_TARGET, "Too many missed pings, closing connection"); + tracing::debug!(target: LOG_TARGET, "Too many missed ping/pongs, closing connection"); break; } } From a792ba2329a4ea2ce5a208162e57599a0baf1ea5 Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Fri, 2 Aug 2024 10:46:11 +0200 Subject: [PATCH 03/16] add noise for debugging --- examples/examples/ws.rs | 17 +++++++--- server/src/server.rs | 7 ++++ server/src/transport/ws.rs | 65 ++++++++++++++++++++++++++------------ 3 files changed, 65 insertions(+), 24 deletions(-) diff --git a/examples/examples/ws.rs b/examples/examples/ws.rs index 6587585cc3..14cc32d062 100644 --- a/examples/examples/ws.rs +++ b/examples/examples/ws.rs @@ -25,9 +25,10 @@ // DEALINGS IN THE SOFTWARE. use std::net::SocketAddr; +use std::time::Duration; use jsonrpsee::core::client::ClientT; -use jsonrpsee::server::{RpcServiceBuilder, Server}; +use jsonrpsee::server::{PingConfig, RpcServiceBuilder, Server}; use jsonrpsee::ws_client::WsClientBuilder; use jsonrpsee::{rpc_params, RpcModule}; use tracing_subscriber::util::SubscriberInitExt; @@ -40,18 +41,26 @@ async fn main() -> anyhow::Result<()> { tracing_subscriber::FmtSubscriber::builder().with_env_filter(filter).finish().try_init()?; let addr = run_server().await?; - let url = format!("ws://{}", addr); + /*let url = format!("ws://{}", addr); let client = WsClientBuilder::default().build(&url).await?; let response: String = client.request("say_hello", rpc_params![]).await?; - tracing::info!("response: {:?}", response); + tracing::info!("response: {:?}", response);*/ + + tokio::time::sleep(Duration::from_secs(60 * 5)).await; Ok(()) } async fn run_server() -> anyhow::Result { let rpc_middleware = RpcServiceBuilder::new().rpc_logger(1024); - let server = Server::builder().set_rpc_middleware(rpc_middleware).build("127.0.0.1:0").await?; + let server = Server::builder() + .enable_ws_ping( + PingConfig::new().ping_interval(Duration::from_secs(10)).inactive_limit(Duration::from_secs(20)), + ) + .set_rpc_middleware(rpc_middleware) + .build("127.0.0.1:9944") + .await?; let mut module = RpcModule::new(()); module.register_method("say_hello", |_, _, _| "lo")?; let addr = server.local_addr()?; diff --git a/server/src/server.rs b/server/src/server.rs index dc9c036189..76ac2abb1c 100644 --- a/server/src/server.rs +++ b/server/src/server.rs @@ -486,6 +486,7 @@ impl TowerServiceBuilder, stop_handle: StopHandle, + remote_addr: SocketAddr, ) -> TowerService { let conn_id = self.conn_id.fetch_add(1, std::sync::atomic::Ordering::Relaxed); @@ -497,6 +498,7 @@ impl TowerServiceBuilder(conn.conn_id.into()); + request.extensions_mut().insert(self.inner.remote_addr); let is_upgrade_request = is_upgrade_request(&request); @@ -1190,6 +1195,7 @@ where stop_handle, drop_on_completion, methods, + remote_addr, .. } = params; @@ -1204,6 +1210,7 @@ where methods, stop_handle: stop_handle.clone(), conn_id, + remote_addr, conn_guard: conn_guard.clone(), }, rpc_middleware, diff --git a/server/src/transport/ws.rs b/server/src/transport/ws.rs index c7aba3b374..20e9137e29 100644 --- a/server/src/transport/ws.rs +++ b/server/src/transport/ws.rs @@ -1,4 +1,5 @@ use std::collections::VecDeque; +use std::net::SocketAddr; use std::sync::Arc; use std::time::{Duration, Instant}; @@ -39,7 +40,6 @@ pub(crate) async fn send_message(sender: &mut Sender, response: String) -> Resul } pub(crate) async fn send_ping(sender: &mut Sender) -> Result<(), SokettoError> { - tracing::debug!(target: LOG_TARGET, "Send ping"); // Submit empty slice as "optional" parameter. let slice: &[u8] = &[]; // Byte slice fails if the provided slice is larger than 125 bytes. @@ -77,17 +77,19 @@ where mut on_session_close, extensions, } = params; + + let remote_addr = extensions.get::().cloned().unwrap_or_else(|| "unknown".parse().unwrap()); let ServerConfig { ping_config, batch_requests_config, max_request_body_size, max_response_body_size, .. } = server_cfg; - let (conn_tx, conn_rx) = oneshot::channel(); let (ping_tx, ping_rx) = mpsc::channel::(4); // Spawn another task that sends out the responses on the Websocket. - let send_task_handle = tokio::spawn(send_task(rx, ws_sender, ping_config, conn_rx, ping_tx.clone())); + let send_task_handle = + tokio::spawn(send_task(rx, ws_sender, ping_config, conn_rx, ping_tx.clone(), remote_addr.clone())); if let Some(ping_config) = ping_config { - tokio::spawn(ping_pong_task(ping_rx, ping_config.inactive_limit, ping_config.max_failures)); + tokio::spawn(ping_pong_task(ping_rx, ping_config.inactive_limit, ping_config.max_failures, remote_addr)); } let stopped = conn.stop_handle.clone().shutdown(); @@ -211,6 +213,7 @@ async fn send_task( ping_config: Option, stop: oneshot::Receiver<()>, ping_tx: mpsc::Sender, + remote_addr: SocketAddr, ) { let ping_interval = match ping_config { None => IntervalStream::pending(), @@ -235,11 +238,17 @@ async fn send_task( // Received message. Either::Left((Some(response), not_ready)) => { // If websocket message send fail then terminate the connection. + let now = Instant::now(); if let Err(err) = send_message(&mut ws_sender, response).await { tracing::debug!(target: LOG_TARGET, "WS send error: {}", err); break; } + if now.elapsed() > Duration::from_secs(30) { + tracing::warn!(target: LOG_TARGET, "Send message was slow {:?}, peer={:?}", now.elapsed(), remote_addr); + break; + } + rx_item = rx.next(); futs = not_ready; } @@ -252,11 +261,17 @@ async fn send_task( // Handle timer intervals. Either::Right((Either::Left((_instant, _stopped)), next_rx)) => { stop = _stopped; + let now = Instant::now(); if let Err(err) = send_ping(&mut ws_sender).await { tracing::debug!(target: LOG_TARGET, "WS send ping error: {}", err); break; } + if now.elapsed() > Duration::from_secs(30) { + tracing::warn!(target: LOG_TARGET, "Send ping was slow {:?}, peer={:?}", now.elapsed(), remote_addr); + break; + } + rx_item = next_rx; let ping_tx = ping_tx.clone(); @@ -286,17 +301,22 @@ enum Receive { } /// Attempts to read data from WebSocket fails if the server was stopped. -async fn try_recv(ws_stream: &mut T, mut stopped: S, ping_tx: &mpsc::Sender) -> Receive +async fn try_recv(ws_stream: &mut T, stopped: S, ping_tx: &mpsc::Sender) -> Receive where S: Future + Unpin, T: StreamExt> + Unpin, { + let mut futs = future::select(ws_stream.next(), stopped); + loop { - match futures_util::future::select(ws_stream.next(), stopped).await { + let closed = ping_tx.closed(); + tokio::pin!(closed); + + match future::select(futs, closed).await { // The connection is closed. - Either::Left((None, _)) => break Receive::ConnectionClosed, + Either::Left((Either::Left((None, _)), _)) => break Receive::ConnectionClosed, // The message has been received, we are done - Either::Left((Some(Ok(Incoming::Data(d))), s)) => { + Either::Left((Either::Left((Some(Ok(Incoming::Data(d))), s)), _)) => { let ping_tx = ping_tx.clone(); tokio::spawn(async move { _ = ping_tx.send(KeepAlive::Pong(Instant::now())).await; @@ -304,19 +324,19 @@ where break Receive::Ok(d, s); } - // Got a pong response, update our "last seen" timestamp. - Either::Left((Some(Ok(Incoming::Pong)), s)) => { + // Got a pong response send status to the ping_pong_task. + Either::Left((Either::Left((Some(Ok(Incoming::Pong)), s)), _)) => { let ping_tx = ping_tx.clone(); tokio::spawn(async move { _ = ping_tx.send(KeepAlive::Pong(Instant::now())).await; }); - stopped = s; + futs = futures_util::future::select(ws_stream.next(), s); } // Received an error, terminate the connection. - Either::Left((Some(Err(e)), s)) => break Receive::Err(e, s), - // Server has been stopped. - Either::Right(_) => break Receive::Stopped, + Either::Left((Either::Left((Some(Err(e)), s)), _)) => break Receive::Err(e, s), + // Server has been stopped or closed by inactive peer. + _ => break Receive::Stopped, } } } @@ -327,7 +347,12 @@ enum KeepAlive { Pong(Instant), } -async fn ping_pong_task(mut rx: mpsc::Receiver, max_inactive_limit: Duration, max_inactive: usize) { +async fn ping_pong_task( + mut rx: mpsc::Receiver, + max_inactive_limit: Duration, + max_inactive: usize, + remote_addr: SocketAddr, +) { let mut polling_interval = IntervalStream::new(interval(max_inactive_limit)); let mut pending_pings: VecDeque = VecDeque::new(); let mut missed_pings = 0; @@ -341,15 +366,15 @@ async fn ping_pong_task(mut rx: mpsc::Receiver, max_inactive_limit: D if let Some(ping_start) = pending_pings.front() { let elapsed = ping_start.elapsed(); - if elapsed > max_inactive_limit { + if elapsed >= max_inactive_limit { missed_pings += 1; remove = true; } - tracing::debug!(target: LOG_TARGET, "ping/pong keep alive expired elapsed={:?}/max={:?}", elapsed, max_inactive_limit); + tracing::debug!(target: LOG_TARGET, "Ping/pong keep alive expired for peer={:?}, elapsed={}s/max={}s", remote_addr, elapsed.as_secs(), max_inactive_limit.as_secs()); if missed_pings >= max_inactive { - tracing::debug!(target: LOG_TARGET, "Too many missed ping/pongs, closing connection"); + tracing::debug!(target: LOG_TARGET, "Too many missed ping/pongs for peer={:?}, closing connection", remote_addr); break; } } @@ -369,14 +394,14 @@ async fn ping_pong_task(mut rx: mpsc::Receiver, max_inactive_limit: D // We adjust for the time to send it to this task. let elapsed = start.elapsed() - end.elapsed(); - if elapsed > max_inactive_limit { + if elapsed >= max_inactive_limit { missed_pings += 1; } tracing::debug!(target: LOG_TARGET, "ping/pong RTT: {:?}", elapsed); if missed_pings >= max_inactive { - tracing::debug!(target: LOG_TARGET, "Too many missed ping/pongs, closing connection"); + tracing::debug!(target: LOG_TARGET, "Ping/pong keep alive expired for peer={:?}, elapsed={}s/max={}s", remote_addr, elapsed.as_secs(), max_inactive_limit.as_secs()); break; } } From ec922c304a9299486a718ebcfe14a0088c07d46c Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Fri, 2 Aug 2024 11:57:54 +0200 Subject: [PATCH 04/16] fix logs --- server/src/transport/ws.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/server/src/transport/ws.rs b/server/src/transport/ws.rs index 20e9137e29..b6fde874b8 100644 --- a/server/src/transport/ws.rs +++ b/server/src/transport/ws.rs @@ -369,12 +369,11 @@ async fn ping_pong_task( if elapsed >= max_inactive_limit { missed_pings += 1; remove = true; + tracing::debug!(target: LOG_TARGET, "ping/pong keep alive expired for peer={:?}, elapsed={}s/max={}s", remote_addr, elapsed.as_secs(), max_inactive_limit.as_secs()); } - tracing::debug!(target: LOG_TARGET, "Ping/pong keep alive expired for peer={:?}, elapsed={}s/max={}s", remote_addr, elapsed.as_secs(), max_inactive_limit.as_secs()); - if missed_pings >= max_inactive { - tracing::debug!(target: LOG_TARGET, "Too many missed ping/pongs for peer={:?}, closing connection", remote_addr); + tracing::debug!(target: LOG_TARGET, "Too many missed ping/pongs for peer={:?}; closing connection", remote_addr); break; } } @@ -396,12 +395,13 @@ async fn ping_pong_task( if elapsed >= max_inactive_limit { missed_pings += 1; + tracing::debug!(target: LOG_TARGET, "ping/pong keep alive expired for peer={:?}, elapsed={}s/max={}s", remote_addr, elapsed.as_secs(), max_inactive_limit.as_secs()); } tracing::debug!(target: LOG_TARGET, "ping/pong RTT: {:?}", elapsed); if missed_pings >= max_inactive { - tracing::debug!(target: LOG_TARGET, "Ping/pong keep alive expired for peer={:?}, elapsed={}s/max={}s", remote_addr, elapsed.as_secs(), max_inactive_limit.as_secs()); + tracing::debug!(target: LOG_TARGET, "Too many missed ping/pongs for peer={:?}; closing connection", remote_addr); break; } } From addd5d7ed08fccb3a7d1ce73f5a22a6c687deebc Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Fri, 2 Aug 2024 12:21:19 +0200 Subject: [PATCH 05/16] more debug logs --- examples/examples/ws.rs | 12 +++++------- server/src/transport/ws.rs | 14 +++++++++++--- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/examples/examples/ws.rs b/examples/examples/ws.rs index 14cc32d062..612f853d6e 100644 --- a/examples/examples/ws.rs +++ b/examples/examples/ws.rs @@ -41,13 +41,13 @@ async fn main() -> anyhow::Result<()> { tracing_subscriber::FmtSubscriber::builder().with_env_filter(filter).finish().try_init()?; let addr = run_server().await?; - /*let url = format!("ws://{}", addr); + let url = format!("ws://{}", addr); let client = WsClientBuilder::default().build(&url).await?; let response: String = client.request("say_hello", rpc_params![]).await?; - tracing::info!("response: {:?}", response);*/ + tracing::info!("response: {:?}", response); - tokio::time::sleep(Duration::from_secs(60 * 5)).await; + tokio::time::sleep(Duration::from_secs(60 * 4)).await; Ok(()) } @@ -55,11 +55,9 @@ async fn main() -> anyhow::Result<()> { async fn run_server() -> anyhow::Result { let rpc_middleware = RpcServiceBuilder::new().rpc_logger(1024); let server = Server::builder() - .enable_ws_ping( - PingConfig::new().ping_interval(Duration::from_secs(10)).inactive_limit(Duration::from_secs(20)), - ) + .enable_ws_ping(PingConfig::new()) .set_rpc_middleware(rpc_middleware) - .build("127.0.0.1:9944") + .build("127.0.0.1:0") .await?; let mut module = RpcModule::new(()); module.register_method("say_hello", |_, _, _| "lo")?; diff --git a/server/src/transport/ws.rs b/server/src/transport/ws.rs index b6fde874b8..8eba7b06f8 100644 --- a/server/src/transport/ws.rs +++ b/server/src/transport/ws.rs @@ -113,7 +113,7 @@ where tokio::pin!(ws_stream); let result = loop { - let data = match try_recv(&mut ws_stream, stopped, &ping_tx).await { + let data = match try_recv(&mut ws_stream, stopped, &ping_tx, remote_addr).await { Receive::ConnectionClosed => break Ok(Shutdown::ConnectionClosed), Receive::Stopped => break Ok(Shutdown::Stopped), Receive::Ok(data, stop) => { @@ -262,6 +262,7 @@ async fn send_task( Either::Right((Either::Left((_instant, _stopped)), next_rx)) => { stop = _stopped; let now = Instant::now(); + tracing::debug!(target: LOG_TARGET, "Send ping to peer={:?}", remote_addr); if let Err(err) = send_ping(&mut ws_sender).await { tracing::debug!(target: LOG_TARGET, "WS send ping error: {}", err); break; @@ -301,7 +302,12 @@ enum Receive { } /// Attempts to read data from WebSocket fails if the server was stopped. -async fn try_recv(ws_stream: &mut T, stopped: S, ping_tx: &mpsc::Sender) -> Receive +async fn try_recv( + ws_stream: &mut T, + stopped: S, + ping_tx: &mpsc::Sender, + remote_addr: SocketAddr, +) -> Receive where S: Future + Unpin, T: StreamExt> + Unpin, @@ -326,6 +332,8 @@ where } // Got a pong response send status to the ping_pong_task. Either::Left((Either::Left((Some(Ok(Incoming::Pong)), s)), _)) => { + tracing::debug!(target: LOG_TARGET, "Received pong from peer={:?}", remote_addr); + let ping_tx = ping_tx.clone(); tokio::spawn(async move { _ = ping_tx.send(KeepAlive::Pong(Instant::now())).await; @@ -398,7 +406,7 @@ async fn ping_pong_task( tracing::debug!(target: LOG_TARGET, "ping/pong keep alive expired for peer={:?}, elapsed={}s/max={}s", remote_addr, elapsed.as_secs(), max_inactive_limit.as_secs()); } - tracing::debug!(target: LOG_TARGET, "ping/pong RTT: {:?}", elapsed); + tracing::debug!(target: LOG_TARGET, "ping_pong_rtt={:?}, peer={:?}", elapsed, remote_addr); if missed_pings >= max_inactive { tracing::debug!(target: LOG_TARGET, "Too many missed ping/pongs for peer={:?}; closing connection", remote_addr); From 780bc7aa752c8583e1e22b9aaacfc843b855b5f1 Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Mon, 5 Aug 2024 15:34:22 +0200 Subject: [PATCH 06/16] more debug logs --- server/src/transport/ws.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/server/src/transport/ws.rs b/server/src/transport/ws.rs index 8eba7b06f8..8159876d3f 100644 --- a/server/src/transport/ws.rs +++ b/server/src/transport/ws.rs @@ -141,7 +141,6 @@ where continue; } err => { - tracing::debug!(target: LOG_TARGET, "WS error: {}; terminate connection: {}", err, conn.conn_id); break Err(err); } }; @@ -193,6 +192,8 @@ where }); }; + tracing::debug!(target: LOG_TARGET, "Connection closed for peer={}, reason={:?}", remote_addr, result); + // Drive all running methods to completion. // **NOTE** Do not return early in this function. This `await` needs to run to guarantee // proper drop behaviour. From 3fe5554ca37407af75480a4ad3759c5fd78dbb8c Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Mon, 5 Aug 2024 16:10:03 +0200 Subject: [PATCH 07/16] more updates --- examples/examples/ws.rs | 1 + server/src/transport/ws.rs | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/examples/ws.rs b/examples/examples/ws.rs index 612f853d6e..38852eb029 100644 --- a/examples/examples/ws.rs +++ b/examples/examples/ws.rs @@ -46,6 +46,7 @@ async fn main() -> anyhow::Result<()> { let client = WsClientBuilder::default().build(&url).await?; let response: String = client.request("say_hello", rpc_params![]).await?; tracing::info!("response: {:?}", response); + drop(client); tokio::time::sleep(Duration::from_secs(60 * 4)).await; diff --git a/server/src/transport/ws.rs b/server/src/transport/ws.rs index 8159876d3f..fd14610652 100644 --- a/server/src/transport/ws.rs +++ b/server/src/transport/ws.rs @@ -400,7 +400,7 @@ async fn ping_pong_task( if let Some(start) = pending_pings.pop_front() { // Calculate the round-trip time (RTT) of the ping/pong. // We adjust for the time to send it to this task. - let elapsed = start.elapsed() - end.elapsed(); + let elapsed = start.elapsed().saturating_sub(end.elapsed()); if elapsed >= max_inactive_limit { missed_pings += 1; From 4cec9358a432932823d8d0001fc29b66f549b65e Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Wed, 7 Aug 2024 16:19:50 +0200 Subject: [PATCH 08/16] cleanup --- examples/examples/ws.rs | 3 -- server/src/server.rs | 60 ++++++++++++++++++++++++++++++++------ server/src/transport/ws.rs | 52 ++++++++++++++++----------------- 3 files changed, 77 insertions(+), 38 deletions(-) diff --git a/examples/examples/ws.rs b/examples/examples/ws.rs index 38852eb029..16bbd9c860 100644 --- a/examples/examples/ws.rs +++ b/examples/examples/ws.rs @@ -46,9 +46,6 @@ async fn main() -> anyhow::Result<()> { let client = WsClientBuilder::default().build(&url).await?; let response: String = client.request("say_hello", rpc_params![]).await?; tracing::info!("response: {:?}", response); - drop(client); - - tokio::time::sleep(Duration::from_secs(60 * 4)).await; Ok(()) } diff --git a/server/src/server.rs b/server/src/server.rs index 76ac2abb1c..c485add6a3 100644 --- a/server/src/server.rs +++ b/server/src/server.rs @@ -26,7 +26,7 @@ use std::error::Error as StdError; use std::future::Future; -use std::net::{SocketAddr, TcpListener as StdTcpListener}; +use std::net::{IpAddr, SocketAddr, TcpListener as StdTcpListener}; use std::pin::Pin; use std::sync::atomic::AtomicU32; use std::sync::Arc; @@ -240,6 +240,8 @@ pub struct TowerServiceBuilder { pub(crate) conn_id: Arc, /// Connection guard. pub(crate) conn_guard: ConnectionGuard, + /// Remote address. + pub(crate) remote_addr: Option, } /// Configuration for batch request handling. @@ -461,6 +463,7 @@ pub struct Builder { server_cfg: ServerConfig, rpc_middleware: RpcServiceBuilder, http_middleware: tower::ServiceBuilder, + remote_addr: Option, } impl Default for Builder { @@ -469,6 +472,7 @@ impl Default for Builder { server_cfg: ServerConfig::default(), rpc_middleware: RpcServiceBuilder::new(), http_middleware: tower::ServiceBuilder::new(), + remote_addr: None, } } } @@ -478,6 +482,12 @@ impl Builder { pub fn new() -> Self { Self::default() } + + /// Set the address of the remote peer. + pub fn set_remote_addr(mut self, remote_addr: SocketAddr) -> Self { + self.remote_addr = Some(remote_addr); + self + } } impl TowerServiceBuilder { @@ -486,7 +496,6 @@ impl TowerServiceBuilder, stop_handle: StopHandle, - remote_addr: SocketAddr, ) -> TowerService { let conn_id = self.conn_id.fetch_add(1, std::sync::atomic::Ordering::Relaxed); @@ -498,7 +507,7 @@ impl TowerServiceBuilder TowerServiceBuilder TowerServiceBuilder TowerServiceBuilder { + TowerServiceBuilder { + server_cfg: self.server_cfg, + rpc_middleware: self.rpc_middleware, + http_middleware: self.http_middleware, + conn_id: self.conn_id, + conn_guard: self.conn_guard, + remote_addr: Some(remote_addr), } } } @@ -639,7 +662,12 @@ impl Builder { /// let builder = ServerBuilder::default().set_rpc_middleware(m); /// ``` pub fn set_rpc_middleware(self, rpc_middleware: RpcServiceBuilder) -> Builder { - Builder { server_cfg: self.server_cfg, rpc_middleware, http_middleware: self.http_middleware } + Builder { + server_cfg: self.server_cfg, + rpc_middleware, + http_middleware: self.http_middleware, + remote_addr: self.remote_addr, + } } /// Configure a custom [`tokio::runtime::Handle`] to run the server on. @@ -725,7 +753,12 @@ impl Builder { /// } /// ``` pub fn set_http_middleware(self, http_middleware: tower::ServiceBuilder) -> Builder { - Builder { server_cfg: self.server_cfg, http_middleware, rpc_middleware: self.rpc_middleware } + Builder { + server_cfg: self.server_cfg, + http_middleware, + rpc_middleware: self.rpc_middleware, + remote_addr: self.remote_addr, + } } /// Configure `TCP_NODELAY` on the socket to the supplied value `nodelay`. @@ -861,6 +894,7 @@ impl Builder { http_middleware: self.http_middleware, conn_id: Arc::new(AtomicU32::new(0)), conn_guard: ConnectionGuard::new(max_conns), + remote_addr: self.remote_addr, } } @@ -938,12 +972,12 @@ struct ServiceData { stop_handle: StopHandle, /// Connection ID conn_id: u32, - /// Remote addr - remote_addr: SocketAddr, /// Connection guard. conn_guard: ConnectionGuard, /// ServerConfig server_cfg: ServerConfig, + /// Remote address. + remote_addr: Option, } /// jsonrpsee tower service @@ -1052,7 +1086,15 @@ where tracing::debug!(target: LOG_TARGET, "Accepting new connection {}/{}", curr_conns, max_conns); request.extensions_mut().insert::(conn.conn_id.into()); - request.extensions_mut().insert(self.inner.remote_addr); + + if let Some(remote_addr) = self.inner.remote_addr { + // Only insert the remote address if it's not already set. + // We expect servers deployed behind a reverse proxy to set the remote address + // themselves otherwise the remote address will be the address of the reverse proxy. + if request.extensions().get::().is_none() { + request.extensions_mut().insert(remote_addr.ip()); + } + } let is_upgrade_request = is_upgrade_request(&request); @@ -1210,8 +1252,8 @@ where methods, stop_handle: stop_handle.clone(), conn_id, - remote_addr, conn_guard: conn_guard.clone(), + remote_addr: Some(remote_addr), }, rpc_middleware, on_session_close: None, diff --git a/server/src/transport/ws.rs b/server/src/transport/ws.rs index fd14610652..0df349fd16 100644 --- a/server/src/transport/ws.rs +++ b/server/src/transport/ws.rs @@ -1,5 +1,5 @@ use std::collections::VecDeque; -use std::net::SocketAddr; +use std::net::IpAddr; use std::sync::Arc; use std::time::{Duration, Instant}; @@ -78,15 +78,16 @@ where extensions, } = params; - let remote_addr = extensions.get::().cloned().unwrap_or_else(|| "unknown".parse().unwrap()); + // NOTE: jsonrpsee only inject the `remote_addr` if it not set because for servers that are behind a reverse proxy, + // needs read HTTP headers by the reverse proxy. + let remote_addr = extensions.get::().copied(); let ServerConfig { ping_config, batch_requests_config, max_request_body_size, max_response_body_size, .. } = server_cfg; let (conn_tx, conn_rx) = oneshot::channel(); let (ping_tx, ping_rx) = mpsc::channel::(4); // Spawn another task that sends out the responses on the Websocket. - let send_task_handle = - tokio::spawn(send_task(rx, ws_sender, ping_config, conn_rx, ping_tx.clone(), remote_addr.clone())); + let send_task_handle = tokio::spawn(send_task(rx, ws_sender, ping_config, conn_rx, ping_tx.clone(), remote_addr)); if let Some(ping_config) = ping_config { tokio::spawn(ping_pong_task(ping_rx, ping_config.inactive_limit, ping_config.max_failures, remote_addr)); @@ -113,8 +114,9 @@ where tokio::pin!(ws_stream); let result = loop { - let data = match try_recv(&mut ws_stream, stopped, &ping_tx, remote_addr).await { + let data = match try_recv(&mut ws_stream, stopped, &ping_tx).await { Receive::ConnectionClosed => break Ok(Shutdown::ConnectionClosed), + Receive::KeepAliveExpired => break Ok(Shutdown::KeepAliveExpired), Receive::Stopped => break Ok(Shutdown::Stopped), Receive::Ok(data, stop) => { stopped = stop; @@ -192,7 +194,7 @@ where }); }; - tracing::debug!(target: LOG_TARGET, "Connection closed for peer={}, reason={:?}", remote_addr, result); + tracing::debug!(target: LOG_TARGET, "Connection closed for peer={:?}, reason={:?}", remote_addr, result); // Drive all running methods to completion. // **NOTE** Do not return early in this function. This `await` needs to run to guarantee @@ -214,7 +216,7 @@ async fn send_task( ping_config: Option, stop: oneshot::Receiver<()>, ping_tx: mpsc::Sender, - remote_addr: SocketAddr, + remote_addr: Option, ) { let ping_interval = match ping_config { None => IntervalStream::pending(), @@ -263,14 +265,13 @@ async fn send_task( Either::Right((Either::Left((_instant, _stopped)), next_rx)) => { stop = _stopped; let now = Instant::now(); - tracing::debug!(target: LOG_TARGET, "Send ping to peer={:?}", remote_addr); if let Err(err) = send_ping(&mut ws_sender).await { tracing::debug!(target: LOG_TARGET, "WS send ping error: {}", err); break; } if now.elapsed() > Duration::from_secs(30) { - tracing::warn!(target: LOG_TARGET, "Send ping was slow {:?}, peer={:?}", now.elapsed(), remote_addr); + tracing::warn!(target: LOG_TARGET, "Send ping was slow {}s, peer={:?}", now.elapsed().as_secs(), remote_addr); break; } @@ -297,18 +298,14 @@ async fn send_task( enum Receive { ConnectionClosed, + KeepAliveExpired, Stopped, Err(SokettoError, S), Ok(Vec, S), } /// Attempts to read data from WebSocket fails if the server was stopped. -async fn try_recv( - ws_stream: &mut T, - stopped: S, - ping_tx: &mpsc::Sender, - remote_addr: SocketAddr, -) -> Receive +async fn try_recv(ws_stream: &mut T, stopped: S, ping_tx: &mpsc::Sender) -> Receive where S: Future + Unpin, T: StreamExt> + Unpin, @@ -326,15 +323,13 @@ where Either::Left((Either::Left((Some(Ok(Incoming::Data(d))), s)), _)) => { let ping_tx = ping_tx.clone(); tokio::spawn(async move { - _ = ping_tx.send(KeepAlive::Pong(Instant::now())).await; + _ = ping_tx.send(KeepAlive::Data(Instant::now())).await; }); break Receive::Ok(d, s); } // Got a pong response send status to the ping_pong_task. Either::Left((Either::Left((Some(Ok(Incoming::Pong)), s)), _)) => { - tracing::debug!(target: LOG_TARGET, "Received pong from peer={:?}", remote_addr); - let ping_tx = ping_tx.clone(); tokio::spawn(async move { _ = ping_tx.send(KeepAlive::Pong(Instant::now())).await; @@ -344,8 +339,11 @@ where } // Received an error, terminate the connection. Either::Left((Either::Left((Some(Err(e)), s)), _)) => break Receive::Err(e, s), + // Server has been stopped or closed by inactive peer. - _ => break Receive::Stopped, + Either::Left((Either::Right((_, _)), _)) => break Receive::Stopped, + // Ping task has been stopped. + Either::Right((_, _)) => break Receive::KeepAliveExpired, } } } @@ -353,6 +351,7 @@ where #[derive(Debug, Copy, Clone)] enum KeepAlive { Ping(Instant), + Data(Instant), Pong(Instant), } @@ -360,7 +359,7 @@ async fn ping_pong_task( mut rx: mpsc::Receiver, max_inactive_limit: Duration, max_inactive: usize, - remote_addr: SocketAddr, + remote_addr: Option, ) { let mut polling_interval = IntervalStream::new(interval(max_inactive_limit)); let mut pending_pings: VecDeque = VecDeque::new(); @@ -378,11 +377,11 @@ async fn ping_pong_task( if elapsed >= max_inactive_limit { missed_pings += 1; remove = true; - tracing::debug!(target: LOG_TARGET, "ping/pong keep alive expired for peer={:?}, elapsed={}s/max={}s", remote_addr, elapsed.as_secs(), max_inactive_limit.as_secs()); + tracing::debug!(target: LOG_TARGET, "Ping/pong keep alive expired for peer={:?}, elapsed={}ms/max={}ms", remote_addr, elapsed.as_millis(), max_inactive_limit.as_millis()); } if missed_pings >= max_inactive { - tracing::debug!(target: LOG_TARGET, "Too many missed ping/pongs for peer={:?}; closing connection", remote_addr); + tracing::debug!(target: LOG_TARGET, "Missed {missed_pings} ping/pongs for peer={:?}; closing connection", remote_addr); break; } } @@ -396,7 +395,7 @@ async fn ping_pong_task( Some(KeepAlive::Ping(start)) => { pending_pings.push_back(start); } - Some(KeepAlive::Pong(end)) => { + Some(KeepAlive::Pong(end)) | Some(KeepAlive::Data(end)) => { if let Some(start) = pending_pings.pop_front() { // Calculate the round-trip time (RTT) of the ping/pong. // We adjust for the time to send it to this task. @@ -404,13 +403,13 @@ async fn ping_pong_task( if elapsed >= max_inactive_limit { missed_pings += 1; - tracing::debug!(target: LOG_TARGET, "ping/pong keep alive expired for peer={:?}, elapsed={}s/max={}s", remote_addr, elapsed.as_secs(), max_inactive_limit.as_secs()); + tracing::debug!(target: LOG_TARGET, "Ping/pong keep alive expired for peer={:?}, elapsed={}ms/max={}ms", remote_addr, elapsed.as_millis(), max_inactive_limit.as_millis()); } - tracing::debug!(target: LOG_TARGET, "ping_pong_rtt={:?}, peer={:?}", elapsed, remote_addr); + tracing::trace!(target: LOG_TARGET, "ws_ping_pong_rtt={}ms, peer={:?}", elapsed.as_millis(), remote_addr); if missed_pings >= max_inactive { - tracing::debug!(target: LOG_TARGET, "Too many missed ping/pongs for peer={:?}; closing connection", remote_addr); + tracing::debug!(target: LOG_TARGET, "Missed {missed_pings} ping/pongs for peer={:?}; closing connection", remote_addr); break; } } @@ -426,6 +425,7 @@ async fn ping_pong_task( pub(crate) enum Shutdown { Stopped, ConnectionClosed, + KeepAliveExpired, } /// Enforce a graceful shutdown. From ce7336b512542406d7a395b75ab821be822b931b Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Wed, 7 Aug 2024 17:47:44 +0200 Subject: [PATCH 09/16] fix more nits --- server/src/transport/ws.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/server/src/transport/ws.rs b/server/src/transport/ws.rs index 0df349fd16..f529b31729 100644 --- a/server/src/transport/ws.rs +++ b/server/src/transport/ws.rs @@ -79,7 +79,7 @@ where } = params; // NOTE: jsonrpsee only inject the `remote_addr` if it not set because for servers that are behind a reverse proxy, - // needs read HTTP headers by the reverse proxy. + // needs read HTTP headers to get the real IP address of the client. let remote_addr = extensions.get::().copied(); let ServerConfig { ping_config, batch_requests_config, max_request_body_size, max_response_body_size, .. } = server_cfg; @@ -220,9 +220,6 @@ async fn send_task( ) { let ping_interval = match ping_config { None => IntervalStream::pending(), - // NOTE: we are emitted a tick here immediately to sync - // with how the receive task work because it starts measuring the pong - // when it starts up. Some(p) => IntervalStream::new(interval(p.ping_interval)), }; let rx = ReceiverStream::new(rx); @@ -248,7 +245,7 @@ async fn send_task( } if now.elapsed() > Duration::from_secs(30) { - tracing::warn!(target: LOG_TARGET, "Send message was slow {:?}, peer={:?}", now.elapsed(), remote_addr); + tracing::warn!(target: LOG_TARGET, "Send message was slow {}s, peer={:?}", now.elapsed().as_secs(), remote_addr); break; } @@ -396,9 +393,12 @@ async fn ping_pong_task( pending_pings.push_back(start); } Some(KeepAlive::Pong(end)) | Some(KeepAlive::Data(end)) => { + // Both pong and data are considered as a response to the ping. + // So we might get more responses than pings that's why it's possible + // that the pending_pings may be empty. if let Some(start) = pending_pings.pop_front() { // Calculate the round-trip time (RTT) of the ping/pong. - // We adjust for the time to send it to this task. + // We adjust for the time it took to send to this task. let elapsed = start.elapsed().saturating_sub(end.elapsed()); if elapsed >= max_inactive_limit { From 7cf608b778b01ff97e5d4cf39d8c5e04ec316a37 Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Wed, 7 Aug 2024 18:22:05 +0200 Subject: [PATCH 10/16] remote_addr -> ip_addr --- server/src/server.rs | 38 +++++++++++++++++++------------------- server/src/transport/ws.rs | 8 ++++---- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/server/src/server.rs b/server/src/server.rs index c485add6a3..f019c19234 100644 --- a/server/src/server.rs +++ b/server/src/server.rs @@ -240,8 +240,8 @@ pub struct TowerServiceBuilder { pub(crate) conn_id: Arc, /// Connection guard. pub(crate) conn_guard: ConnectionGuard, - /// Remote address. - pub(crate) remote_addr: Option, + /// IP address. + pub(crate) ip_addr: Option, } /// Configuration for batch request handling. @@ -463,7 +463,7 @@ pub struct Builder { server_cfg: ServerConfig, rpc_middleware: RpcServiceBuilder, http_middleware: tower::ServiceBuilder, - remote_addr: Option, + ip_addr: Option, } impl Default for Builder { @@ -472,7 +472,7 @@ impl Default for Builder { server_cfg: ServerConfig::default(), rpc_middleware: RpcServiceBuilder::new(), http_middleware: tower::ServiceBuilder::new(), - remote_addr: None, + ip_addr: None, } } } @@ -484,8 +484,8 @@ impl Builder { } /// Set the address of the remote peer. - pub fn set_remote_addr(mut self, remote_addr: SocketAddr) -> Self { - self.remote_addr = Some(remote_addr); + pub fn set_ip_addr(mut self, ip_addr: IpAddr) -> Self { + self.ip_addr = Some(ip_addr); self } } @@ -507,7 +507,7 @@ impl TowerServiceBuilder TowerServiceBuilder TowerServiceBuilder TowerServiceBuilder { + pub fn set_ip_addr(self, ip_addr: IpAddr) -> TowerServiceBuilder { TowerServiceBuilder { server_cfg: self.server_cfg, rpc_middleware: self.rpc_middleware, http_middleware: self.http_middleware, conn_id: self.conn_id, conn_guard: self.conn_guard, - remote_addr: Some(remote_addr), + ip_addr: Some(ip_addr), } } } @@ -666,7 +666,7 @@ impl Builder { server_cfg: self.server_cfg, rpc_middleware, http_middleware: self.http_middleware, - remote_addr: self.remote_addr, + ip_addr: self.ip_addr, } } @@ -757,7 +757,7 @@ impl Builder { server_cfg: self.server_cfg, http_middleware, rpc_middleware: self.rpc_middleware, - remote_addr: self.remote_addr, + ip_addr: self.ip_addr, } } @@ -894,7 +894,7 @@ impl Builder { http_middleware: self.http_middleware, conn_id: Arc::new(AtomicU32::new(0)), conn_guard: ConnectionGuard::new(max_conns), - remote_addr: self.remote_addr, + ip_addr: self.ip_addr, } } @@ -976,8 +976,8 @@ struct ServiceData { conn_guard: ConnectionGuard, /// ServerConfig server_cfg: ServerConfig, - /// Remote address. - remote_addr: Option, + /// IP address. + ip_addr: Option, } /// jsonrpsee tower service @@ -1087,12 +1087,12 @@ where request.extensions_mut().insert::(conn.conn_id.into()); - if let Some(remote_addr) = self.inner.remote_addr { + if let Some(ip_addr) = self.inner.ip_addr { // Only insert the remote address if it's not already set. // We expect servers deployed behind a reverse proxy to set the remote address // themselves otherwise the remote address will be the address of the reverse proxy. if request.extensions().get::().is_none() { - request.extensions_mut().insert(remote_addr.ip()); + request.extensions_mut().insert(ip_addr); } } @@ -1253,7 +1253,7 @@ where stop_handle: stop_handle.clone(), conn_id, conn_guard: conn_guard.clone(), - remote_addr: Some(remote_addr), + ip_addr: Some(remote_addr.ip()), }, rpc_middleware, on_session_close: None, diff --git a/server/src/transport/ws.rs b/server/src/transport/ws.rs index f529b31729..71f2fa39c5 100644 --- a/server/src/transport/ws.rs +++ b/server/src/transport/ws.rs @@ -80,17 +80,17 @@ where // NOTE: jsonrpsee only inject the `remote_addr` if it not set because for servers that are behind a reverse proxy, // needs read HTTP headers to get the real IP address of the client. - let remote_addr = extensions.get::().copied(); + let ip_addr = extensions.get::().copied(); let ServerConfig { ping_config, batch_requests_config, max_request_body_size, max_response_body_size, .. } = server_cfg; let (conn_tx, conn_rx) = oneshot::channel(); let (ping_tx, ping_rx) = mpsc::channel::(4); // Spawn another task that sends out the responses on the Websocket. - let send_task_handle = tokio::spawn(send_task(rx, ws_sender, ping_config, conn_rx, ping_tx.clone(), remote_addr)); + let send_task_handle = tokio::spawn(send_task(rx, ws_sender, ping_config, conn_rx, ping_tx.clone(), ip_addr)); if let Some(ping_config) = ping_config { - tokio::spawn(ping_pong_task(ping_rx, ping_config.inactive_limit, ping_config.max_failures, remote_addr)); + tokio::spawn(ping_pong_task(ping_rx, ping_config.inactive_limit, ping_config.max_failures, ip_addr)); } let stopped = conn.stop_handle.clone().shutdown(); @@ -194,7 +194,7 @@ where }); }; - tracing::debug!(target: LOG_TARGET, "Connection closed for peer={:?}, reason={:?}", remote_addr, result); + tracing::debug!(target: LOG_TARGET, "Connection closed for peer={:?}, reason={:?}", ip_addr, result); // Drive all running methods to completion. // **NOTE** Do not return early in this function. This `await` needs to run to guarantee From 705148f0c05dca1bf945e2387a866561a58beac5 Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Wed, 7 Aug 2024 18:41:51 +0200 Subject: [PATCH 11/16] remove used import --- examples/examples/ws.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/examples/ws.rs b/examples/examples/ws.rs index 16bbd9c860..9a703a6a7f 100644 --- a/examples/examples/ws.rs +++ b/examples/examples/ws.rs @@ -25,7 +25,6 @@ // DEALINGS IN THE SOFTWARE. use std::net::SocketAddr; -use std::time::Duration; use jsonrpsee::core::client::ClientT; use jsonrpsee::server::{PingConfig, RpcServiceBuilder, Server}; From eac1b1b44ab59f3022f48e222ed48f6a839603dd Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Fri, 9 Aug 2024 10:53:35 +0200 Subject: [PATCH 12/16] check if ping_pong_task is spawned --- server/src/transport/ws.rs | 96 ++++++++++++++++++++++++-------------- 1 file changed, 60 insertions(+), 36 deletions(-) diff --git a/server/src/transport/ws.rs b/server/src/transport/ws.rs index 71f2fa39c5..b0f3433f86 100644 --- a/server/src/transport/ws.rs +++ b/server/src/transport/ws.rs @@ -8,9 +8,9 @@ use crate::middleware::rpc::{RpcService, RpcServiceBuilder, RpcServiceCfg, RpcSe use crate::server::{handle_rpc_call, ConnectionState, ServerConfig}; use crate::{HttpBody, HttpRequest, HttpResponse, PingConfig, LOG_TARGET}; -use futures_util::future::{self, Either}; +use futures_util::future::{self, Either, Fuse}; use futures_util::io::{BufReader, BufWriter}; -use futures_util::{Future, StreamExt, TryStreamExt}; +use futures_util::{Future, FutureExt, StreamExt, TryStreamExt}; use hyper::upgrade::Upgraded; use hyper_util::rt::TokioIo; use jsonrpsee_core::server::{BoundedSubscriptions, MethodSink, Methods}; @@ -84,14 +84,20 @@ where let ServerConfig { ping_config, batch_requests_config, max_request_body_size, max_response_body_size, .. } = server_cfg; let (conn_tx, conn_rx) = oneshot::channel(); - let (ping_tx, ping_rx) = mpsc::channel::(4); - // Spawn another task that sends out the responses on the Websocket. - let send_task_handle = tokio::spawn(send_task(rx, ws_sender, ping_config, conn_rx, ping_tx.clone(), ip_addr)); - - if let Some(ping_config) = ping_config { + // Spawn ping/pong task if ping config is provided. + let ping_config = if let Some(ping_config) = ping_config { + let (ping_tx, ping_rx) = mpsc::channel::(4); tokio::spawn(ping_pong_task(ping_rx, ping_config.inactive_limit, ping_config.max_failures, ip_addr)); - } + Some((ping_config, ping_tx)) + } else { + None + }; + + let ping_tx = ping_config.as_ref().map(|(_, tx)| tx.clone()); + + // Spawn another task that sends out the responses on the Websocket. + let send_task_handle = tokio::spawn(send_task(rx, ws_sender, ping_config, conn_rx, ip_addr)); let stopped = conn.stop_handle.clone().shutdown(); let rpc_service = Arc::new(rpc_service); @@ -114,7 +120,7 @@ where tokio::pin!(ws_stream); let result = loop { - let data = match try_recv(&mut ws_stream, stopped, &ping_tx).await { + let data = match try_recv(&mut ws_stream, stopped, ping_tx.as_ref()).await { Receive::ConnectionClosed => break Ok(Shutdown::ConnectionClosed), Receive::KeepAliveExpired => break Ok(Shutdown::KeepAliveExpired), Receive::Stopped => break Ok(Shutdown::Stopped), @@ -213,22 +219,26 @@ where async fn send_task( rx: mpsc::Receiver, mut ws_sender: Sender, - ping_config: Option, + ping_config: Option<(PingConfig, mpsc::Sender)>, stop: oneshot::Receiver<()>, - ping_tx: mpsc::Sender, remote_addr: Option, ) { - let ping_interval = match ping_config { - None => IntervalStream::pending(), - Some(p) => IntervalStream::new(interval(p.ping_interval)), + use futures_util::future::Either; + + // Ping task is only spawned if ping config is provided. + let ping = match ping_config { + None => Either::Left(IntervalStream::pending().map(|_| None)), + Some((p, ping_tx)) => { + Either::Right(IntervalStream::new(interval(p.ping_interval)).map(move |_| Some(ping_tx.clone()))) + } }; let rx = ReceiverStream::new(rx); - tokio::pin!(ping_interval, rx, stop); + tokio::pin!(ping, rx, stop); // Received messages from the WebSocket. let mut rx_item = rx.next(); - let next_ping = ping_interval.next(); + let next_ping = ping.next(); let mut futs = future::select(next_ping, stop); loop { @@ -259,7 +269,7 @@ async fn send_task( } // Handle timer intervals. - Either::Right((Either::Left((_instant, _stopped)), next_rx)) => { + Either::Right((Either::Left((Some(ping_tx), _stopped)), next_rx)) => { stop = _stopped; let now = Instant::now(); if let Err(err) = send_ping(&mut ws_sender).await { @@ -274,13 +284,20 @@ async fn send_task( rx_item = next_rx; - let ping_tx = ping_tx.clone(); + let ping_tx = ping_tx.expect("ping tx is only `None` if ping_config is `None` checked above; qed"); tokio::spawn(async move { ping_tx.send(KeepAlive::Ping(Instant::now())).await.ok(); }); - futs = future::select(ping_interval.next(), stop); + futs = future::select(ping.next(), stop); + } + + // The interval stream has been closed. + // This should be unreachable because the interval stream never ends. + Either::Right((Either::Left((None, _stopped)), _)) => { + break; } + Either::Right((Either::Right((_stopped, _)), _)) => { // server has stopped break; @@ -302,37 +319,44 @@ enum Receive { } /// Attempts to read data from WebSocket fails if the server was stopped. -async fn try_recv(ws_stream: &mut T, stopped: S, ping_tx: &mpsc::Sender) -> Receive +async fn try_recv(ws_stream: &mut T, stopped: S, ping_tx: Option<&mpsc::Sender>) -> Receive where S: Future + Unpin, T: StreamExt> + Unpin, { let mut futs = future::select(ws_stream.next(), stopped); + let closed = match ping_tx { + Some(ping_tx) => ping_tx.closed().fuse(), + None => Fuse::terminated(), + }; - loop { - let closed = ping_tx.closed(); - tokio::pin!(closed); + tokio::pin!(closed); + loop { match future::select(futs, closed).await { // The connection is closed. Either::Left((Either::Left((None, _)), _)) => break Receive::ConnectionClosed, // The message has been received, we are done Either::Left((Either::Left((Some(Ok(Incoming::Data(d))), s)), _)) => { - let ping_tx = ping_tx.clone(); - tokio::spawn(async move { - _ = ping_tx.send(KeepAlive::Data(Instant::now())).await; - }); + if let Some(ping_tx) = ping_tx { + let ping_tx = ping_tx.clone(); + tokio::spawn(async move { + _ = ping_tx.send(KeepAlive::Data(Instant::now())).await; + }); + } break Receive::Ok(d, s); } // Got a pong response send status to the ping_pong_task. - Either::Left((Either::Left((Some(Ok(Incoming::Pong)), s)), _)) => { - let ping_tx = ping_tx.clone(); - tokio::spawn(async move { - _ = ping_tx.send(KeepAlive::Pong(Instant::now())).await; - }); - + Either::Left((Either::Left((Some(Ok(Incoming::Pong)), s)), c)) => { + if let Some(ping_tx) = ping_tx { + let ping_tx = ping_tx.clone(); + tokio::spawn(async move { + _ = ping_tx.send(KeepAlive::Pong(Instant::now())).await; + }); + } futs = futures_util::future::select(ws_stream.next(), s); + closed = c; } // Received an error, terminate the connection. Either::Left((Either::Left((Some(Err(e)), s)), _)) => break Receive::Err(e, s), @@ -355,7 +379,7 @@ enum KeepAlive { async fn ping_pong_task( mut rx: mpsc::Receiver, max_inactive_limit: Duration, - max_inactive: usize, + max_missed_pings: usize, remote_addr: Option, ) { let mut polling_interval = IntervalStream::new(interval(max_inactive_limit)); @@ -377,7 +401,7 @@ async fn ping_pong_task( tracing::debug!(target: LOG_TARGET, "Ping/pong keep alive expired for peer={:?}, elapsed={}ms/max={}ms", remote_addr, elapsed.as_millis(), max_inactive_limit.as_millis()); } - if missed_pings >= max_inactive { + if missed_pings >= max_missed_pings { tracing::debug!(target: LOG_TARGET, "Missed {missed_pings} ping/pongs for peer={:?}; closing connection", remote_addr); break; } @@ -408,7 +432,7 @@ async fn ping_pong_task( tracing::trace!(target: LOG_TARGET, "ws_ping_pong_rtt={}ms, peer={:?}", elapsed.as_millis(), remote_addr); - if missed_pings >= max_inactive { + if missed_pings >= max_missed_pings { tracing::debug!(target: LOG_TARGET, "Missed {missed_pings} ping/pongs for peer={:?}; closing connection", remote_addr); break; } From 69e30c60308cab9d1eae52b1d162b0f2cb922b4d Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Fri, 9 Aug 2024 11:00:33 +0200 Subject: [PATCH 13/16] remove ip addr --- server/src/server.rs | 55 +++----------------------------------- server/src/transport/ws.rs | 36 +++++++------------------ 2 files changed, 13 insertions(+), 78 deletions(-) diff --git a/server/src/server.rs b/server/src/server.rs index f019c19234..dc9c036189 100644 --- a/server/src/server.rs +++ b/server/src/server.rs @@ -26,7 +26,7 @@ use std::error::Error as StdError; use std::future::Future; -use std::net::{IpAddr, SocketAddr, TcpListener as StdTcpListener}; +use std::net::{SocketAddr, TcpListener as StdTcpListener}; use std::pin::Pin; use std::sync::atomic::AtomicU32; use std::sync::Arc; @@ -240,8 +240,6 @@ pub struct TowerServiceBuilder { pub(crate) conn_id: Arc, /// Connection guard. pub(crate) conn_guard: ConnectionGuard, - /// IP address. - pub(crate) ip_addr: Option, } /// Configuration for batch request handling. @@ -463,7 +461,6 @@ pub struct Builder { server_cfg: ServerConfig, rpc_middleware: RpcServiceBuilder, http_middleware: tower::ServiceBuilder, - ip_addr: Option, } impl Default for Builder { @@ -472,7 +469,6 @@ impl Default for Builder { server_cfg: ServerConfig::default(), rpc_middleware: RpcServiceBuilder::new(), http_middleware: tower::ServiceBuilder::new(), - ip_addr: None, } } } @@ -482,12 +478,6 @@ impl Builder { pub fn new() -> Self { Self::default() } - - /// Set the address of the remote peer. - pub fn set_ip_addr(mut self, ip_addr: IpAddr) -> Self { - self.ip_addr = Some(ip_addr); - self - } } impl TowerServiceBuilder { @@ -507,7 +497,6 @@ impl TowerServiceBuilder TowerServiceBuilder TowerServiceBuilder TowerServiceBuilder { - TowerServiceBuilder { - server_cfg: self.server_cfg, - rpc_middleware: self.rpc_middleware, - http_middleware: self.http_middleware, - conn_id: self.conn_id, - conn_guard: self.conn_guard, - ip_addr: Some(ip_addr), } } } @@ -662,12 +637,7 @@ impl Builder { /// let builder = ServerBuilder::default().set_rpc_middleware(m); /// ``` pub fn set_rpc_middleware(self, rpc_middleware: RpcServiceBuilder) -> Builder { - Builder { - server_cfg: self.server_cfg, - rpc_middleware, - http_middleware: self.http_middleware, - ip_addr: self.ip_addr, - } + Builder { server_cfg: self.server_cfg, rpc_middleware, http_middleware: self.http_middleware } } /// Configure a custom [`tokio::runtime::Handle`] to run the server on. @@ -753,12 +723,7 @@ impl Builder { /// } /// ``` pub fn set_http_middleware(self, http_middleware: tower::ServiceBuilder) -> Builder { - Builder { - server_cfg: self.server_cfg, - http_middleware, - rpc_middleware: self.rpc_middleware, - ip_addr: self.ip_addr, - } + Builder { server_cfg: self.server_cfg, http_middleware, rpc_middleware: self.rpc_middleware } } /// Configure `TCP_NODELAY` on the socket to the supplied value `nodelay`. @@ -894,7 +859,6 @@ impl Builder { http_middleware: self.http_middleware, conn_id: Arc::new(AtomicU32::new(0)), conn_guard: ConnectionGuard::new(max_conns), - ip_addr: self.ip_addr, } } @@ -976,8 +940,6 @@ struct ServiceData { conn_guard: ConnectionGuard, /// ServerConfig server_cfg: ServerConfig, - /// IP address. - ip_addr: Option, } /// jsonrpsee tower service @@ -1087,15 +1049,6 @@ where request.extensions_mut().insert::(conn.conn_id.into()); - if let Some(ip_addr) = self.inner.ip_addr { - // Only insert the remote address if it's not already set. - // We expect servers deployed behind a reverse proxy to set the remote address - // themselves otherwise the remote address will be the address of the reverse proxy. - if request.extensions().get::().is_none() { - request.extensions_mut().insert(ip_addr); - } - } - let is_upgrade_request = is_upgrade_request(&request); if self.inner.server_cfg.enable_ws && is_upgrade_request { @@ -1237,7 +1190,6 @@ where stop_handle, drop_on_completion, methods, - remote_addr, .. } = params; @@ -1253,7 +1205,6 @@ where stop_handle: stop_handle.clone(), conn_id, conn_guard: conn_guard.clone(), - ip_addr: Some(remote_addr.ip()), }, rpc_middleware, on_session_close: None, diff --git a/server/src/transport/ws.rs b/server/src/transport/ws.rs index b0f3433f86..a575cc8a98 100644 --- a/server/src/transport/ws.rs +++ b/server/src/transport/ws.rs @@ -1,5 +1,4 @@ use std::collections::VecDeque; -use std::net::IpAddr; use std::sync::Arc; use std::time::{Duration, Instant}; @@ -78,9 +77,7 @@ where extensions, } = params; - // NOTE: jsonrpsee only inject the `remote_addr` if it not set because for servers that are behind a reverse proxy, - // needs read HTTP headers to get the real IP address of the client. - let ip_addr = extensions.get::().copied(); + let conn_id = conn.conn_id; let ServerConfig { ping_config, batch_requests_config, max_request_body_size, max_response_body_size, .. } = server_cfg; let (conn_tx, conn_rx) = oneshot::channel(); @@ -88,7 +85,7 @@ where // Spawn ping/pong task if ping config is provided. let ping_config = if let Some(ping_config) = ping_config { let (ping_tx, ping_rx) = mpsc::channel::(4); - tokio::spawn(ping_pong_task(ping_rx, ping_config.inactive_limit, ping_config.max_failures, ip_addr)); + tokio::spawn(ping_pong_task(ping_rx, ping_config.inactive_limit, ping_config.max_failures, conn_id)); Some((ping_config, ping_tx)) } else { None @@ -97,7 +94,7 @@ where let ping_tx = ping_config.as_ref().map(|(_, tx)| tx.clone()); // Spawn another task that sends out the responses on the Websocket. - let send_task_handle = tokio::spawn(send_task(rx, ws_sender, ping_config, conn_rx, ip_addr)); + let send_task_handle = tokio::spawn(send_task(rx, ws_sender, ping_config, conn_rx)); let stopped = conn.stop_handle.clone().shutdown(); let rpc_service = Arc::new(rpc_service); @@ -200,7 +197,7 @@ where }); }; - tracing::debug!(target: LOG_TARGET, "Connection closed for peer={:?}, reason={:?}", ip_addr, result); + tracing::debug!(target: LOG_TARGET, "Connection closed for conn_id={conn_id}, reason={:?}", result); // Drive all running methods to completion. // **NOTE** Do not return early in this function. This `await` needs to run to guarantee @@ -221,7 +218,6 @@ async fn send_task( mut ws_sender: Sender, ping_config: Option<(PingConfig, mpsc::Sender)>, stop: oneshot::Receiver<()>, - remote_addr: Option, ) { use futures_util::future::Either; @@ -248,17 +244,11 @@ async fn send_task( // Received message. Either::Left((Some(response), not_ready)) => { // If websocket message send fail then terminate the connection. - let now = Instant::now(); if let Err(err) = send_message(&mut ws_sender, response).await { tracing::debug!(target: LOG_TARGET, "WS send error: {}", err); break; } - if now.elapsed() > Duration::from_secs(30) { - tracing::warn!(target: LOG_TARGET, "Send message was slow {}s, peer={:?}", now.elapsed().as_secs(), remote_addr); - break; - } - rx_item = rx.next(); futs = not_ready; } @@ -271,17 +261,11 @@ async fn send_task( // Handle timer intervals. Either::Right((Either::Left((Some(ping_tx), _stopped)), next_rx)) => { stop = _stopped; - let now = Instant::now(); if let Err(err) = send_ping(&mut ws_sender).await { tracing::debug!(target: LOG_TARGET, "WS send ping error: {}", err); break; } - if now.elapsed() > Duration::from_secs(30) { - tracing::warn!(target: LOG_TARGET, "Send ping was slow {}s, peer={:?}", now.elapsed().as_secs(), remote_addr); - break; - } - rx_item = next_rx; let ping_tx = ping_tx.expect("ping tx is only `None` if ping_config is `None` checked above; qed"); @@ -380,7 +364,7 @@ async fn ping_pong_task( mut rx: mpsc::Receiver, max_inactive_limit: Duration, max_missed_pings: usize, - remote_addr: Option, + conn_id: u32, ) { let mut polling_interval = IntervalStream::new(interval(max_inactive_limit)); let mut pending_pings: VecDeque = VecDeque::new(); @@ -398,11 +382,11 @@ async fn ping_pong_task( if elapsed >= max_inactive_limit { missed_pings += 1; remove = true; - tracing::debug!(target: LOG_TARGET, "Ping/pong keep alive expired for peer={:?}, elapsed={}ms/max={}ms", remote_addr, elapsed.as_millis(), max_inactive_limit.as_millis()); + tracing::debug!(target: LOG_TARGET, "Ping/pong keep alive expired for conn_id={conn_id}, elapsed={}ms/max={}ms", elapsed.as_millis(), max_inactive_limit.as_millis()); } if missed_pings >= max_missed_pings { - tracing::debug!(target: LOG_TARGET, "Missed {missed_pings} ping/pongs for peer={:?}; closing connection", remote_addr); + tracing::debug!(target: LOG_TARGET, "Missed {missed_pings} ping/pongs for conn_id={conn_id}; closing connection"); break; } } @@ -427,13 +411,13 @@ async fn ping_pong_task( if elapsed >= max_inactive_limit { missed_pings += 1; - tracing::debug!(target: LOG_TARGET, "Ping/pong keep alive expired for peer={:?}, elapsed={}ms/max={}ms", remote_addr, elapsed.as_millis(), max_inactive_limit.as_millis()); + tracing::debug!(target: LOG_TARGET, "Ping/pong keep alive expired for conn_id={conn_id}, elapsed={}ms/max={}ms", elapsed.as_millis(), max_inactive_limit.as_millis()); } - tracing::trace!(target: LOG_TARGET, "ws_ping_pong_rtt={}ms, peer={:?}", elapsed.as_millis(), remote_addr); + tracing::trace!(target: LOG_TARGET, "ws_ping_pong_rtt={}ms, conn_id={conn_id}", elapsed.as_millis()); if missed_pings >= max_missed_pings { - tracing::debug!(target: LOG_TARGET, "Missed {missed_pings} ping/pongs for peer={:?}; closing connection", remote_addr); + tracing::debug!(target: LOG_TARGET, "Missed {missed_pings} ping/pongs for conn_id={conn_id}; closing connection"); break; } } From 9d409b6b15ea99aa2789169eb884d89b91c23046 Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Fri, 9 Aug 2024 12:43:41 +0200 Subject: [PATCH 14/16] check all pending pings on pong/data --- server/src/transport/ws.rs | 55 +++++------------------ server/src/utils.rs | 91 +++++++++++++++++++++++++++++++++++++- 2 files changed, 100 insertions(+), 46 deletions(-) diff --git a/server/src/transport/ws.rs b/server/src/transport/ws.rs index a575cc8a98..97cf6bee82 100644 --- a/server/src/transport/ws.rs +++ b/server/src/transport/ws.rs @@ -1,10 +1,10 @@ -use std::collections::VecDeque; use std::sync::Arc; use std::time::{Duration, Instant}; use crate::future::{IntervalStream, SessionClose}; use crate::middleware::rpc::{RpcService, RpcServiceBuilder, RpcServiceCfg, RpcServiceT}; use crate::server::{handle_rpc_call, ConnectionState, ServerConfig}; +use crate::utils::PendingPings; use crate::{HttpBody, HttpRequest, HttpResponse, PingConfig, LOG_TARGET}; use futures_util::future::{self, Either, Fuse}; @@ -354,7 +354,7 @@ where } #[derive(Debug, Copy, Clone)] -enum KeepAlive { +pub(crate) enum KeepAlive { Ping(Instant), Data(Instant), Pong(Instant), @@ -362,64 +362,29 @@ enum KeepAlive { async fn ping_pong_task( mut rx: mpsc::Receiver, - max_inactive_limit: Duration, + max_inactivity_dur: Duration, max_missed_pings: usize, conn_id: u32, ) { - let mut polling_interval = IntervalStream::new(interval(max_inactive_limit)); - let mut pending_pings: VecDeque = VecDeque::new(); - let mut missed_pings = 0; + let mut polling_interval = IntervalStream::new(interval(max_inactivity_dur)); + let mut pending_pings = PendingPings::new(max_missed_pings, max_inactivity_dur, conn_id); loop { tokio::select! { // If the ping is never answered, we use this timer as a fallback. _ = polling_interval.next() => { - let mut remove = false; - - if let Some(ping_start) = pending_pings.front() { - let elapsed = ping_start.elapsed(); - - if elapsed >= max_inactive_limit { - missed_pings += 1; - remove = true; - tracing::debug!(target: LOG_TARGET, "Ping/pong keep alive expired for conn_id={conn_id}, elapsed={}ms/max={}ms", elapsed.as_millis(), max_inactive_limit.as_millis()); - } - - if missed_pings >= max_missed_pings { - tracing::debug!(target: LOG_TARGET, "Missed {missed_pings} ping/pongs for conn_id={conn_id}; closing connection"); - break; - } - } - - if remove { - pending_pings.pop_front(); + if !pending_pings.check_pending(Instant::now()) { + break; } } msg = rx.recv() => { match msg { Some(KeepAlive::Ping(start)) => { - pending_pings.push_back(start); + pending_pings.push(start); } Some(KeepAlive::Pong(end)) | Some(KeepAlive::Data(end)) => { - // Both pong and data are considered as a response to the ping. - // So we might get more responses than pings that's why it's possible - // that the pending_pings may be empty. - if let Some(start) = pending_pings.pop_front() { - // Calculate the round-trip time (RTT) of the ping/pong. - // We adjust for the time it took to send to this task. - let elapsed = start.elapsed().saturating_sub(end.elapsed()); - - if elapsed >= max_inactive_limit { - missed_pings += 1; - tracing::debug!(target: LOG_TARGET, "Ping/pong keep alive expired for conn_id={conn_id}, elapsed={}ms/max={}ms", elapsed.as_millis(), max_inactive_limit.as_millis()); - } - - tracing::trace!(target: LOG_TARGET, "ws_ping_pong_rtt={}ms, conn_id={conn_id}", elapsed.as_millis()); - - if missed_pings >= max_missed_pings { - tracing::debug!(target: LOG_TARGET, "Missed {missed_pings} ping/pongs for conn_id={conn_id}; closing connection"); - break; - } + if !pending_pings.check_pending(end) { + break; } } None => break, diff --git a/server/src/utils.rs b/server/src/utils.rs index d510a84661..9dc46b602e 100644 --- a/server/src/utils.rs +++ b/server/src/utils.rs @@ -24,11 +24,13 @@ // IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. +use std::collections::VecDeque; use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; +use std::time::{Duration, Instant}; -use crate::{HttpBody, HttpRequest}; +use crate::{HttpBody, HttpRequest, LOG_TARGET}; use futures_util::future::{self, Either}; use hyper_util::rt::{TokioExecutor, TokioIo}; @@ -83,6 +85,50 @@ where } } +#[derive(Debug, Clone)] +pub(crate) struct PendingPings { + list: VecDeque, + max_missed_pings: usize, + missed_pings: usize, + max_inactivity_dur: Duration, + conn_id: u32, +} + +impl PendingPings { + pub(crate) fn new(max_missed_pings: usize, max_inactivity_dur: Duration, conn_id: u32) -> Self { + Self { list: VecDeque::new(), max_missed_pings, max_inactivity_dur, missed_pings: 0, conn_id } + } + + pub(crate) fn push(&mut self, instant: Instant) { + self.list.push_back(instant); + } + + /// Returns `true` if the pong was answered in time, `false` otherwise. + pub(crate) fn check_pending(&mut self, end: Instant) -> bool { + for ping_start in self.list.drain(..) { + // Calculate the round-trip time (RTT) of the ping/pong. + // We adjust for the time when the pong was received. + let elapsed = ping_start.elapsed().saturating_sub(end.elapsed()); + + tracing::trace!(target: LOG_TARGET, "ws_ping_pong_rtt={}ms, conn_id={}", elapsed.as_millis(), self.conn_id); + + if elapsed >= self.max_inactivity_dur { + self.missed_pings += 1; + tracing::debug!(target: LOG_TARGET, "Ping/pong keep alive expired for conn_id={}, elapsed={}ms/max={}ms", self.conn_id, elapsed.as_millis(), self.max_inactivity_dur.as_millis()); + } else { + self.missed_pings = 0; + } + + if self.missed_pings >= self.max_missed_pings { + tracing::debug!(target: LOG_TARGET, "Missed {} ping/pongs for conn_id={}; closing connection", self.missed_pings, self.conn_id); + return false; + } + } + + true + } +} + /// Serve a service over a TCP connection without graceful shutdown. /// This means that pending requests will be dropped when the server is stopped. /// @@ -163,3 +209,46 @@ pub(crate) mod deserialize { Ok(req) } } + +#[cfg(test)] +mod tests { + use super::PendingPings; + use std::time::Duration; + + #[test] + fn pending_ping_works() { + let mut pending_pings = PendingPings::new(1, std::time::Duration::from_secs(1), 0); + + pending_pings.push(std::time::Instant::now()); + assert!(pending_pings.check_pending(std::time::Instant::now())); + assert!(pending_pings.list.is_empty()); + } + + #[test] + fn inactive_too_long() { + let mut pending_pings = PendingPings::new(2, std::time::Duration::from_millis(100), 0); + + pending_pings.push(std::time::Instant::now()); + pending_pings.push(std::time::Instant::now()); + + std::thread::sleep(Duration::from_millis(200)); + + assert!(!pending_pings.check_pending(std::time::Instant::now())); + } + + #[test] + fn active_reset_counter() { + let mut pending_pings = PendingPings::new(2, std::time::Duration::from_millis(100), 0); + + pending_pings.push(std::time::Instant::now()); + + std::thread::sleep(Duration::from_millis(200)); + + assert!(pending_pings.check_pending(std::time::Instant::now())); + assert_eq!(pending_pings.missed_pings, 1); + + pending_pings.push(std::time::Instant::now()); + assert!(pending_pings.check_pending(std::time::Instant::now())); + assert_eq!(pending_pings.missed_pings, 0); + } +} From 4abb1daf7d920f30de92a7a940ba9720660af28a Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Fri, 9 Aug 2024 12:47:28 +0200 Subject: [PATCH 15/16] fix nit --- server/src/transport/ws.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/server/src/transport/ws.rs b/server/src/transport/ws.rs index 97cf6bee82..6ff7fd3d00 100644 --- a/server/src/transport/ws.rs +++ b/server/src/transport/ws.rs @@ -219,8 +219,6 @@ async fn send_task( ping_config: Option<(PingConfig, mpsc::Sender)>, stop: oneshot::Receiver<()>, ) { - use futures_util::future::Either; - // Ping task is only spawned if ping config is provided. let ping = match ping_config { None => Either::Left(IntervalStream::pending().map(|_| None)), From 66af3a67acbe757675e01782ac362bf14d681401 Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Fri, 9 Aug 2024 15:56:59 +0200 Subject: [PATCH 16/16] add apis for checking and response --- server/src/transport/ws.rs | 5 +-- server/src/utils.rs | 65 +++++++++++++++++++++++++++++++------- 2 files changed, 56 insertions(+), 14 deletions(-) diff --git a/server/src/transport/ws.rs b/server/src/transport/ws.rs index 6ff7fd3d00..5efc2c4433 100644 --- a/server/src/transport/ws.rs +++ b/server/src/transport/ws.rs @@ -371,17 +371,18 @@ async fn ping_pong_task( tokio::select! { // If the ping is never answered, we use this timer as a fallback. _ = polling_interval.next() => { - if !pending_pings.check_pending(Instant::now()) { + if !pending_pings.check_alive() { break; } } + // Data on the connection. msg = rx.recv() => { match msg { Some(KeepAlive::Ping(start)) => { pending_pings.push(start); } Some(KeepAlive::Pong(end)) | Some(KeepAlive::Data(end)) => { - if !pending_pings.check_pending(end) { + if !pending_pings.alive_response(end) { break; } } diff --git a/server/src/utils.rs b/server/src/utils.rs index 9dc46b602e..dc3a86da0e 100644 --- a/server/src/utils.rs +++ b/server/src/utils.rs @@ -99,12 +99,52 @@ impl PendingPings { Self { list: VecDeque::new(), max_missed_pings, max_inactivity_dur, missed_pings: 0, conn_id } } + fn log_ping_expired(elapsed: Duration, conn_id: u32, max_inactivity_dur: Duration) { + tracing::debug!(target: LOG_TARGET, "Ping/pong keep alive for conn_id={conn_id}, elapsed={}ms/max={}ms", elapsed.as_millis(), max_inactivity_dur.as_millis()); + } + + fn log_connection_closed(missed_pings: usize, conn_id: u32) { + tracing::debug!(target: LOG_TARGET, "Missed {missed_pings} ping/pongs for conn_id={conn_id}; closing connection"); + } + pub(crate) fn push(&mut self, instant: Instant) { self.list.push_back(instant); } + /// Check if there are any pending pings that have expired + /// + /// It's different from [`PendingPing::alive_response`] because + /// this shouldn't be used when data is received. + /// + /// It's just way to ensure that pings are checked despite no message is received on the + /// connection. + /// + /// Returns `true` if the connection is still alive, `false` otherwise. + pub(crate) fn check_alive(&mut self) -> bool { + let mut list = VecDeque::new(); + + for ping_start in self.list.drain(..) { + if ping_start.elapsed() >= self.max_inactivity_dur { + self.missed_pings += 1; + Self::log_ping_expired(ping_start.elapsed(), self.conn_id, self.max_inactivity_dur); + } else { + list.push_back(ping_start); + } + + if self.missed_pings >= self.max_missed_pings { + Self::log_connection_closed(self.missed_pings, self.conn_id); + return false; + } + } + + self.list = list; + true + } + + /// Register a alive response. + /// /// Returns `true` if the pong was answered in time, `false` otherwise. - pub(crate) fn check_pending(&mut self, end: Instant) -> bool { + pub(crate) fn alive_response(&mut self, end: Instant) -> bool { for ping_start in self.list.drain(..) { // Calculate the round-trip time (RTT) of the ping/pong. // We adjust for the time when the pong was received. @@ -114,13 +154,13 @@ impl PendingPings { if elapsed >= self.max_inactivity_dur { self.missed_pings += 1; - tracing::debug!(target: LOG_TARGET, "Ping/pong keep alive expired for conn_id={}, elapsed={}ms/max={}ms", self.conn_id, elapsed.as_millis(), self.max_inactivity_dur.as_millis()); + Self::log_ping_expired(ping_start.elapsed(), self.conn_id, self.max_inactivity_dur); } else { self.missed_pings = 0; } if self.missed_pings >= self.max_missed_pings { - tracing::debug!(target: LOG_TARGET, "Missed {} ping/pongs for conn_id={}; closing connection", self.missed_pings, self.conn_id); + Self::log_connection_closed(self.missed_pings, self.conn_id); return false; } } @@ -213,42 +253,43 @@ pub(crate) mod deserialize { #[cfg(test)] mod tests { use super::PendingPings; - use std::time::Duration; + use std::time::{Duration, Instant}; #[test] fn pending_ping_works() { let mut pending_pings = PendingPings::new(1, std::time::Duration::from_secs(1), 0); - pending_pings.push(std::time::Instant::now()); - assert!(pending_pings.check_pending(std::time::Instant::now())); + pending_pings.push(Instant::now()); + assert!(pending_pings.alive_response(std::time::Instant::now())); assert!(pending_pings.list.is_empty()); + assert_eq!(pending_pings.missed_pings, 0); } #[test] fn inactive_too_long() { let mut pending_pings = PendingPings::new(2, std::time::Duration::from_millis(100), 0); - pending_pings.push(std::time::Instant::now()); - pending_pings.push(std::time::Instant::now()); + pending_pings.push(Instant::now()); + pending_pings.push(Instant::now()); std::thread::sleep(Duration::from_millis(200)); - assert!(!pending_pings.check_pending(std::time::Instant::now())); + assert!(!pending_pings.check_alive()); + assert_eq!(pending_pings.missed_pings, 2); } #[test] fn active_reset_counter() { let mut pending_pings = PendingPings::new(2, std::time::Duration::from_millis(100), 0); - pending_pings.push(std::time::Instant::now()); std::thread::sleep(Duration::from_millis(200)); - assert!(pending_pings.check_pending(std::time::Instant::now())); + assert!(pending_pings.check_alive()); assert_eq!(pending_pings.missed_pings, 1); pending_pings.push(std::time::Instant::now()); - assert!(pending_pings.check_pending(std::time::Instant::now())); + assert!(pending_pings.alive_response(Instant::now())); assert_eq!(pending_pings.missed_pings, 0); } }