From c1da2026d09d5d11018a3f1b3eb281605832933f Mon Sep 17 00:00:00 2001 From: Fraser Hutchison Date: Thu, 29 Feb 2024 20:53:38 +0000 Subject: [PATCH 1/4] provide more info in PayloadExceedsLimit error --- src/protocol.rs | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/src/protocol.rs b/src/protocol.rs index 82145a5..43324ff 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -499,7 +499,12 @@ pub enum LocalProtocolViolation { /// See [`ChannelConfiguration::with_max_request_payload_size()`] and /// [`ChannelConfiguration::with_max_response_payload_size()`] for details. #[error("payload exceeds configured limit")] - PayloadExceedsLimit, + PayloadExceedsLimit { + /// The payload length in bytes. + payload_length: usize, + /// The configured upper limit for payload length in bytes. + limit: usize, + }, /// The given error payload exceeds a single frame. /// /// Error payloads may not span multiple frames, shorten the payload or increase frame size. @@ -595,7 +600,10 @@ impl JulietProtocol { if let Some(ref payload) = payload { if payload.len() > chan.config.max_request_payload_size as usize { - return Err(LocalProtocolViolation::PayloadExceedsLimit); + return Err(LocalProtocolViolation::PayloadExceedsLimit { + payload_length: payload.len(), + limit: chan.config.max_request_payload_size as usize, + }); } } @@ -637,7 +645,10 @@ impl JulietProtocol { if let Some(ref payload) = payload { if payload.len() > chan.config.max_response_payload_size as usize { - return Err(LocalProtocolViolation::PayloadExceedsLimit); + return Err(LocalProtocolViolation::PayloadExceedsLimit { + payload_length: payload.len(), + limit: chan.config.max_request_payload_size as usize, + }); } } @@ -2202,7 +2213,10 @@ mod tests { .create_request(env.common_channel, payload.get()) .expect_err("should not be able to create too large request"); - assert_matches!(violation, LocalProtocolViolation::PayloadExceedsLimit); + assert_matches!( + violation, + LocalProtocolViolation::PayloadExceedsLimit { .. } + ); // If we force the issue, Bob must refuse it instead. let bob_result = env.inject_and_send_request(Alice, payload.get()); @@ -2219,7 +2233,10 @@ mod tests { .bob .create_request(env.common_channel, payload.get()) .expect_err("should not be able to create too large response"); - assert_matches!(violation, LocalProtocolViolation::PayloadExceedsLimit); + assert_matches!( + violation, + LocalProtocolViolation::PayloadExceedsLimit { .. } + ); // If we force the issue, Alice must refuse it. let alice_result = env.inject_and_send_response(Bob, id, payload.get()); From c64366a5380c99669d3a2225774b4a0afc2f4441 Mon Sep 17 00:00:00 2001 From: Fraser Hutchison Date: Tue, 5 Mar 2024 11:00:51 +0000 Subject: [PATCH 2/4] avoid deadlocking in rpc test --- src/rpc.rs | 69 +++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 47 insertions(+), 22 deletions(-) diff --git a/src/rpc.rs b/src/rpc.rs index 1bd80a5..cebb8bd 100644 --- a/src/rpc.rs +++ b/src/rpc.rs @@ -766,7 +766,7 @@ impl IncomingRequest { // Do nothing, just discard the response. } EnqueueError::BufferLimitHit(_) => { - // TODO: Add seperate type to avoid this. + // TODO: Add separate type to avoid this. unreachable!("cannot hit request limit when responding") } } @@ -851,7 +851,10 @@ mod tests { use bytes::Bytes; use futures::FutureExt; - use tokio::io::{DuplexStream, ReadHalf, WriteHalf}; + use tokio::{ + io::{DuplexStream, ReadHalf, WriteHalf}, + sync::mpsc, + }; use tracing::{error_span, info, span, Instrument, Level}; use crate::{ @@ -1330,7 +1333,7 @@ mod tests { large_volume_test::<1>(spec).await; } - #[tokio::test(flavor = "multi_thread", worker_threads = 4)] + #[tokio::test(flavor = "multi_thread", worker_threads = 5)] async fn run_large_volume_test_with_default_values_10_channels() { tracing_subscriber::fmt() .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) @@ -1352,7 +1355,7 @@ mod tests { let (mut alice, mut bob) = LargeVolumeTestSpec::::default().mk_rpc(); - // Alice server. Will close the connection after enough bytes have been sent. + // Alice server. Will close the connection after enough bytes have been received. let mut remaining = spec.min_send_bytes; let alice_server = tokio::spawn( async move { @@ -1371,6 +1374,7 @@ mod tests { request.respond(None); remaining = remaining.saturating_sub(payload_size); + tracing::debug!("payload_size: {payload_size}, remaining: {remaining}"); if remaining == 0 { // We've reached the volume we were looking for, end test. break; @@ -1420,14 +1424,18 @@ mod tests { Err(guard) => { // Not ready, but we are not going to wait. - tokio::spawn(async move { - if let Err(err) = guard.wait_for_response().await { - match err { - RequestError::RemoteClosed(_) | RequestError::Shutdown => {} - err => panic!("{}", err), + tokio::spawn( + async move { + if let Err(err) = guard.wait_for_response().await { + match err { + RequestError::RemoteClosed(_) + | RequestError::Shutdown => {} + err => panic!("{}", err), + } } } - }); + .in_current_span(), + ); } } } @@ -1437,10 +1445,11 @@ mod tests { .instrument(error_span!("alice_client")), ); - // Bob server. + // A channel to allow Bob's server to notify Bob's client to send a new request to Alice. + let (notify_tx, mut notify_rx) = mpsc::unbounded_channel(); + // Bob server. Will shut down once Alice closes the connection. let bob_server = tokio::spawn( async move { - let mut bob_counter = 0; while let Some(request) = bob .server .next_request() @@ -1459,7 +1468,19 @@ mod tests { let channel = request.channel(); // Just discard the message payload, but acknowledge receiving it. request.respond(None); + // Notify Bob client to send a new request to Alice. + notify_tx.send(channel).unwrap(); + } + info!("exiting"); + } + .instrument(error_span!("bob_server")), + ); + // Bob client. Will shut down once Alice closes the connection. + let bob_client = tokio::spawn( + async move { + let mut bob_counter = 0; + while let Some(channel) = notify_rx.recv().await { let payload_size = spec.gen_payload_size(bob_counter); let large_payload: Bytes = iter::repeat(0xFF) .take(payload_size) @@ -1470,11 +1491,11 @@ mod tests { let bobs_request: RequestGuard = bob .client .create_request(channel) - .with_payload(large_payload.clone()) + .with_payload(large_payload) .queue_for_sending() .await; - info!(bob_counter, "bob enqueued request"); + info!(bob_counter, payload_size, "bob enqueued request"); bob_counter += 1; match bobs_request.try_get_response() { @@ -1492,26 +1513,30 @@ mod tests { Err(guard) => { // Do not wait, instead attempt to retrieve next request. - tokio::spawn(async move { - if let Err(err) = guard.wait_for_response().await { - match err { - RequestError::RemoteClosed(_) | RequestError::Shutdown => {} - err => panic!("{}", err), + tokio::spawn( + async move { + if let Err(err) = guard.wait_for_response().await { + match err { + RequestError::RemoteClosed(_) + | RequestError::Shutdown => {} + err => panic!("{}", err), + } } } - }); + .in_current_span(), + ); } } } - info!("exiting"); } - .instrument(error_span!("bob_server")), + .instrument(error_span!("bob_client")), ); alice_server.await.expect("failed to join alice server"); alice_client.await.expect("failed to join alice client"); bob_server.await.expect("failed to join bob server"); + bob_client.await.expect("failed to join bob client"); info!("all joined"); } From 4c710554a15415cbf579a1ec6da42adf3dad8e1e Mon Sep 17 00:00:00 2001 From: Fraser Hutchison Date: Tue, 5 Mar 2024 16:25:50 +0000 Subject: [PATCH 3/4] add further info to LocalProtocolViolation variants --- src/protocol.rs | 99 ++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 77 insertions(+), 22 deletions(-) diff --git a/src/protocol.rs b/src/protocol.rs index 43324ff..759ddc1 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -305,6 +305,11 @@ impl Channel { self.outgoing_requests.len() < self.config.request_limit as usize } + /// Returns the configured request limit for this channel. + pub fn request_limit(&self) -> u16 { + self.config.request_limit + } + /// Creates a new request, bypassing all client-side checks. /// /// Low-level function that does nothing but create a syntactically correct request and track @@ -474,7 +479,7 @@ impl Display for CompletedRead { } } -/// The caller of the this crate has violated the protocol. +/// The caller of this crate has violated the protocol. /// /// A correct implementation of a client should never encounter this, thus simply unwrapping every /// instance of this as part of a `Result<_, LocalProtocolViolation>` is usually a valid choice. @@ -487,18 +492,26 @@ pub enum LocalProtocolViolation { /// /// Wait for additional requests to be cancelled or answered. Calling /// [`JulietProtocol::allowed_to_send_request()`] beforehand is recommended. - #[error("sending would exceed request limit")] - WouldExceedRequestLimit, + #[error("sending would exceed request limit of {limit}")] + WouldExceedRequestLimit { + /// The configured limit for requests on the channel. + limit: u16, + }, /// The channel given does not exist. /// /// The given [`ChannelId`] exceeds `N` of [`JulietProtocol`]. - #[error("invalid channel")] - InvalidChannel(ChannelId), + #[error("channel {channel} not a member of configured {channel_count} channels")] + InvalidChannel { + /// The provided channel ID. + channel: ChannelId, + /// The configured number of channels. + channel_count: usize, + }, /// The given payload exceeds the configured limit. /// /// See [`ChannelConfiguration::with_max_request_payload_size()`] and /// [`ChannelConfiguration::with_max_response_payload_size()`] for details. - #[error("payload exceeds configured limit")] + #[error("payload length of {payload_length} bytes exceeds configured limit of {limit}")] PayloadExceedsLimit { /// The payload length in bytes. payload_length: usize, @@ -508,8 +521,16 @@ pub enum LocalProtocolViolation { /// The given error payload exceeds a single frame. /// /// Error payloads may not span multiple frames, shorten the payload or increase frame size. - #[error("error payload would be multi-frame")] - ErrorPayloadIsMultiFrame, + #[error( + "error payload of {payload_length} bytes exceeds a single frame with configured max size \ + of {max_frame_size})" + )] + ErrorPayloadIsMultiFrame { + /// The payload length in bytes. + payload_length: usize, + /// The configured maximum frame size in bytes. + max_frame_size: u32, + }, } macro_rules! log_frame { @@ -539,7 +560,10 @@ impl JulietProtocol { #[inline(always)] const fn lookup_channel(&self, channel: ChannelId) -> Result<&Channel, LocalProtocolViolation> { if channel.0 as usize >= N { - Err(LocalProtocolViolation::InvalidChannel(channel)) + Err(LocalProtocolViolation::InvalidChannel { + channel, + channel_count: N, + }) } else { Ok(&self.channels[channel.0 as usize]) } @@ -554,7 +578,10 @@ impl JulietProtocol { channel: ChannelId, ) -> Result<&mut Channel, LocalProtocolViolation> { if channel.0 as usize >= N { - Err(LocalProtocolViolation::InvalidChannel(channel)) + Err(LocalProtocolViolation::InvalidChannel { + channel, + channel_count: N, + }) } else { Ok(&mut self.channels[channel.0 as usize]) } @@ -608,7 +635,9 @@ impl JulietProtocol { } if !chan.allowed_to_send_request() { - return Err(LocalProtocolViolation::WouldExceedRequestLimit); + return Err(LocalProtocolViolation::WouldExceedRequestLimit { + limit: chan.request_limit(), + }); } Ok(chan.create_unchecked_request(channel, payload)) @@ -723,11 +752,15 @@ impl JulietProtocol { id: Id, payload: Bytes, ) -> Result { - let header = Header::new_error(header::ErrorKind::Other, channel, id); + let header = Header::new_error(ErrorKind::Other, channel, id); + let payload_length = payload.len(); let msg = OutgoingMessage::new(header, Some(payload)); if msg.is_multi_frame(self.max_frame_size) { - Err(LocalProtocolViolation::ErrorPayloadIsMultiFrame) + Err(LocalProtocolViolation::ErrorPayloadIsMultiFrame { + payload_length, + max_frame_size: self.max_frame_size.0, + }) } else { Ok(msg) } @@ -1264,7 +1297,8 @@ mod tests { #[test] fn test_channel_lookups_work() { - let mut protocol: JulietProtocol<3> = ProtocolBuilder::new().build(); + const CHANNEL_COUNT: usize = 3; + let mut protocol: JulietProtocol = ProtocolBuilder::new().build(); // We mark channels by inserting an ID into them, that way we can ensure we're not getting // back the same channel every time. @@ -1285,15 +1319,24 @@ mod tests { .insert(Id::new(102)); assert!(matches!( protocol.lookup_channel_mut(ChannelId(3)), - Err(LocalProtocolViolation::InvalidChannel(ChannelId(3))) + Err(LocalProtocolViolation::InvalidChannel { + channel: ChannelId(3), + channel_count: CHANNEL_COUNT + }) )); assert!(matches!( protocol.lookup_channel_mut(ChannelId(4)), - Err(LocalProtocolViolation::InvalidChannel(ChannelId(4))) + Err(LocalProtocolViolation::InvalidChannel { + channel: ChannelId(4), + channel_count: CHANNEL_COUNT + }) )); assert!(matches!( protocol.lookup_channel_mut(ChannelId(255)), - Err(LocalProtocolViolation::InvalidChannel(ChannelId(255))) + Err(LocalProtocolViolation::InvalidChannel { + channel: ChannelId(255), + channel_count: CHANNEL_COUNT + }) )); // Now look up the channels and ensure they contain the right values @@ -1320,15 +1363,24 @@ mod tests { ); assert!(matches!( protocol.lookup_channel(ChannelId(3)), - Err(LocalProtocolViolation::InvalidChannel(ChannelId(3))) + Err(LocalProtocolViolation::InvalidChannel { + channel: ChannelId(3), + channel_count: CHANNEL_COUNT + }) )); assert!(matches!( protocol.lookup_channel(ChannelId(4)), - Err(LocalProtocolViolation::InvalidChannel(ChannelId(4))) + Err(LocalProtocolViolation::InvalidChannel { + channel: ChannelId(4), + channel_count: CHANNEL_COUNT + }) )); assert!(matches!( protocol.lookup_channel(ChannelId(255)), - Err(LocalProtocolViolation::InvalidChannel(ChannelId(255))) + Err(LocalProtocolViolation::InvalidChannel { + channel: ChannelId(255), + channel_count: CHANNEL_COUNT + }) )); } @@ -1453,7 +1505,10 @@ mod tests { // Try an invalid channel, should result in an error. assert!(matches!( protocol.create_request(ChannelId::new(2), payload.get()), - Err(LocalProtocolViolation::InvalidChannel(ChannelId(2))) + Err(LocalProtocolViolation::InvalidChannel { + channel: ChannelId(2), + channel_count: 2 + }) )); assert!(protocol @@ -1465,7 +1520,7 @@ mod tests { assert!(matches!( protocol.create_request(channel, payload.get()), - Err(LocalProtocolViolation::WouldExceedRequestLimit) + Err(LocalProtocolViolation::WouldExceedRequestLimit { limit: 1 }) )); } } From 9db8b5ea3b6befe8e52026756d4a9051cdf1e590 Mon Sep 17 00:00:00 2001 From: Fraser Hutchison Date: Tue, 5 Mar 2024 16:28:38 +0000 Subject: [PATCH 4/4] appease clippy --- src/protocol/multiframe.rs | 1 + src/rpc.rs | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/protocol/multiframe.rs b/src/protocol/multiframe.rs index de6a913..75c908a 100644 --- a/src/protocol/multiframe.rs +++ b/src/protocol/multiframe.rs @@ -44,6 +44,7 @@ pub(super) enum MultiframeReceiver { /// The outcome of a multiframe acceptance. #[derive(Debug)] +#[allow(clippy::enum_variant_names)] pub(crate) enum CompletedFrame { /// A new multi-frame transfer was started. NewMultiFrame, diff --git a/src/rpc.rs b/src/rpc.rs index cebb8bd..4f3e645 100644 --- a/src/rpc.rs +++ b/src/rpc.rs @@ -926,7 +926,7 @@ mod tests { async fn run_echo_client( mut rpc_server: JulietRpcServer, WriteHalf>, ) { - while let Some(inc) = rpc_server + if let Some(inc) = rpc_server .next_request() .await .expect("client rpc_server error") @@ -1576,7 +1576,7 @@ mod tests { let mut bob = CompleteSetup::new(&rpc_builder, bob_stream); let alice_join_handle = tokio::spawn(async move { - while let Some(incoming_request) = alice + if let Some(incoming_request) = alice .server .next_request() .await