diff --git a/clash/tests/data/config/wg.yaml b/clash/tests/data/config/wg.yaml index 4cc12a0d3..3e93f3399 100644 --- a/clash/tests/data/config/wg.yaml +++ b/clash/tests/data/config/wg.yaml @@ -33,7 +33,6 @@ dns: - 8.8.8.8 # default value - tls://dns.google:853 # DNS over TLS - https://1.1.1.1/dns-query # DNS over HTTPS - - dhcp://en0 # dns from dhcp allow-lan: true mode: rule diff --git a/clash_lib/src/proxy/wg/device.rs b/clash_lib/src/proxy/wg/device.rs index 4e6f73ddb..4f81aa2a3 100644 --- a/clash_lib/src/proxy/wg/device.rs +++ b/clash_lib/src/proxy/wg/device.rs @@ -55,10 +55,7 @@ impl DeviceManager { .lock() .await .push((socket, remote, read_pair.0, write_pair.1)); - SocketPair { - read: read_pair.1, - write: write_pair.0, - } + SocketPair::new(read_pair.1, write_pair.0) } pub async fn poll_sockets(&self, mut device: VirtualIpDevice) { @@ -71,13 +68,12 @@ impl DeviceManager { }); loop { + let mut need_wait = true; + let timestamp = Instant::now(); let mut sockets = self.socket_set.lock().await; let mut socket_pairs = self.socket_pairs.lock().await; - trace!("polling active socket: {}", socket_pairs.len()); - iface.poll(timestamp, &mut device, &mut sockets); - let mut unconnected_sockets = self.unconnected_sockets.lock().await; let unconnected_sockets = unconnected_sockets.drain(0..); for (mut socket, remote, sender, receiver) in unconnected_sockets { @@ -91,20 +87,23 @@ impl DeviceManager { let handle = sockets.add(socket); socket_pairs.insert(handle, (sender, receiver)); + + need_wait = false; } + iface.poll(timestamp, &mut device, &mut sockets); + for (handle, (sender, receiver)) in socket_pairs.iter_mut() { let socket = sockets.get_mut::(*handle); if socket.may_recv() { - match socket.recv(|data| (data.len(), data)) { - Ok(data) => match sender.try_send(data.to_owned().into()) { - Ok(_) => { - trace!("data forwared to socket: {:?}", socket); - } + match socket.recv(|data| (data.len(), data.to_owned())) { + Ok(data) if !data.is_empty() => match sender.try_send(data.into()) { + Ok(_) => {} Err(e) => { warn!("failed to send tcp packet: {:?}", e); } }, + Ok(_) => {} Err(RecvError::Finished) => { warn!("tcp socket finished"); continue; @@ -113,27 +112,35 @@ impl DeviceManager { warn!("failed to receive tcp packet: {:?}", e); } } + need_wait = false; } if socket.may_send() { match receiver.try_recv() { Ok(data) => match socket.send_slice(&data) { - Ok(_) => {} + Ok(n) => { + trace!("sent {} bytes, total: {}", n, data.len()); + if n != data.len() { + error!("fix me"); + } + } Err(e) => { warn!("failed to send tcp packet: {:?}", e); } }, - Err(e) => { - warn!("failed to receive tcp packet: {:?}", e); - } + Err(_) => {} } + need_wait = false; } } - match iface.poll_delay(timestamp, &sockets) { - Some(delay) => { - tokio::time::sleep(delay.into()).await; + if need_wait { + match iface.poll_delay(timestamp, &sockets) { + Some(delay) => { + trace!("device poll delay: {:?}", delay); + tokio::time::sleep(delay.into()).await; + } + None => {} } - None => {} } let mut port_to_release = Vec::new(); @@ -201,11 +208,9 @@ impl Device for VirtualIpDevice { &mut self, timestamp: smoltcp::time::Instant, ) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> { - trace!("wg tun polling data"); let next = self.packet_receiver.try_recv().ok(); match next { Some((proto, data)) => { - trace!("wg tun received packet"); let rx_token = RxToken { buffer: { let mut buffer = BytesMut::new(); @@ -223,7 +228,6 @@ impl Device for VirtualIpDevice { } fn transmit(&mut self, timestamp: smoltcp::time::Instant) -> Option> { - trace!("wg tun writing data"); Some(TxToken { sender: self.packet_sender.clone(), }) diff --git a/clash_lib/src/proxy/wg/stack/tcp.rs b/clash_lib/src/proxy/wg/stack/tcp.rs index cdf4f4e3a..ef7df72b5 100644 --- a/clash_lib/src/proxy/wg/stack/tcp.rs +++ b/clash_lib/src/proxy/wg/stack/tcp.rs @@ -1,6 +1,6 @@ use std::fmt::Debug; -use bytes::Bytes; +use bytes::{Bytes, BytesMut}; use tokio::{ io::{AsyncRead, AsyncWrite}, @@ -12,24 +12,52 @@ use tracing::trace; pub struct SocketPair { pub read: Receiver, pub write: Sender, + + read_buf: BytesMut, +} + +impl SocketPair { + pub fn new(read: Receiver, write: Sender) -> Self { + Self { + read, + write, + read_buf: BytesMut::new(), + } + } } impl AsyncRead for SocketPair { fn poll_read( mut self: std::pin::Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, + cx: &mut std::task::Context<'_>, buf: &mut tokio::io::ReadBuf<'_>, ) -> std::task::Poll> { - match self.read.try_recv() { - Ok(data) => { - trace!("tcp socket received: {:?}", data); - buf.put_slice(&data); + if !self.read_buf.is_empty() { + let len = std::cmp::min(self.read_buf.len(), buf.remaining()); + buf.put_slice(&self.read_buf.split_to(len)); + trace!( + "reusing cached data sent {}, left {}", + len, + self.read_buf.len() + ); + return std::task::Poll::Ready(Ok(())); + } + + match self.read.poll_recv(cx) { + std::task::Poll::Ready(Some(data)) => { + let len = std::cmp::min(data.len(), buf.remaining()); + buf.put_slice(&data[..len]); + self.read_buf.extend_from_slice(&data[len..]); + trace!( + "socket got {} data, sent {}, left {}", + data.len(), + len, + self.read_buf.len() + ); std::task::Poll::Ready(Ok(())) } - Err(_) => { - trace!("no data ready"); - std::task::Poll::Pending - } + std::task::Poll::Ready(None) => std::task::Poll::Ready(Ok(())), + std::task::Poll::Pending => std::task::Poll::Pending, } } } diff --git a/clash_lib/src/proxy/wg/wireguard.rs b/clash_lib/src/proxy/wg/wireguard.rs index 41189c7ba..d20913316 100644 --- a/clash_lib/src/proxy/wg/wireguard.rs +++ b/clash_lib/src/proxy/wg/wireguard.rs @@ -115,22 +115,14 @@ impl WireguardTunnel { let mut packet_reader = self.packet_reader.lock().await; loop { - let timeouted_recv = - tokio::time::timeout(Duration::from_secs(30), packet_reader.recv()); - - match timeouted_recv.await { - Ok(Some(packet)) => { + match packet_reader.recv().await { + Some(packet) => { if let Err(e) = self.send_ip_packet(&packet).await { error!("failed to send packet: {}", e); } } - Ok(None) => { - trace!("connection closed, stopping"); - break; - } - Err(e) => { - trace!("no active connection, stopping: {e:?}"); - tokio::time::sleep(Duration::from_millis(1)).await; + None => { + trace!("no active connection, stopping"); break; } } @@ -148,24 +140,14 @@ impl WireguardTunnel { pub async fn start_receiving(&self) { loop { - trace!("wg stack receiving data"); - let mut recv_buf = [0u8; 65535]; let mut send_buf = [0u8; 65535]; - let timeouted_recv = - tokio::time::timeout(Duration::from_secs(30), self.udp.recv(&mut recv_buf)); - - let size = match timeouted_recv.await { - Ok(Ok(size)) => size, - Ok(Err(e)) => { - error!("failed to receive packet: {e:?}"); - tokio::time::sleep(Duration::from_millis(1)).await; - continue; - } + let size = match self.udp.recv(&mut recv_buf).await { + Ok(size) => size, Err(e) => { - trace!("no active connection, stopping: {e:?}"); - break; + error!("failed to receive packet: {}", e); + continue; } }; let mut peer = self.peer.lock().await; @@ -256,22 +238,28 @@ impl WireguardTunnel { /// Determine the inner protocol of the incoming IP packet (TCP/UDP). fn route_protocol(&self, packet: &[u8]) -> Option { match IpVersion::of_packet(packet) { - Ok(IpVersion::Ipv4) => Ipv4Packet::new_checked(&packet).ok().and_then(|packet| { - match packet.next_header() { - IpProtocol::Tcp => Some(PortProtocol::Tcp), - IpProtocol::Udp => Some(PortProtocol::Udp), - // Unrecognized protocol, so we cannot determine where to route - _ => None, - } - }), - Ok(IpVersion::Ipv6) => Ipv6Packet::new_checked(&packet).ok().and_then(|packet| { - match packet.next_header() { - IpProtocol::Tcp => Some(PortProtocol::Tcp), - IpProtocol::Udp => Some(PortProtocol::Udp), - // Unrecognized protocol, so we cannot determine where to route - _ => None, - } - }), + Ok(IpVersion::Ipv4) => Ipv4Packet::new_checked(&packet) + .ok() + .filter(|packet| Ipv4Addr::from(packet.dst_addr()) == self.source_peer_ip) + .and_then(|packet| { + match packet.next_header() { + IpProtocol::Tcp => Some(PortProtocol::Tcp), + IpProtocol::Udp => Some(PortProtocol::Udp), + // Unrecognized protocol, so we cannot determine where to route + _ => None, + } + }), + Ok(IpVersion::Ipv6) => Ipv6Packet::new_checked(&packet) + .ok() + .filter(|packet| Ipv6Addr::from(packet.dst_addr()) == self.source_peer_ip) + .and_then(|packet| { + match packet.next_header() { + IpProtocol::Tcp => Some(PortProtocol::Tcp), + IpProtocol::Udp => Some(PortProtocol::Udp), + // Unrecognized protocol, so we cannot determine where to route + _ => None, + } + }), _ => None, } }