Skip to content

Commit

Permalink
refactor: Shared to use internal mutability
Browse files Browse the repository at this point in the history
  • Loading branch information
akaladarshi committed Jan 26, 2025
1 parent 3f9e8c3 commit 6e031d6
Show file tree
Hide file tree
Showing 2 changed files with 223 additions and 136 deletions.
152 changes: 80 additions & 72 deletions yamux/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ struct Active<T> {
socket: Fuse<frame::Io<T>>,
next_id: u32,

streams: IntMap<StreamId, Arc<Mutex<stream::Shared>>>,
streams: IntMap<StreamId, stream::Shared>,
stream_receivers: SelectAll<TaggedStream<StreamId, mpsc::Receiver<StreamCommand>>>,
no_streams_waker: Option<Waker>,

Expand Down Expand Up @@ -507,9 +507,8 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
let s = self.streams.remove(&stream_id).expect("stream not found");

log::trace!("{}: removing dropped stream {}", self.id, stream_id);
let frame = {
let mut shared = s.lock();
let frame = match shared.update_state(self.id, stream_id, State::Closed) {
let frame = s.with_mut(|inner| {
let frame = match inner.update_state(self.id, stream_id, State::Closed) {
// The stream was dropped without calling `poll_close`.
// We reset the stream to inform the remote of the closure.
State::Open { .. } => {
Expand Down Expand Up @@ -541,14 +540,15 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
// remote end has already done so in the past.
State::Closed => None,
};
if let Some(w) = shared.reader.take() {
if let Some(w) = inner.reader.take() {
w.wake()
}
if let Some(w) = shared.writer.take() {
if let Some(w) = inner.writer.take() {
w.wake()
}

frame
};
});
frame.map(Into::into)
}

Expand All @@ -565,10 +565,8 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
&& matches!(frame.header().tag(), Tag::Data | Tag::WindowUpdate)
{
let id = frame.header().stream_id();
if let Some(stream) = self.streams.get(&id) {
stream
.lock()
.update_state(self.id, id, State::Open { acknowledged: true });
if let Some(shared) = self.streams.get(&id) {
shared.update_state(self.id, id, State::Open { acknowledged: true });
}
if let Some(waker) = self.new_outbound_stream_waker.take() {
waker.wake();
Expand All @@ -590,14 +588,15 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
if frame.header().flags().contains(header::RST) {
// stream reset
if let Some(s) = self.streams.get_mut(&stream_id) {
let mut shared = s.lock();
shared.update_state(self.id, stream_id, State::Closed);
if let Some(w) = shared.reader.take() {
w.wake()
}
if let Some(w) = shared.writer.take() {
w.wake()
}
s.with_mut(|inner| {
inner.update_state(self.id, stream_id, State::Closed);
if let Some(w) = inner.reader.take() {
w.wake()
}
if let Some(w) = inner.writer.take() {
w.wake()
}
});
}
return Action::None;
}
Expand Down Expand Up @@ -628,35 +627,40 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
}
let stream = self.make_new_inbound_stream(stream_id, DEFAULT_CREDIT);
{
let mut shared = stream.shared();
if is_finish {
shared.update_state(self.id, stream_id, State::RecvClosed);
}
shared.consume_receive_window(frame.body_len());
shared.buffer.push(frame.into_body());
stream.shared().with_mut(|inner| {
if is_finish {
inner.update_state(self.id, stream_id, State::RecvClosed);
}
inner.consume_receive_window(frame.body_len());
inner.buffer.push(frame.into_body());
})
}
self.streams.insert(stream_id, stream.clone_shared());
return Action::New(stream);
}

if let Some(s) = self.streams.get_mut(&stream_id) {
let mut shared = s.lock();
if frame.body_len() > shared.receive_window() {
log::error!(
"{}/{}: frame body larger than window of stream",
self.id,
stream_id
);
return Action::Terminate(Frame::protocol_error());
}
if is_finish {
shared.update_state(self.id, stream_id, State::RecvClosed);
}
shared.consume_receive_window(frame.body_len());
shared.buffer.push(frame.into_body());
if let Some(w) = shared.reader.take() {
w.wake()
}
if let Some(shared) = self.streams.get_mut(&stream_id) {
let action = shared.with_mut(|inner| {
if frame.body_len() > inner.receive_window() {
log::error!(
"{}/{}: frame body larger than window of stream",
self.id,
stream_id
);
Action::Terminate(Frame::protocol_error())
} else {
if is_finish {
inner.update_state(self.id, stream_id, State::RecvClosed);
}
inner.consume_receive_window(frame.body_len());
inner.buffer.push(frame.into_body());
if let Some(w) = inner.reader.take() {
w.wake()
}
Action::None
}
});
return action;
} else {
log::trace!(
"{}/{}: data frame for unknown stream, possibly dropped earlier: {:?}",
Expand All @@ -681,15 +685,16 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {

if frame.header().flags().contains(header::RST) {
// stream reset
if let Some(s) = self.streams.get_mut(&stream_id) {
let mut shared = s.lock();
shared.update_state(self.id, stream_id, State::Closed);
if let Some(w) = shared.reader.take() {
w.wake()
}
if let Some(w) = shared.writer.take() {
w.wake()
}
if let Some(shared) = self.streams.get_mut(&stream_id) {
shared.with_mut(|inner| {
inner.update_state(self.id, stream_id, State::Closed);
if let Some(w) = inner.reader.take() {
w.wake()
}
if let Some(w) = inner.writer.take() {
w.wake()
}
});
}
return Action::None;
}
Expand Down Expand Up @@ -723,19 +728,21 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
return Action::New(stream);
}

if let Some(s) = self.streams.get_mut(&stream_id) {
let mut shared = s.lock();
shared.increase_send_window_by(frame.header().credit());
if is_finish {
shared.update_state(self.id, stream_id, State::RecvClosed);
if let Some(shared) = self.streams.get_mut(&stream_id) {
shared.with_mut(|inner| {
inner.increase_send_window_by(frame.header().credit());
if is_finish {
inner.update_state(self.id, stream_id, State::RecvClosed);

if let Some(w) = inner.reader.take() {
w.wake()
}
}

if let Some(w) = shared.reader.take() {
if let Some(w) = inner.writer.take() {
w.wake()
}
}
if let Some(w) = shared.writer.take() {
w.wake()
}
});
} else {
log::trace!(
"{}/{}: window update for unknown stream, possibly dropped earlier: {:?}",
Expand Down Expand Up @@ -848,7 +855,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
Mode::Client => id.is_client(),
Mode::Server => id.is_server(),
})
.filter(|(_, s)| s.lock().is_pending_ack())
.filter(|(_, s)| s.is_pending_ack())
.count()
}

Expand All @@ -867,15 +874,16 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
impl<T> Active<T> {
/// Close and drop all `Stream`s and wake any pending `Waker`s.
fn drop_all_streams(&mut self) {
for (id, s) in self.streams.drain() {
let mut shared = s.lock();
shared.update_state(self.id, id, State::Closed);
if let Some(w) = shared.reader.take() {
w.wake()
}
if let Some(w) = shared.writer.take() {
w.wake()
}
for (id, shared) in self.streams.drain() {
shared.with_mut(|inner| {
inner.update_state(self.id, id, State::Closed);
if let Some(w) = inner.reader.take() {
w.wake()
}
if let Some(w) = inner.writer.take() {
w.wake()
}
});
}
}
}
Loading

0 comments on commit 6e031d6

Please sign in to comment.