diff --git a/Cargo.toml b/Cargo.toml index 46c5069e..84812d9c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,8 +23,8 @@ resolver = "2" bytes = { version = "1.7", features = ["serde"] } futures-core = "0.3" futures-util = { version = "0.3", default-features = false, features = ["std"] } -tokio = "1.40" -tokio-tungstenite = "0.24" +tokio = "1.43" +tokio-tungstenite = "0.26" serde = { version = "1.0", features = ["derive"] } smallvec = { version = "1.13", features = ["union"] } serde_json = "1.0" diff --git a/crates/engineioxide/src/transport/ws.rs b/crates/engineioxide/src/transport/ws.rs index dbfbf202..b1ed5d91 100644 --- a/crates/engineioxide/src/transport/ws.rs +++ b/crates/engineioxide/src/transport/ws.rs @@ -16,7 +16,7 @@ use tokio::{ task::JoinHandle, }; use tokio_tungstenite::{ - tungstenite::{handshake::derive_accept_key, protocol::Role, Message}, + tungstenite::{handshake::derive_accept_key, protocol::Role, Message, Utf8Bytes}, WebSocketStream, }; @@ -27,10 +27,9 @@ use crate::{ errors::Error, handler::EngineIoHandler, packet::{OpenPacket, Packet}, - service::ProtocolVersion, - service::TransportType, + service::{ProtocolVersion, TransportType}, sid::Sid, - DisconnectReason, Socket, + DisconnectReason, Socket, Str, }; /// Create a response for websocket upgrade @@ -165,29 +164,32 @@ where { while let Some(msg) = rx.try_next().await? { match msg { - Message::Text(msg) => match Packet::try_from(msg)? { - Packet::Close => { - #[cfg(feature = "tracing")] - tracing::debug!("[sid={}] closing session", socket.id); - engine.close_session(socket.id, DisconnectReason::TransportClose); - break; - } - Packet::Pong | Packet::Ping => socket - .heartbeat_tx - .try_send(()) - .map_err(|_| Error::HeartbeatTimeout), - Packet::Message(msg) => { - engine.handler.on_message(msg, socket.clone()); - Ok(()) + Message::Text(msg) => { + match Packet::try_from(unsafe { Str::from_bytes_unchecked(msg.into()) })? { + Packet::Close => { + #[cfg(feature = "tracing")] + tracing::debug!("[sid={}] closing session", socket.id); + engine.close_session(socket.id, DisconnectReason::TransportClose); + break; + } + Packet::Pong | Packet::Ping => socket + .heartbeat_tx + .try_send(()) + .map_err(|_| Error::HeartbeatTimeout), + Packet::Message(msg) => { + engine.handler.on_message(msg, socket.clone()); + Ok(()) + } + p => return Err(Error::BadPacket(p)), } - p => return Err(Error::BadPacket(p)), - }, + } Message::Binary(mut data) => { + #[cfg(feature = "v3")] if socket.protocol == ProtocolVersion::V3 && !data.is_empty() { // The first byte is the message type, which we don't need. - let _ = data.remove(0); + data = data.slice(1..); } - engine.handler.on_binary(data.into(), socket.clone()); + engine.handler.on_binary(data, socket.clone()); Ok(()) } Message::Close(_) => break, @@ -220,12 +222,12 @@ where macro_rules! map_fn { ($item:ident) => { let res = match $item { - Packet::Binary(bin) | Packet::BinaryV3(bin) => { - let mut bin: Vec = bin.into(); - if socket.protocol == ProtocolVersion::V3 { - // v3 protocol requires packet type as the first byte - bin.insert(0, 0x04); - } + Packet::Binary(bin) => tx.feed(Message::Binary(bin)).await, + Packet::BinaryV3(bin) => { + // v3 protocol requires packet type as the first byte + let mut buf = Vec::with_capacity(bin.len() + 1); + buf.push(0x04); + buf.extend_from_slice(&bin); tx.feed(Message::Binary(bin)).await } Packet::Close => { @@ -239,7 +241,7 @@ where Packet::Noop => Ok(()), _ => { let packet: String = $item.try_into().unwrap(); - tx.feed(Message::Text(packet)).await + tx.feed(Message::Text(Utf8Bytes::from(packet))).await } }; if let Err(_e) = res { @@ -274,7 +276,8 @@ where S: AsyncRead + AsyncWrite + Unpin + Send + 'static, { let packet = Packet::Open(OpenPacket::new(TransportType::Websocket, sid, config)); - ws.send(Message::Text(packet.try_into()?)).await?; + let value: String = packet.try_into()?; + ws.send(Message::Text(Utf8Bytes::from(value))).await?; Ok(()) } @@ -317,11 +320,11 @@ where Some(Ok(Message::Text(d))) => d, _ => Err(Error::Upgrade)?, }; - match Packet::try_from(msg)? { + match Packet::try_from(unsafe { Str::from_bytes_unchecked(msg.into()) })? { Packet::PingUpgrade => { // Respond with a PongUpgrade packet - ws.send(Message::Text(Packet::PongUpgrade.try_into()?)) - .await?; + let msg: String = Packet::PongUpgrade.try_into()?; + ws.send(Message::Text(Utf8Bytes::from(msg))).await?; } p => Err(Error::BadPacket(p))?, }; @@ -343,7 +346,7 @@ where Err(Error::Upgrade)? } }; - match Packet::try_from(msg)? { + match Packet::try_from(unsafe { Str::from_bytes_unchecked(msg.into()) })? { Packet::Upgrade => { #[cfg(feature = "tracing")] tracing::debug!("ws upgraded successful")