Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix large volume test #7

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions src/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -559,27 +559,29 @@ where
let header_sent = frame_sent.header();

// If we finished the active multi frame send, clear it.
let mut cleared_multi_frame = false;
if was_final {
let channel_idx = header_sent.channel().get() as usize;
if let Some(ref active_multi_frame) =
self.active_multi_frame[channel_idx] {
if header_sent == *active_multi_frame {
self.active_multi_frame[channel_idx] = None;
cleared_multi_frame = true;
}
}
}
};

if header_sent.is_error() {
// We finished sending an error frame, time to exit.
return Err(CoreError::RemoteProtocolViolation(header_sent));
}

// TODO: We should restrict the dirty-queue processing here a little bit
// (only check when completing a multi-frame message).
// A message has completed sending, process the wait queue in case we have
// to start sending a multi-frame message like a response that was delayed
// only because of the one-multi-frame-per-channel restriction.
self.process_wait_queue(header_sent.channel())?;
if cleared_multi_frame {
self.process_wait_queue(header_sent.channel())?;
}
} else {
#[cfg(feature = "tracing")]
tracing::error!("current frame should not disappear");
Expand Down Expand Up @@ -719,6 +721,16 @@ where

/// Handles a new item to send out that arrived through the incoming channel.
fn handle_incoming_item(&mut self, item: QueuedItem) -> Result<(), LocalProtocolViolation> {
// Process the wait queue to avoid this new item "jumping the queue".
match &item {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1dff662 seems to add starvation protection, i.e. newer data cannot consistently get in front of existing data. Was this behavior observed to be problematic?

My core issue with this is that if we process the wait queue each time anyway, it might be better to not even check if we can bypass it and just put everything in the wait queue every time. However, processing the wait queue is expensive, especially if the previously mentioned change is made. Queuing messages will then result in quadratic complexity!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was problematic in that Alice's requests started timing out, as ones in the wait queue weren't processed since newer ones kept getting preferential treatment.

I did consider just dumping everything in the wait queue, however I had the same reservation as you about the cost (and it also seemed to be somewhat abusing the intent of the wait queue - it would at least need renamed for clarity I think if we did that).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the very least add QueuedItem::is_request :)

This may be less of an issue if the "new" WaitQueue (see above) is added.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the very least add QueuedItem::is_request :)

I can do, but tbh I don't see why we'd want that or where we'd use it?

QueuedItem::Request { channel, .. } | QueuedItem::Response { channel, .. } => {
self.process_wait_queue(*channel)?
}
QueuedItem::RequestCancellation { .. }
| QueuedItem::ResponseCancellation { .. }
| QueuedItem::Error { .. } => {}
}

// Check if the item is sendable immediately.
if let Some(channel) = item_should_wait(&item, &self.juliet, &self.active_multi_frame)? {
#[cfg(feature = "tracing")]
Expand All @@ -745,6 +757,7 @@ where
let id = msg.header().id();
self.request_map.insert(io_id, (channel, id));
if msg.is_multi_frame(self.juliet.max_frame_size()) {
debug_assert!(self.active_multi_frame[channel.get() as usize].is_none());
self.active_multi_frame[channel.get() as usize] = Some(msg.header());
}
self.ready_queue.push_back(msg.frames());
Expand All @@ -771,6 +784,7 @@ where
} => {
if let Some(msg) = self.juliet.create_response(channel, id, payload)? {
if msg.is_multi_frame(self.juliet.max_frame_size()) {
debug_assert!(self.active_multi_frame[channel.get() as usize].is_none());
self.active_multi_frame[channel.get() as usize] = Some(msg.header());
}
self.ready_queue.push_back(msg.frames())
Expand Down Expand Up @@ -835,11 +849,6 @@ where
self.wait_queue[channel.get() as usize].push_back(item);
} else {
self.send_to_ready_queue(item)?;

// No need to look further if we have saturated the channel.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did 0253804#diff-76866598ce8fd16261a27ac58a84b2825e6e77fc37c163a6afa60f0f4477e569L852-L856 fix an issue? The code was supposed to bring down the potential $O(n^2)$
total complexity of processing the queue
times. What's the case that triggers this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It wasn't an issue exposed via a test. Rather I thought it was a bug while following the logic during debugging.

The issue is that the wait queue can have not only requests but responses, so it would be wrong to exit early in the case where a bunch of responses could have been moved out of the wait queue.

As an aside, I wonder if it would be worthwhile creating a new enum just for the wait queue, similar to QueuedItem but with only Request and Response variants?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As an aside, I wonder if it would be worthwhile creating a new enum just for the wait queue, similar to QueuedItem but with only Request and Response variants?

My guess is that the problem is likely best solved with two separate queues, one for requests and one for large messages (although I am not entirely sure yet how to handle the case where a message is both large and a request). Alternatively, we should keep some sort of state to ensure it can distinguish these cases quickly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the code being as-is in this PR, I'm still uncomfortable with the situation. Imagine queuing single-frame messages at a very high rate. Once we have saturated the ready queue, they will all go into the wait queue, and every call will process the now-growing entire queue.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could consider adding this:

struct WaitSubQueue {
    single_frame: VecDeque<QueuedItem>,
    multi_frame: VecDeque<QueuedItem>,
}

struct WaitQueue {
    requests: WaitSubQueue,
    other: Vec<QueuedItem>,
    prefer_request: bool,
}

impl WaitSubQueue {
    #[inline(always)]
    fn next_item(&mut self, allow_multi_frame: bool) -> Option<QueuedItem> {
        if allow_multi_frame && !self.multi_frame.is_empty() {
            self.multi_frame.pop_front()
        } else {
            self.singe_frame.pop_front()
        }
    }
}

impl WaitQueue {
    pub fn next_item(
        &mut self,
        request_allowed: bool,
        multiframe_allowed: bool,
    ) -> Option<QueuedItem> {
        if request_allowed {
            self.next_item_allowing_request(multiframe_allowed)
        } else {
            self.other.next_item()
        }
    }

    /// Returns the next item, assuming a request is allowed.
    // Note: This function is separate out for readability.
    #[inline(always)]
    fn next_item_allowing_request(&mut self, multiframe_allowed: bool) {
        let candidate = if prefer_request {
            self.requests
                .next_item(multiframe_allowed)
                .or_else(|| self.other.next_item(multiframe_allowed))
        } else {
            self.other
                .next_item(multiframe_allowed)
                .or_else(|| self.requests.next_item(multiframe_allowed))
        }?;

        // Alternate, to prevent starvation is receiver is procesing at a rate
        // that matches our production rate. This essentially subdivides the
        // channel into request/non-request subchannels.
        self.prefer_request = !candidate.is_request();
        Some(candidate)
    }
}

Since the logic gets more complex, it would be wise to separate it out. This is just a sketch, at least some comments would need to be filled in.

The key idea is to know what kind of item we can produce next by checking the state of our multiframe sends and request limits, then use the separated queue to optimize. This reorders items that weren't reordered before by separating the queues.

if !self.juliet.allowed_to_send_request(channel)? {
break;
}
}

// Ensure we do not loop endlessly if we cannot find anything.
Expand Down Expand Up @@ -867,6 +876,8 @@ fn item_should_wait<const N: usize>(
} => {
// Check if we cannot schedule due to the message exceeding the request limit.
if !juliet.allowed_to_send_request(*channel)? {
#[cfg(feature = "tracing")]
tracing::trace!(%channel, %item, "item should wait: channel full");
return Ok(Some(*channel));
}

Expand All @@ -889,6 +900,8 @@ fn item_should_wait<const N: usize>(
if active_multi_frame.is_some() {
if let Some(payload) = payload {
if payload_is_multi_frame(juliet.max_frame_size(), payload.len()) {
#[cfg(feature = "tracing")]
tracing::trace!(%channel, %item, "item should wait: multiframe in progress");
return Ok(Some(*channel));
}
}
Expand Down
126 changes: 99 additions & 27 deletions src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,11 @@ impl Channel {
self.outgoing_requests.len() < self.config.request_limit as usize
}

/// Returns the configured request limit for this channel.
pub fn request_limit(&self) -> u16 {
self.config.request_limit
}

/// Creates a new request, bypassing all client-side checks.
///
/// Low-level function that does nothing but create a syntactically correct request and track
Expand Down Expand Up @@ -474,7 +479,7 @@ impl Display for CompletedRead {
}
}

/// The caller of the this crate has violated the protocol.
/// The caller of this crate has violated the protocol.
///
/// A correct implementation of a client should never encounter this, thus simply unwrapping every
/// instance of this as part of a `Result<_, LocalProtocolViolation>` is usually a valid choice.
Expand All @@ -487,24 +492,45 @@ pub enum LocalProtocolViolation {
///
/// Wait for additional requests to be cancelled or answered. Calling
/// [`JulietProtocol::allowed_to_send_request()`] beforehand is recommended.
#[error("sending would exceed request limit")]
WouldExceedRequestLimit,
#[error("sending would exceed request limit of {limit}")]
WouldExceedRequestLimit {
/// The configured limit for requests on the channel.
limit: u16,
},
/// The channel given does not exist.
///
/// The given [`ChannelId`] exceeds `N` of [`JulietProtocol<N>`].
#[error("invalid channel")]
InvalidChannel(ChannelId),
#[error("channel {channel} not a member of configured {channel_count} channels")]
InvalidChannel {
/// The provided channel ID.
channel: ChannelId,
/// The configured number of channels.
channel_count: usize,
},
/// The given payload exceeds the configured limit.
///
/// See [`ChannelConfiguration::with_max_request_payload_size()`] and
/// [`ChannelConfiguration::with_max_response_payload_size()`] for details.
#[error("payload exceeds configured limit")]
PayloadExceedsLimit,
#[error("payload length of {payload_length} bytes exceeds configured limit of {limit}")]
PayloadExceedsLimit {
marc-casperlabs marked this conversation as resolved.
Show resolved Hide resolved
/// The payload length in bytes.
payload_length: usize,
/// The configured upper limit for payload length in bytes.
limit: usize,
},
/// The given error payload exceeds a single frame.
///
/// Error payloads may not span multiple frames, shorten the payload or increase frame size.
#[error("error payload would be multi-frame")]
ErrorPayloadIsMultiFrame,
#[error(
"error payload of {payload_length} bytes exceeds a single frame with configured max size \
of {max_frame_size})"
)]
ErrorPayloadIsMultiFrame {
/// The payload length in bytes.
payload_length: usize,
/// The configured maximum frame size in bytes.
max_frame_size: u32,
},
}

macro_rules! log_frame {
Expand Down Expand Up @@ -534,7 +560,10 @@ impl<const N: usize> JulietProtocol<N> {
#[inline(always)]
const fn lookup_channel(&self, channel: ChannelId) -> Result<&Channel, LocalProtocolViolation> {
if channel.0 as usize >= N {
Err(LocalProtocolViolation::InvalidChannel(channel))
Err(LocalProtocolViolation::InvalidChannel {
channel,
channel_count: N,
})
} else {
Ok(&self.channels[channel.0 as usize])
}
Expand All @@ -549,7 +578,10 @@ impl<const N: usize> JulietProtocol<N> {
channel: ChannelId,
) -> Result<&mut Channel, LocalProtocolViolation> {
if channel.0 as usize >= N {
Err(LocalProtocolViolation::InvalidChannel(channel))
Err(LocalProtocolViolation::InvalidChannel {
channel,
channel_count: N,
})
} else {
Ok(&mut self.channels[channel.0 as usize])
}
Expand Down Expand Up @@ -595,12 +627,17 @@ impl<const N: usize> JulietProtocol<N> {

if let Some(ref payload) = payload {
if payload.len() > chan.config.max_request_payload_size as usize {
return Err(LocalProtocolViolation::PayloadExceedsLimit);
return Err(LocalProtocolViolation::PayloadExceedsLimit {
payload_length: payload.len(),
limit: chan.config.max_request_payload_size as usize,
});
}
}

if !chan.allowed_to_send_request() {
return Err(LocalProtocolViolation::WouldExceedRequestLimit);
return Err(LocalProtocolViolation::WouldExceedRequestLimit {
limit: chan.request_limit(),
});
}

Ok(chan.create_unchecked_request(channel, payload))
Expand Down Expand Up @@ -637,7 +674,10 @@ impl<const N: usize> JulietProtocol<N> {

if let Some(ref payload) = payload {
if payload.len() > chan.config.max_response_payload_size as usize {
return Err(LocalProtocolViolation::PayloadExceedsLimit);
return Err(LocalProtocolViolation::PayloadExceedsLimit {
payload_length: payload.len(),
limit: chan.config.max_request_payload_size as usize,
});
}
}

Expand Down Expand Up @@ -712,11 +752,15 @@ impl<const N: usize> JulietProtocol<N> {
id: Id,
payload: Bytes,
) -> Result<OutgoingMessage, LocalProtocolViolation> {
let header = Header::new_error(header::ErrorKind::Other, channel, id);
let header = Header::new_error(ErrorKind::Other, channel, id);

let payload_length = payload.len();
let msg = OutgoingMessage::new(header, Some(payload));
if msg.is_multi_frame(self.max_frame_size) {
Err(LocalProtocolViolation::ErrorPayloadIsMultiFrame)
Err(LocalProtocolViolation::ErrorPayloadIsMultiFrame {
payload_length,
max_frame_size: self.max_frame_size.0,
})
} else {
Ok(msg)
}
Expand Down Expand Up @@ -1253,7 +1297,8 @@ mod tests {

#[test]
fn test_channel_lookups_work() {
let mut protocol: JulietProtocol<3> = ProtocolBuilder::new().build();
const CHANNEL_COUNT: usize = 3;
let mut protocol: JulietProtocol<CHANNEL_COUNT> = ProtocolBuilder::new().build();

// We mark channels by inserting an ID into them, that way we can ensure we're not getting
// back the same channel every time.
Expand All @@ -1274,15 +1319,24 @@ mod tests {
.insert(Id::new(102));
assert!(matches!(
protocol.lookup_channel_mut(ChannelId(3)),
Err(LocalProtocolViolation::InvalidChannel(ChannelId(3)))
Err(LocalProtocolViolation::InvalidChannel {
channel: ChannelId(3),
channel_count: CHANNEL_COUNT
})
));
assert!(matches!(
protocol.lookup_channel_mut(ChannelId(4)),
Err(LocalProtocolViolation::InvalidChannel(ChannelId(4)))
Err(LocalProtocolViolation::InvalidChannel {
channel: ChannelId(4),
channel_count: CHANNEL_COUNT
})
));
assert!(matches!(
protocol.lookup_channel_mut(ChannelId(255)),
Err(LocalProtocolViolation::InvalidChannel(ChannelId(255)))
Err(LocalProtocolViolation::InvalidChannel {
channel: ChannelId(255),
channel_count: CHANNEL_COUNT
})
));

// Now look up the channels and ensure they contain the right values
Expand All @@ -1309,15 +1363,24 @@ mod tests {
);
assert!(matches!(
protocol.lookup_channel(ChannelId(3)),
Err(LocalProtocolViolation::InvalidChannel(ChannelId(3)))
Err(LocalProtocolViolation::InvalidChannel {
channel: ChannelId(3),
channel_count: CHANNEL_COUNT
})
));
assert!(matches!(
protocol.lookup_channel(ChannelId(4)),
Err(LocalProtocolViolation::InvalidChannel(ChannelId(4)))
Err(LocalProtocolViolation::InvalidChannel {
channel: ChannelId(4),
channel_count: CHANNEL_COUNT
})
));
assert!(matches!(
protocol.lookup_channel(ChannelId(255)),
Err(LocalProtocolViolation::InvalidChannel(ChannelId(255)))
Err(LocalProtocolViolation::InvalidChannel {
channel: ChannelId(255),
channel_count: CHANNEL_COUNT
})
));
}

Expand Down Expand Up @@ -1442,7 +1505,10 @@ mod tests {
// Try an invalid channel, should result in an error.
assert!(matches!(
protocol.create_request(ChannelId::new(2), payload.get()),
Err(LocalProtocolViolation::InvalidChannel(ChannelId(2)))
Err(LocalProtocolViolation::InvalidChannel {
channel: ChannelId(2),
channel_count: 2
})
));

assert!(protocol
Expand All @@ -1454,7 +1520,7 @@ mod tests {

assert!(matches!(
protocol.create_request(channel, payload.get()),
Err(LocalProtocolViolation::WouldExceedRequestLimit)
Err(LocalProtocolViolation::WouldExceedRequestLimit { limit: 1 })
));
}
}
Expand Down Expand Up @@ -2202,7 +2268,10 @@ mod tests {
.create_request(env.common_channel, payload.get())
.expect_err("should not be able to create too large request");

assert_matches!(violation, LocalProtocolViolation::PayloadExceedsLimit);
assert_matches!(
violation,
LocalProtocolViolation::PayloadExceedsLimit { .. }
);

// If we force the issue, Bob must refuse it instead.
let bob_result = env.inject_and_send_request(Alice, payload.get());
Expand All @@ -2219,7 +2288,10 @@ mod tests {
.bob
.create_request(env.common_channel, payload.get())
.expect_err("should not be able to create too large response");
assert_matches!(violation, LocalProtocolViolation::PayloadExceedsLimit);
assert_matches!(
violation,
LocalProtocolViolation::PayloadExceedsLimit { .. }
);

// If we force the issue, Alice must refuse it.
let alice_result = env.inject_and_send_response(Bob, id, payload.get());
Expand Down
1 change: 1 addition & 0 deletions src/protocol/multiframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ pub(super) enum MultiframeReceiver {

/// The outcome of a multiframe acceptance.
#[derive(Debug)]
#[allow(clippy::enum_variant_names)]
pub(crate) enum CompletedFrame {
/// A new multi-frame transfer was started.
NewMultiFrame,
Expand Down
Loading