From 73cdd5232362fa1949a6fc8fdd7549aa59cf9ef3 Mon Sep 17 00:00:00 2001 From: dev0 Date: Mon, 11 Dec 2023 02:53:31 +1100 Subject: [PATCH] WIP --- Cargo.lock | 121 ++++++++++++++ clash_lib/Cargo.toml | 2 + clash_lib/src/proxy/mod.rs | 2 +- clash_lib/src/proxy/wg/device.rs | 14 ++ clash_lib/src/proxy/wg/events.rs | 108 +++++++++++++ clash_lib/src/proxy/wg/mod.rs | 22 +-- clash_lib/src/proxy/wg/stack/mod.rs | 65 ++++++++ clash_lib/src/proxy/wg/stack/tcp.rs | 42 +++++ clash_lib/src/proxy/wg/stack/udp.rs | 0 clash_lib/src/proxy/wg/wireguard.rs | 236 ++++++++++++++++++++++++++++ 10 files changed, 601 insertions(+), 11 deletions(-) create mode 100644 clash_lib/src/proxy/wg/device.rs create mode 100644 clash_lib/src/proxy/wg/events.rs create mode 100644 clash_lib/src/proxy/wg/stack/mod.rs create mode 100644 clash_lib/src/proxy/wg/stack/tcp.rs create mode 100644 clash_lib/src/proxy/wg/stack/udp.rs create mode 100644 clash_lib/src/proxy/wg/wireguard.rs diff --git a/Cargo.lock b/Cargo.lock index 1756afaaf..c75de6770 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -228,6 +228,15 @@ dependencies = [ "syn 2.0.37", ] +[[package]] +name = "atomic-polyfill" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8cf2bce30dfe09ef0bfaef228b9d414faaf7e563035494d7fe092dba54b300f4" +dependencies = [ + "critical-section", +] + [[package]] name = "atty" version = "0.2.14" @@ -820,6 +829,7 @@ dependencies = [ "serde_yaml", "sha2", "shadowsocks", + "smoltcp", "socket2 0.5.5", "state", "tempfile", @@ -980,6 +990,12 @@ dependencies = [ "itertools 0.10.5", ] +[[package]] +name = "critical-section" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7059fff8937831a9ae6f0fe4d658ffabf58f2ca96aa9dec1c889f936f705f216" + [[package]] name = "crossbeam" version = "0.8.2" @@ -1122,6 +1138,38 @@ version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2e66c9d817f1720209181c316d28635c050fa304f9c79e47a520882661b7308" +[[package]] +name = "defmt" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8a2d011b2fee29fb7d659b83c43fce9a2cb4df453e16d441a51448e448f3f98" +dependencies = [ + "bitflags 1.3.2", + "defmt-macros", +] + +[[package]] +name = "defmt-macros" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54f0216f6c5acb5ae1a47050a6645024e6edafc2ee32d421955eccfef12ef92e" +dependencies = [ + "defmt-parser", + "proc-macro-error", + "proc-macro2", + "quote", + "syn 2.0.37", +] + +[[package]] +name = "defmt-parser" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "269924c02afd7f94bc4cecbfa5c379f6ffcf9766b3408fe63d22c728654eccd0" +dependencies = [ + "thiserror", +] + [[package]] name = "der" version = "0.6.1" @@ -1605,6 +1653,15 @@ version = "1.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eabb4a44450da02c90444cf74558da904edde8fb4e9035a9a6a4e15445af0bd7" +[[package]] +name = "hash32" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0c35f58762feb77d74ebe43bdbc3210f09be9fe6742234d573bacc26ed92b67" +dependencies = [ + "byteorder", +] + [[package]] name = "hashbrown" version = "0.12.3" @@ -1655,6 +1712,19 @@ dependencies = [ "http", ] +[[package]] +name = "heapless" +version = "0.7.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cdc6457c0eb62c71aac4bc17216026d8410337c4126773b9c5daba343f17964f" +dependencies = [ + "atomic-polyfill", + "hash32", + "rustc_version", + "spin 0.9.8", + "stable_deref_trait", +] + [[package]] name = "heck" version = "0.4.1" @@ -2269,6 +2339,12 @@ dependencies = [ "libc", ] +[[package]] +name = "managed" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ca88d725a0a943b096803bd34e73a4437208b6077654cc4ecb2947a5f91618d" + [[package]] name = "match_cfg" version = "0.1.0" @@ -2878,6 +2954,30 @@ dependencies = [ "syn 2.0.37", ] +[[package]] +name = "proc-macro-error" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" +dependencies = [ + "proc-macro-error-attr", + "proc-macro2", + "quote", + "syn 1.0.109", + "version_check", +] + +[[package]] +name = "proc-macro-error-attr" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" +dependencies = [ + "proc-macro2", + "quote", + "version_check", +] + [[package]] name = "proc-macro2" version = "1.0.67" @@ -3507,6 +3607,21 @@ version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "942b4a808e05215192e39f4ab80813e599068285906cc91aa64f923db842bd5a" +[[package]] +name = "smoltcp" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d2e3a36ac8fea7b94e666dfa3871063d6e0a5c9d5d4fec9a1a6b7b6760f0229" +dependencies = [ + "bitflags 1.3.2", + "byteorder", + "cfg-if", + "defmt", + "heapless", + "log", + "managed", +] + [[package]] name = "socket2" version = "0.4.9" @@ -3552,6 +3667,12 @@ dependencies = [ "der", ] +[[package]] +name = "stable_deref_trait" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" + [[package]] name = "state" version = "0.6.0" diff --git a/clash_lib/Cargo.toml b/clash_lib/Cargo.toml index 767ae16f9..88b8d0ed8 100644 --- a/clash_lib/Cargo.toml +++ b/clash_lib/Cargo.toml @@ -65,6 +65,8 @@ tun = { git = "https://github.com/Watfaq/rust-tun.git", rev = "8f7568190f1200d3e netstack-lwip = { git = "https://github.com/Watfaq/netstack-lwip.git", rev = "2817bf82740e04bbee6b7bf1165f55657a6ed163" } boringtun = { version = "0.6.0" } +smoltcp = { version = "0.10", default-features = false, features = ["std", "log", "medium-ip", "proto-ipv4", "proto-ipv6", "socket-udp", "socket-tcp"] } + serde = { version = "1.0", features=["derive"] } serde_yaml = "0.9" diff --git a/clash_lib/src/proxy/mod.rs b/clash_lib/src/proxy/mod.rs index 9c9de4721..b632ca95f 100644 --- a/clash_lib/src/proxy/mod.rs +++ b/clash_lib/src/proxy/mod.rs @@ -32,7 +32,7 @@ pub mod trojan; pub mod tun; pub mod utils; pub mod vmess; -//pub mod wg; +pub mod wg; pub mod converters; diff --git a/clash_lib/src/proxy/wg/device.rs b/clash_lib/src/proxy/wg/device.rs new file mode 100644 index 000000000..a3b0b739f --- /dev/null +++ b/clash_lib/src/proxy/wg/device.rs @@ -0,0 +1,14 @@ +use std::{ + collections::VecDeque, + sync::{Arc, Mutex}, +}; + +use bytes::Bytes; + +use super::events::BusSender; + +pub struct VirtualIpDevice { + mtu: usize, + bus_sender: BusSender, + queue: Arc>>, +} diff --git a/clash_lib/src/proxy/wg/events.rs b/clash_lib/src/proxy/wg/events.rs new file mode 100644 index 000000000..5dc951d82 --- /dev/null +++ b/clash_lib/src/proxy/wg/events.rs @@ -0,0 +1,108 @@ +use std::sync::{atomic::AtomicU32, Arc}; + +use bytes::Bytes; +use tracing::error; + +/// Layer 7 protocols for ports. +#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash, Ord, PartialOrd)] +pub enum PortProtocol { + /// TCP + Tcp, + /// UDP + Udp, +} + +#[derive(Clone)] +pub struct Bus { + counter: Arc, + bus: tokio::sync::broadcast::Sender<(u32, Event)>, +} + +impl Bus { + pub fn new() -> Self { + let (tx, _) = tokio::sync::broadcast::channel(1024); + Self { + counter: Arc::new(AtomicU32::new(0)), + bus: tx, + } + } + + pub fn new_endpoint(&self) -> BusEndpoint { + let id = self + .counter + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + + let tx = self.bus.clone(); + let rx = self.bus.subscribe(); + + let tx = BusSender { id, tx }; + BusEndpoint { id, tx, rx } + } +} + +pub struct BusEndpoint { + id: u32, + tx: BusSender, + rx: tokio::sync::broadcast::Receiver<(u32, Event)>, +} + +impl BusEndpoint { + /// Sends the event on the bus. Note that the messages sent by this endpoint won't reach itself. + pub fn send(&self, event: Event) { + self.tx.send(event) + } + + /// Returns the unique sequential ID of this endpoint. + pub fn id(&self) -> u32 { + self.id + } + + /// Awaits the next `Event` on the bus to be read. + pub async fn recv(&mut self) -> Event { + loop { + match self.rx.recv().await { + Ok((id, event)) => { + if id == self.id { + // If the event was sent by this endpoint, it is skipped + continue; + } else { + return event; + } + } + Err(_) => { + error!("Failed to read event bus from endpoint #{}", self.id); + return futures::future::pending().await; + } + } + } + } + + /// Creates a new sender for this endpoint that can be cloned. + pub fn sender(&self) -> BusSender { + self.tx.clone() + } +} + +/// Events that go on the bus between the local server, smoltcp, and WireGuard. +#[derive(Debug, Clone)] +pub enum Event { + InboundInternetPacket(PortProtocol, Bytes), + /// IP packet to be sent through the WireGuard tunnel as crafted by the virtual device. + OutboundInternetPacket(Bytes), +} + +#[derive(Clone)] +pub struct BusSender { + id: u32, + tx: tokio::sync::broadcast::Sender<(u32, Event)>, +} + +impl BusSender { + /// Sends the event on the bus. Note that the messages sent by this endpoint won't reach itself. + pub fn send(&self, event: Event) { + match self.tx.send((self.id, event)) { + Ok(_) => {} + Err(_) => error!("Failed to send event to bus from endpoint #{}", self.id), + } + } +} diff --git a/clash_lib/src/proxy/wg/mod.rs b/clash_lib/src/proxy/wg/mod.rs index fc88d1bfc..877274682 100644 --- a/clash_lib/src/proxy/wg/mod.rs +++ b/clash_lib/src/proxy/wg/mod.rs @@ -5,17 +5,23 @@ use std::{ }; use crate::{ - app::{dispatcher::BoxedChainedStream, dns::ThreadSafeDNSResolver}, + app::{ + dispatcher::{BoxedChainedDatagram, BoxedChainedStream}, + dns::ThreadSafeDNSResolver, + }, session::{Session, SocksAddr}, }; -use super::{ - AnyOutboundDatagram, AnyOutboundHandler, AnyStream, CommonOption, OutboundHandler, OutboundType, -}; +use super::{AnyOutboundHandler, AnyStream, CommonOption, OutboundHandler, OutboundType}; use async_trait::async_trait; pub use netstack_lwip as netstack; +mod device; +mod events; +mod stack; +mod wireguard; + pub struct Opts { pub name: String, pub common_opts: CommonOption, @@ -34,15 +40,11 @@ pub struct Opts { pub struct Handler { opts: Opts, - - device: boringtun::device::Device, } impl Handler { pub fn new(opts: Opts) -> AnyOutboundHandler { - let device_cfg = boringtun::device::DeviceConfig::default(); - let device = boringtun::device::Device::new("utun", device_cfg).unwrap(); - Arc::new(Self { opts, device }) + Arc::new(Self { opts }) } } @@ -88,7 +90,7 @@ impl OutboundHandler for Handler { &self, sess: &Session, resolver: ThreadSafeDNSResolver, - ) -> io::Result { + ) -> io::Result { todo!() } } diff --git a/clash_lib/src/proxy/wg/stack/mod.rs b/clash_lib/src/proxy/wg/stack/mod.rs new file mode 100644 index 000000000..0a89b63f6 --- /dev/null +++ b/clash_lib/src/proxy/wg/stack/mod.rs @@ -0,0 +1,65 @@ +pub mod tcp; +pub mod udp; + +use async_trait::async_trait; +use std::fmt::{Display, Formatter}; + +use super::{device::VirtualIpDevice, events::PortProtocol}; + +#[async_trait] +pub trait VirtualInterfacePoll { + /// Initializes the virtual interface and processes incoming data to be dispatched + /// to the WireGuard tunnel and to the real client. + async fn poll_loop(mut self, device: VirtualIpDevice) -> std::io::Result<()>; +} + +/// Virtual port. +#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)] +pub struct VirtualPort(u16, PortProtocol); + +impl VirtualPort { + /// Create a new `VirtualPort` instance, with the given port number and associated protocol. + pub fn new(port: u16, proto: PortProtocol) -> Self { + VirtualPort(port, proto) + } + + /// The port number + pub fn num(&self) -> u16 { + self.0 + } + + /// The protocol of this port. + pub fn proto(&self) -> PortProtocol { + self.1 + } +} + +impl From for u16 { + fn from(port: VirtualPort) -> Self { + port.num() + } +} + +impl From<&VirtualPort> for u16 { + fn from(port: &VirtualPort) -> Self { + port.num() + } +} + +impl From for PortProtocol { + fn from(port: VirtualPort) -> Self { + port.proto() + } +} + +impl From<&VirtualPort> for PortProtocol { + fn from(port: &VirtualPort) -> Self { + port.proto() + } +} + +impl Display for VirtualPort { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "[{}:{}]", self.num(), self.proto()) + } +} diff --git a/clash_lib/src/proxy/wg/stack/tcp.rs b/clash_lib/src/proxy/wg/stack/tcp.rs new file mode 100644 index 000000000..e00d56e91 --- /dev/null +++ b/clash_lib/src/proxy/wg/stack/tcp.rs @@ -0,0 +1,42 @@ +use std::net::IpAddr; + +use async_trait::async_trait; +use smoltcp::{ + iface::{Config, Interface}, + socket::tcp::Socket, + time::Instant, +}; + +use crate::proxy::wg::{device::VirtualIpDevice, events::Bus}; + +use super::VirtualInterfacePoll; + +pub struct VirtualTcpDevice { + source_peer_ip: IpAddr, + bus: Bus, +} + +impl VirtualTcpDevice { + pub fn new(source_peer_ip: IpAddr, bus: Bus) -> Self { + Self { + source_peer_ip, + bus, + } + } + + pub fn new_client_socket() -> Socket<'static> { + Socket::new( + smoltcp::socket::tcp::SocketBuffer::new(vec![0; 65535]), + smoltcp::socket::tcp::SocketBuffer::new(vec![0; 65535]), + ) + } +} + +#[async_trait] +impl VirtualInterfacePoll for VirtualTcpDevice { + async fn poll_loop(self, device: VirtualIpDevice) -> std::io::Result<()> { + let mut config = Config::new(smoltcp::wire::HardwareAddress::Ip); + config.random_seed = rand::random(); + let mut iface = Interface::new(config, &mut device, Instant::now()); + } +} diff --git a/clash_lib/src/proxy/wg/stack/udp.rs b/clash_lib/src/proxy/wg/stack/udp.rs new file mode 100644 index 000000000..e69de29bb diff --git a/clash_lib/src/proxy/wg/wireguard.rs b/clash_lib/src/proxy/wg/wireguard.rs new file mode 100644 index 000000000..cd6c64ce5 --- /dev/null +++ b/clash_lib/src/proxy/wg/wireguard.rs @@ -0,0 +1,236 @@ +use std::{ + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, + time::Duration, +}; + +use async_recursion::async_recursion; +use boringtun::noise::{errors::WireGuardError, Tunn, TunnResult}; + +use smoltcp::wire::{IpProtocol, IpVersion, Ipv4Packet, Ipv6Packet}; +use tokio::net::UdpSocket; +use tracing::{enabled, error, trace, warn}; + +use crate::Error; + +use super::events::{Bus, Event, PortProtocol}; + +pub struct WireguardTunnel { + pub(crate) source_peer_ip: IpAddr, + peer: Box, + udp: UdpSocket, + pub(crate) endpoint: SocketAddr, + bus: Bus, +} + +pub struct Config { + pub private_key: [u8; 32], + pub endpoint_public_key: [u8; 32], + pub preshared_key: Option<[u8; 32]>, + pub remote_endpoint: SocketAddr, + pub source_peer_ip: IpAddr, + pub keepalive_seconds: Option, +} + +impl WireguardTunnel { + pub async fn new(config: Config, bus: Bus) -> Result { + let source_peer_ip = config.source_peer_ip; + let peer = Box::new( + Tunn::new( + config.private_key.into(), + config.endpoint_public_key.into(), + config.preshared_key, + config.keepalive_seconds, + 0, + None, + ) + .map_err(|x| { + Error::InvalidConfig(format!("failed to create wireguard tunnel: {}", x)) + })?, + ); + + let remote_endpoint = config.remote_endpoint; + let udp = UdpSocket::bind("127.0.0.1:0").await?; + + Ok(Self { + source_peer_ip, + peer, + udp, + endpoint: remote_endpoint, + bus, + }) + } + + pub async fn send_ip_packet(&mut self, packet: &[u8]) -> Result<(), Error> { + trace_ip_packet("Sending IP packet", packet); + + let mut send_buf = [0u8; 65535]; + match self.peer.encapsulate(packet, &mut send_buf) { + boringtun::noise::TunnResult::Done => {} + boringtun::noise::TunnResult::Err(e) => { + error!("failed to encapsulate packet: {e:?}"); + } + boringtun::noise::TunnResult::WriteToNetwork(packet) => { + self.udp.send_to(&packet, self.endpoint).await?; + } + _ => { + error!("unexpected result from encapsulate"); + } + } + Ok(()) + } + + pub async fn start_forwarding(&mut self) { + let mut ep = self.bus.new_endpoint(); + + loop { + if let Event::OutboundInternetPacket(data) = ep.recv().await { + if let Err(e) = self.send_ip_packet(&data).await { + error!("failed to send packet: {}", e); + } + } + } + } + + pub async fn start_heartbeat(&mut self) { + loop { + let mut send_buf = [0u8; 65535]; + let tun_result = self.peer.update_timers(&mut send_buf); + self.handle_routine_result(tun_result).await; + } + } + + pub async fn start_receiving(&mut self) { + let ep = self.bus.new_endpoint(); + + loop { + let mut recv_buf = [0u8; 65535]; + let mut send_buf = [0u8; 65535]; + + let size = match self.udp.recv(&mut recv_buf).await { + Ok(size) => size, + Err(e) => { + error!("failed to receive packet: {e:?}"); + tokio::time::sleep(Duration::from_millis(1)).await; + continue; + } + }; + + let data = &recv_buf[..size]; + match self.peer.decapsulate(None, data, &mut send_buf) { + TunnResult::Done => todo!(), + TunnResult::Err(_) => todo!(), + TunnResult::WriteToNetwork(packet) => { + match self.udp.send_to(&packet, self.endpoint).await { + Ok(_) => {} + Err(e) => { + error!("failed to send packet: {}", e); + continue; + } + } + + loop { + let mut send_buf = [0u8; 65535]; + match self.peer.decapsulate(None, &[], &mut send_buf) { + TunnResult::WriteToNetwork(packet) => { + match self.udp.send_to(packet, self.endpoint).await { + Ok(_) => {} + Err(e) => { + error!("Failed to send decapsulation-instructed packet to WireGuard endpoint: {:?}", e); + break; + } + }; + } + _ => { + break; + } + } + } + } + TunnResult::WriteToTunnelV4(packet, _) | TunnResult::WriteToTunnelV6(packet, _) => { + trace_ip_packet("Received IP packet", packet); + + if let Some(proto) = self.route_protocol(packet) { + ep.send(Event::InboundInternetPacket(proto, packet.to_vec().into())); + // TODO: avoid copy + } + } + } + } + } + + #[async_recursion] + async fn handle_routine_result<'a: 'async_recursion>(&mut self, result: TunnResult<'a>) { + match result { + TunnResult::Done => { + tokio::time::sleep(Duration::from_millis(1)).await; + } + TunnResult::Err(WireGuardError::ConnectionExpired) => { + warn!("wireguard connection expired"); + let mut buf = [0u8; 65535]; + let tun_result = self.peer.format_handshake_initiation(&mut buf[..], false); + self.handle_routine_result(tun_result).await; + } + TunnResult::Err(e) => { + error!("wireguard error: {e:?}"); + } + TunnResult::WriteToNetwork(packet) => { + match self.udp.send_to(&packet, self.endpoint).await { + Ok(_) => {} + Err(e) => { + error!("failed to send packet: {}", e); + } + } + } + _ => { + error!("unexpected result from wireguard"); + } + } + } + + /// Determine the inner protocol of the incoming IP packet (TCP/UDP). + fn route_protocol(&self, packet: &[u8]) -> Option { + match IpVersion::of_packet(packet) { + Ok(IpVersion::Ipv4) => Ipv4Packet::new_checked(&packet) + .ok() + // Only care if the packet is destined for this tunnel + .filter(|packet| Ipv4Addr::from(packet.dst_addr()) == self.source_peer_ip) + .and_then(|packet| match packet.next_header() { + IpProtocol::Tcp => Some(PortProtocol::Tcp), + IpProtocol::Udp => Some(PortProtocol::Udp), + // Unrecognized protocol, so we cannot determine where to route + _ => None, + }), + Ok(IpVersion::Ipv6) => Ipv6Packet::new_checked(&packet) + .ok() + // Only care if the packet is destined for this tunnel + .filter(|packet| Ipv6Addr::from(packet.dst_addr()) == self.source_peer_ip) + .and_then(|packet| match packet.next_header() { + IpProtocol::Tcp => Some(PortProtocol::Tcp), + IpProtocol::Udp => Some(PortProtocol::Udp), + // Unrecognized protocol, so we cannot determine where to route + _ => None, + }), + _ => None, + } + } +} + +fn trace_ip_packet(message: &str, packet: &[u8]) { + if enabled!(tracing::Level::TRACE) { + use smoltcp::wire::*; + + match IpVersion::of_packet(packet) { + Ok(IpVersion::Ipv4) => trace!( + "{}: {}", + message, + PrettyPrinter::>::new("", &packet) + ), + Ok(IpVersion::Ipv6) => trace!( + "{}: {}", + message, + PrettyPrinter::>::new("", &packet) + ), + _ => {} + } + } +}