Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply fixes for deadlock in RPC test #8

Merged
merged 4 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 99 additions & 27 deletions src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,11 @@ impl Channel {
self.outgoing_requests.len() < self.config.request_limit as usize
}

/// Returns the configured request limit for this channel.
pub fn request_limit(&self) -> u16 {
self.config.request_limit
}

/// Creates a new request, bypassing all client-side checks.
///
/// Low-level function that does nothing but create a syntactically correct request and track
Expand Down Expand Up @@ -474,7 +479,7 @@ impl Display for CompletedRead {
}
}

/// The caller of the this crate has violated the protocol.
/// The caller of this crate has violated the protocol.
///
/// A correct implementation of a client should never encounter this, thus simply unwrapping every
/// instance of this as part of a `Result<_, LocalProtocolViolation>` is usually a valid choice.
Expand All @@ -487,24 +492,45 @@ pub enum LocalProtocolViolation {
///
/// Wait for additional requests to be cancelled or answered. Calling
/// [`JulietProtocol::allowed_to_send_request()`] beforehand is recommended.
#[error("sending would exceed request limit")]
WouldExceedRequestLimit,
#[error("sending would exceed request limit of {limit}")]
WouldExceedRequestLimit {
/// The configured limit for requests on the channel.
limit: u16,
},
/// The channel given does not exist.
///
/// The given [`ChannelId`] exceeds `N` of [`JulietProtocol<N>`].
#[error("invalid channel")]
InvalidChannel(ChannelId),
#[error("channel {channel} not a member of configured {channel_count} channels")]
InvalidChannel {
/// The provided channel ID.
channel: ChannelId,
/// The configured number of channels.
channel_count: usize,
},
/// The given payload exceeds the configured limit.
///
/// See [`ChannelConfiguration::with_max_request_payload_size()`] and
/// [`ChannelConfiguration::with_max_response_payload_size()`] for details.
#[error("payload exceeds configured limit")]
PayloadExceedsLimit,
#[error("payload length of {payload_length} bytes exceeds configured limit of {limit}")]
PayloadExceedsLimit {
/// The payload length in bytes.
payload_length: usize,
/// The configured upper limit for payload length in bytes.
limit: usize,
},
/// The given error payload exceeds a single frame.
///
/// Error payloads may not span multiple frames, shorten the payload or increase frame size.
#[error("error payload would be multi-frame")]
ErrorPayloadIsMultiFrame,
#[error(
"error payload of {payload_length} bytes exceeds a single frame with configured max size \
of {max_frame_size})"
)]
ErrorPayloadIsMultiFrame {
/// The payload length in bytes.
payload_length: usize,
/// The configured maximum frame size in bytes.
max_frame_size: u32,
},
}

macro_rules! log_frame {
Expand Down Expand Up @@ -534,7 +560,10 @@ impl<const N: usize> JulietProtocol<N> {
#[inline(always)]
const fn lookup_channel(&self, channel: ChannelId) -> Result<&Channel, LocalProtocolViolation> {
if channel.0 as usize >= N {
Err(LocalProtocolViolation::InvalidChannel(channel))
Err(LocalProtocolViolation::InvalidChannel {
channel,
channel_count: N,
})
} else {
Ok(&self.channels[channel.0 as usize])
}
Expand All @@ -549,7 +578,10 @@ impl<const N: usize> JulietProtocol<N> {
channel: ChannelId,
) -> Result<&mut Channel, LocalProtocolViolation> {
if channel.0 as usize >= N {
Err(LocalProtocolViolation::InvalidChannel(channel))
Err(LocalProtocolViolation::InvalidChannel {
channel,
channel_count: N,
})
} else {
Ok(&mut self.channels[channel.0 as usize])
}
Expand Down Expand Up @@ -595,12 +627,17 @@ impl<const N: usize> JulietProtocol<N> {

if let Some(ref payload) = payload {
if payload.len() > chan.config.max_request_payload_size as usize {
return Err(LocalProtocolViolation::PayloadExceedsLimit);
return Err(LocalProtocolViolation::PayloadExceedsLimit {
payload_length: payload.len(),
limit: chan.config.max_request_payload_size as usize,
});
}
}

if !chan.allowed_to_send_request() {
return Err(LocalProtocolViolation::WouldExceedRequestLimit);
return Err(LocalProtocolViolation::WouldExceedRequestLimit {
limit: chan.request_limit(),
});
}

Ok(chan.create_unchecked_request(channel, payload))
Expand Down Expand Up @@ -637,7 +674,10 @@ impl<const N: usize> JulietProtocol<N> {

if let Some(ref payload) = payload {
if payload.len() > chan.config.max_response_payload_size as usize {
return Err(LocalProtocolViolation::PayloadExceedsLimit);
return Err(LocalProtocolViolation::PayloadExceedsLimit {
payload_length: payload.len(),
limit: chan.config.max_request_payload_size as usize,
});
}
}

Expand Down Expand Up @@ -712,11 +752,15 @@ impl<const N: usize> JulietProtocol<N> {
id: Id,
payload: Bytes,
) -> Result<OutgoingMessage, LocalProtocolViolation> {
let header = Header::new_error(header::ErrorKind::Other, channel, id);
let header = Header::new_error(ErrorKind::Other, channel, id);

let payload_length = payload.len();
let msg = OutgoingMessage::new(header, Some(payload));
if msg.is_multi_frame(self.max_frame_size) {
Err(LocalProtocolViolation::ErrorPayloadIsMultiFrame)
Err(LocalProtocolViolation::ErrorPayloadIsMultiFrame {
payload_length,
max_frame_size: self.max_frame_size.0,
})
} else {
Ok(msg)
}
Expand Down Expand Up @@ -1253,7 +1297,8 @@ mod tests {

#[test]
fn test_channel_lookups_work() {
let mut protocol: JulietProtocol<3> = ProtocolBuilder::new().build();
const CHANNEL_COUNT: usize = 3;
let mut protocol: JulietProtocol<CHANNEL_COUNT> = ProtocolBuilder::new().build();

// We mark channels by inserting an ID into them, that way we can ensure we're not getting
// back the same channel every time.
Expand All @@ -1274,15 +1319,24 @@ mod tests {
.insert(Id::new(102));
assert!(matches!(
protocol.lookup_channel_mut(ChannelId(3)),
Err(LocalProtocolViolation::InvalidChannel(ChannelId(3)))
Err(LocalProtocolViolation::InvalidChannel {
channel: ChannelId(3),
channel_count: CHANNEL_COUNT
})
));
assert!(matches!(
protocol.lookup_channel_mut(ChannelId(4)),
Err(LocalProtocolViolation::InvalidChannel(ChannelId(4)))
Err(LocalProtocolViolation::InvalidChannel {
channel: ChannelId(4),
channel_count: CHANNEL_COUNT
})
));
assert!(matches!(
protocol.lookup_channel_mut(ChannelId(255)),
Err(LocalProtocolViolation::InvalidChannel(ChannelId(255)))
Err(LocalProtocolViolation::InvalidChannel {
channel: ChannelId(255),
channel_count: CHANNEL_COUNT
})
));

// Now look up the channels and ensure they contain the right values
Expand All @@ -1309,15 +1363,24 @@ mod tests {
);
assert!(matches!(
protocol.lookup_channel(ChannelId(3)),
Err(LocalProtocolViolation::InvalidChannel(ChannelId(3)))
Err(LocalProtocolViolation::InvalidChannel {
channel: ChannelId(3),
channel_count: CHANNEL_COUNT
})
));
assert!(matches!(
protocol.lookup_channel(ChannelId(4)),
Err(LocalProtocolViolation::InvalidChannel(ChannelId(4)))
Err(LocalProtocolViolation::InvalidChannel {
channel: ChannelId(4),
channel_count: CHANNEL_COUNT
})
));
assert!(matches!(
protocol.lookup_channel(ChannelId(255)),
Err(LocalProtocolViolation::InvalidChannel(ChannelId(255)))
Err(LocalProtocolViolation::InvalidChannel {
channel: ChannelId(255),
channel_count: CHANNEL_COUNT
})
));
}

Expand Down Expand Up @@ -1442,7 +1505,10 @@ mod tests {
// Try an invalid channel, should result in an error.
assert!(matches!(
protocol.create_request(ChannelId::new(2), payload.get()),
Err(LocalProtocolViolation::InvalidChannel(ChannelId(2)))
Err(LocalProtocolViolation::InvalidChannel {
channel: ChannelId(2),
channel_count: 2
})
));

assert!(protocol
Expand All @@ -1454,7 +1520,7 @@ mod tests {

assert!(matches!(
protocol.create_request(channel, payload.get()),
Err(LocalProtocolViolation::WouldExceedRequestLimit)
Err(LocalProtocolViolation::WouldExceedRequestLimit { limit: 1 })
));
}
}
Expand Down Expand Up @@ -2202,7 +2268,10 @@ mod tests {
.create_request(env.common_channel, payload.get())
.expect_err("should not be able to create too large request");

assert_matches!(violation, LocalProtocolViolation::PayloadExceedsLimit);
assert_matches!(
violation,
LocalProtocolViolation::PayloadExceedsLimit { .. }
);

// If we force the issue, Bob must refuse it instead.
let bob_result = env.inject_and_send_request(Alice, payload.get());
Expand All @@ -2219,7 +2288,10 @@ mod tests {
.bob
.create_request(env.common_channel, payload.get())
.expect_err("should not be able to create too large response");
assert_matches!(violation, LocalProtocolViolation::PayloadExceedsLimit);
assert_matches!(
violation,
LocalProtocolViolation::PayloadExceedsLimit { .. }
);

// If we force the issue, Alice must refuse it.
let alice_result = env.inject_and_send_response(Bob, id, payload.get());
Expand Down
1 change: 1 addition & 0 deletions src/protocol/multiframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ pub(super) enum MultiframeReceiver {

/// The outcome of a multiframe acceptance.
#[derive(Debug)]
#[allow(clippy::enum_variant_names)]
pub(crate) enum CompletedFrame {
/// A new multi-frame transfer was started.
NewMultiFrame,
Expand Down
Loading