Skip to content

Commit

Permalink
TCP working
Browse files Browse the repository at this point in the history
  • Loading branch information
ibigbug committed Dec 18, 2023
1 parent beb3078 commit f56ed43
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 76 deletions.
1 change: 0 additions & 1 deletion clash/tests/data/config/wg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ dns:
- 8.8.8.8 # default value
- tls://dns.google:853 # DNS over TLS
- https://1.1.1.1/dns-query # DNS over HTTPS
- dhcp://en0 # dns from dhcp

allow-lan: true
mode: rule
Expand Down
50 changes: 27 additions & 23 deletions clash_lib/src/proxy/wg/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,7 @@ impl DeviceManager {
.lock()
.await
.push((socket, remote, read_pair.0, write_pair.1));
SocketPair {
read: read_pair.1,
write: write_pair.0,
}
SocketPair::new(read_pair.1, write_pair.0)
}

pub async fn poll_sockets(&self, mut device: VirtualIpDevice) {
Expand All @@ -71,13 +68,12 @@ impl DeviceManager {
});

loop {
let mut need_wait = true;

let timestamp = Instant::now();
let mut sockets = self.socket_set.lock().await;
let mut socket_pairs = self.socket_pairs.lock().await;

trace!("polling active socket: {}", socket_pairs.len());
iface.poll(timestamp, &mut device, &mut sockets);

let mut unconnected_sockets = self.unconnected_sockets.lock().await;
let unconnected_sockets = unconnected_sockets.drain(0..);
for (mut socket, remote, sender, receiver) in unconnected_sockets {
Expand All @@ -91,20 +87,23 @@ impl DeviceManager {

let handle = sockets.add(socket);
socket_pairs.insert(handle, (sender, receiver));

need_wait = false;
}

iface.poll(timestamp, &mut device, &mut sockets);

for (handle, (sender, receiver)) in socket_pairs.iter_mut() {
let socket = sockets.get_mut::<tcp::Socket>(*handle);
if socket.may_recv() {
match socket.recv(|data| (data.len(), data)) {
Ok(data) => match sender.try_send(data.to_owned().into()) {
Ok(_) => {
trace!("data forwared to socket: {:?}", socket);
}
match socket.recv(|data| (data.len(), data.to_owned())) {
Ok(data) if !data.is_empty() => match sender.try_send(data.into()) {
Ok(_) => {}
Err(e) => {
warn!("failed to send tcp packet: {:?}", e);
}
},
Ok(_) => {}
Err(RecvError::Finished) => {
warn!("tcp socket finished");
continue;
Expand All @@ -113,27 +112,35 @@ impl DeviceManager {
warn!("failed to receive tcp packet: {:?}", e);
}
}
need_wait = false;
}
if socket.may_send() {
match receiver.try_recv() {
Ok(data) => match socket.send_slice(&data) {
Ok(_) => {}
Ok(n) => {
trace!("sent {} bytes, total: {}", n, data.len());
if n != data.len() {
error!("fix me");
}
}
Err(e) => {
warn!("failed to send tcp packet: {:?}", e);
}
},
Err(e) => {
warn!("failed to receive tcp packet: {:?}", e);
}
Err(_) => {}
}
need_wait = false;
}
}

match iface.poll_delay(timestamp, &sockets) {
Some(delay) => {
tokio::time::sleep(delay.into()).await;
if need_wait {
match iface.poll_delay(timestamp, &sockets) {
Some(delay) => {
trace!("device poll delay: {:?}", delay);
tokio::time::sleep(delay.into()).await;
}
None => {}
}
None => {}
}

let mut port_to_release = Vec::new();
Expand Down Expand Up @@ -201,11 +208,9 @@ impl Device for VirtualIpDevice {
&mut self,
timestamp: smoltcp::time::Instant,
) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> {
trace!("wg tun polling data");
let next = self.packet_receiver.try_recv().ok();
match next {
Some((proto, data)) => {
trace!("wg tun received packet");
let rx_token = RxToken {
buffer: {
let mut buffer = BytesMut::new();
Expand All @@ -223,7 +228,6 @@ impl Device for VirtualIpDevice {
}

fn transmit(&mut self, timestamp: smoltcp::time::Instant) -> Option<Self::TxToken<'_>> {
trace!("wg tun writing data");
Some(TxToken {
sender: self.packet_sender.clone(),
})
Expand Down
48 changes: 38 additions & 10 deletions clash_lib/src/proxy/wg/stack/tcp.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::fmt::Debug;

use bytes::Bytes;
use bytes::{Bytes, BytesMut};

use tokio::{
io::{AsyncRead, AsyncWrite},
Expand All @@ -12,24 +12,52 @@ use tracing::trace;
pub struct SocketPair {
pub read: Receiver<Bytes>,
pub write: Sender<Bytes>,

read_buf: BytesMut,
}

impl SocketPair {
pub fn new(read: Receiver<Bytes>, write: Sender<Bytes>) -> Self {
Self {
read,
write,
read_buf: BytesMut::new(),
}
}
}

impl AsyncRead for SocketPair {
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
match self.read.try_recv() {
Ok(data) => {
trace!("tcp socket received: {:?}", data);
buf.put_slice(&data);
if !self.read_buf.is_empty() {
let len = std::cmp::min(self.read_buf.len(), buf.remaining());
buf.put_slice(&self.read_buf.split_to(len));
trace!(
"reusing cached data sent {}, left {}",
len,
self.read_buf.len()
);
return std::task::Poll::Ready(Ok(()));
}

match self.read.poll_recv(cx) {
std::task::Poll::Ready(Some(data)) => {
let len = std::cmp::min(data.len(), buf.remaining());
buf.put_slice(&data[..len]);
self.read_buf.extend_from_slice(&data[len..]);
trace!(
"socket got {} data, sent {}, left {}",
data.len(),
len,
self.read_buf.len()
);
std::task::Poll::Ready(Ok(()))
}
Err(_) => {
trace!("no data ready");
std::task::Poll::Pending
}
std::task::Poll::Ready(None) => std::task::Poll::Ready(Ok(())),
std::task::Poll::Pending => std::task::Poll::Pending,
}
}
}
Expand Down
72 changes: 30 additions & 42 deletions clash_lib/src/proxy/wg/wireguard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,22 +115,14 @@ impl WireguardTunnel {

let mut packet_reader = self.packet_reader.lock().await;
loop {
let timeouted_recv =
tokio::time::timeout(Duration::from_secs(30), packet_reader.recv());

match timeouted_recv.await {
Ok(Some(packet)) => {
match packet_reader.recv().await {
Some(packet) => {
if let Err(e) = self.send_ip_packet(&packet).await {
error!("failed to send packet: {}", e);
}
}
Ok(None) => {
trace!("connection closed, stopping");
break;
}
Err(e) => {
trace!("no active connection, stopping: {e:?}");
tokio::time::sleep(Duration::from_millis(1)).await;
None => {
trace!("no active connection, stopping");
break;
}
}
Expand All @@ -148,24 +140,14 @@ impl WireguardTunnel {

pub async fn start_receiving(&self) {
loop {
trace!("wg stack receiving data");

let mut recv_buf = [0u8; 65535];
let mut send_buf = [0u8; 65535];

let timeouted_recv =
tokio::time::timeout(Duration::from_secs(30), self.udp.recv(&mut recv_buf));

let size = match timeouted_recv.await {
Ok(Ok(size)) => size,
Ok(Err(e)) => {
error!("failed to receive packet: {e:?}");
tokio::time::sleep(Duration::from_millis(1)).await;
continue;
}
let size = match self.udp.recv(&mut recv_buf).await {
Ok(size) => size,
Err(e) => {
trace!("no active connection, stopping: {e:?}");
break;
error!("failed to receive packet: {}", e);
continue;
}
};
let mut peer = self.peer.lock().await;
Expand Down Expand Up @@ -256,22 +238,28 @@ impl WireguardTunnel {
/// Determine the inner protocol of the incoming IP packet (TCP/UDP).
fn route_protocol(&self, packet: &[u8]) -> Option<PortProtocol> {
match IpVersion::of_packet(packet) {
Ok(IpVersion::Ipv4) => Ipv4Packet::new_checked(&packet).ok().and_then(|packet| {
match packet.next_header() {
IpProtocol::Tcp => Some(PortProtocol::Tcp),
IpProtocol::Udp => Some(PortProtocol::Udp),
// Unrecognized protocol, so we cannot determine where to route
_ => None,
}
}),
Ok(IpVersion::Ipv6) => Ipv6Packet::new_checked(&packet).ok().and_then(|packet| {
match packet.next_header() {
IpProtocol::Tcp => Some(PortProtocol::Tcp),
IpProtocol::Udp => Some(PortProtocol::Udp),
// Unrecognized protocol, so we cannot determine where to route
_ => None,
}
}),
Ok(IpVersion::Ipv4) => Ipv4Packet::new_checked(&packet)
.ok()
.filter(|packet| Ipv4Addr::from(packet.dst_addr()) == self.source_peer_ip)
.and_then(|packet| {
match packet.next_header() {
IpProtocol::Tcp => Some(PortProtocol::Tcp),
IpProtocol::Udp => Some(PortProtocol::Udp),
// Unrecognized protocol, so we cannot determine where to route
_ => None,
}
}),
Ok(IpVersion::Ipv6) => Ipv6Packet::new_checked(&packet)
.ok()
.filter(|packet| Ipv6Addr::from(packet.dst_addr()) == self.source_peer_ip)
.and_then(|packet| {
match packet.next_header() {
IpProtocol::Tcp => Some(PortProtocol::Tcp),
IpProtocol::Udp => Some(PortProtocol::Udp),
// Unrecognized protocol, so we cannot determine where to route
_ => None,
}
}),
_ => None,
}
}
Expand Down

0 comments on commit f56ed43

Please sign in to comment.