diff --git a/ext/websocket/01_websocket.js b/ext/websocket/01_websocket.js index 78572f5f002fab..737643c0742c1f 100644 --- a/ext/websocket/01_websocket.js +++ b/ext/websocket/01_websocket.js @@ -429,6 +429,7 @@ class WebSocket extends EventTarget { const rid = this[_rid]; while (this[_readyState] !== CLOSED) { const kind = await op_ws_next_event(rid); + /* close the connection if read was cancelled, and we didn't get a close frame */ if ( (this[_readyState] == CLOSING) && @@ -442,6 +443,10 @@ class WebSocket extends EventTarget { break; } + if (kind == null) { + break; + } + switch (kind) { case 0: { /* string */ diff --git a/ext/websocket/lib.rs b/ext/websocket/lib.rs index 5aef1a7a550646..d27fcb59a1215f 100644 --- a/ext/websocket/lib.rs +++ b/ext/websocket/lib.rs @@ -8,6 +8,7 @@ use deno_core::url; use deno_core::AsyncMutFuture; use deno_core::AsyncRefCell; use deno_core::ByteString; +use deno_core::CancelFuture; use deno_core::CancelHandle; use deno_core::CancelTryFuture; use deno_core::JsBuffer; @@ -532,6 +533,7 @@ pub struct ServerWebSocket { string: Cell>, ws_read: AsyncRefCell>>, ws_write: AsyncRefCell>>, + cancel_handle: Rc, } impl ServerWebSocket { @@ -546,6 +548,7 @@ impl ServerWebSocket { string: Cell::new(None), ws_read: AsyncRefCell::new(FragmentCollectorRead::new(ws_read)), ws_write: AsyncRefCell::new(ws_write), + cancel_handle: CancelHandle::new_rc(), } } @@ -752,7 +755,7 @@ pub async fn op_ws_close( let Ok(resource) = state .borrow_mut() .resource_table - .get::(rid) + .take::(rid) else { return Ok(()); }; @@ -762,6 +765,8 @@ pub async fn op_ws_close( .unwrap_or_else(|| Frame::close_raw(vec![].into())); resource.closed.set(true); + + resource.cancel_handle.cancel(); let lock = resource.reserve_lock(); resource.write_frame(lock, frame).await } @@ -804,19 +809,19 @@ pub fn op_ws_get_error(state: &mut OpState, #[smi] rid: ResourceId) -> String { pub async fn op_ws_next_event( state: Rc>, #[smi] rid: ResourceId, -) -> u16 { +) -> Option { let Ok(resource) = state .borrow_mut() .resource_table .get::(rid) else { // op_ws_get_error will correctly handle a bad resource - return MessageKind::Error as u16; + return Some(MessageKind::Error as u16); }; // If there's a pending error, this always returns error if resource.errored.get() { - return MessageKind::Error as u16; + return Some(MessageKind::Error as u16); } let mut ws = RcRef::map(&resource, |r| &r.ws_read).borrow_mut().await; @@ -825,19 +830,26 @@ pub async fn op_ws_next_event( let writer = writer.clone(); async move { writer.borrow_mut().await.write_frame(frame).await } }; + let cancel_handle = resource.cancel_handle.clone(); loop { - let res = ws.read_frame(&mut sender).await; + let Ok(res) = ws + .read_frame(&mut sender) + .or_cancel(cancel_handle.clone()) + .await + else { + return None; + }; let val = match res { Ok(val) => val, Err(err) => { // No message was received, socket closed while we waited. // Report closed status to JavaScript. if resource.closed.get() { - return MessageKind::ClosedDefault as u16; + return Some(MessageKind::ClosedDefault as u16); } resource.set_error(Some(err.to_string())); - return MessageKind::Error as u16; + return Some(MessageKind::Error as u16); } }; @@ -845,22 +857,22 @@ pub async fn op_ws_next_event( OpCode::Text => match String::from_utf8(val.payload.to_vec()) { Ok(s) => { resource.string.set(Some(s)); - MessageKind::Text as u16 + Some(MessageKind::Text as u16) } Err(_) => { resource.set_error(Some("Invalid string data".into())); - MessageKind::Error as u16 + Some(MessageKind::Error as u16) } }, OpCode::Binary => { resource.buffer.set(Some(val.payload.to_vec())); - MessageKind::Binary as u16 + Some(MessageKind::Binary as u16) } OpCode::Close => { // Close reason is returned through error if val.payload.len() < 2 { resource.set_error(None); - MessageKind::ClosedDefault as u16 + Some(MessageKind::ClosedDefault as u16) } else { let close_code = CloseCode::from(u16::from_be_bytes([ val.payload[0], @@ -868,10 +880,10 @@ pub async fn op_ws_next_event( ])); let reason = String::from_utf8(val.payload[2..].to_vec()).ok(); resource.set_error(reason); - close_code.into() + Some(close_code.into()) } } - OpCode::Pong => MessageKind::Pong as u16, + OpCode::Pong => Some(MessageKind::Pong as u16), OpCode::Continuation | OpCode::Ping => { continue; } diff --git a/tests/unit/websocket_test.ts b/tests/unit/websocket_test.ts index d9878828db4d97..3dad8fa8bd4b67 100644 --- a/tests/unit/websocket_test.ts +++ b/tests/unit/websocket_test.ts @@ -821,3 +821,12 @@ Deno.test("send to a closed socket", async () => { }; await promise; }); + +Deno.test(async function websocketDoesntLeak() { + const { promise, resolve } = Promise.withResolvers(); + const ws = new WebSocket(new URL("ws://localhost:4242/")); + assertEquals(ws.url, "ws://localhost:4242/"); + ws.onopen = () => resolve(); + await promise; + ws.close(); +});