From de8e93a882f3894d534fd666e8533b6b44e0bbbd Mon Sep 17 00:00:00 2001 From: tickbh Date: Thu, 25 Jan 2024 15:45:30 +0800 Subject: [PATCH] fix server --- examples/proxy.rs | 7 +-- src/arg.rs | 29 +-------- src/config/mod.rs | 2 + src/config/wrap.rs | 37 +++++++++++ src/option.rs | 112 +++++++++++++++++++-------------- src/prot/close.rs | 12 ++-- src/prot/create.rs | 7 ++- src/prot/data.rs | 8 ++- src/prot/frame.rs | 19 ++++-- src/streams/center_trans.rs | 77 +++++++++++++++++++++++ src/streams/mod.rs | 4 +- src/wmcore.rs | 121 ++++++++++++++++-------------------- tests/mapping.rs | 8 +-- tests/proxy.rs | 29 ++++++--- 14 files changed, 298 insertions(+), 174 deletions(-) create mode 100644 src/config/wrap.rs create mode 100644 src/streams/center_trans.rs diff --git a/examples/proxy.rs b/examples/proxy.rs index af49753..20fb416 100644 --- a/examples/proxy.rs +++ b/examples/proxy.rs @@ -111,11 +111,9 @@ async fn main() { let password = "wmproxy".to_string(); let proxy = ProxyConfig::builder() - .bind_addr(addr) + .bind(addr) .username(Some(username.clone())) .password(Some(password.clone())) - .center(true) - .mode("server".to_string()) .into_value() .unwrap(); @@ -124,10 +122,9 @@ async fn main() { .unwrap(); let proxy = ProxyConfig::builder() - .bind_addr(addr) + .bind(addr) .username(Some(username.clone())) .password(Some(password.clone())) - .center(true) .server(Some(format!("{}", server_addr))) .into_value() .unwrap(); diff --git a/src/arg.rs b/src/arg.rs index 07ce52c..44e3340 100644 --- a/src/arg.rs +++ b/src/arg.rs @@ -14,13 +14,11 @@ use std::process::id; use std::{ - fmt::Display, fs::File, io::{self, Read, Write}, - net::{AddrParseError, IpAddr, Ipv4Addr, SocketAddr}, + net::{IpAddr, Ipv4Addr, SocketAddr}, path::PathBuf, process::exit, - str::FromStr, }; use bpaf::*; @@ -29,38 +27,15 @@ use webparse::{Request, Url}; use wenmeng::Client; use crate::reverse::StreamConfig; -use crate::ConfigDuration; use crate::{ option::proxy_config, reverse::{HttpConfig, LocationConfig, ServerConfig, UpstreamConfig}, ConfigHeader, ConfigLog, ConfigOption, FileServer, ProxyConfig, ProxyResult, }; +use crate::{ConfigDuration, WrapAddr}; const VERSION: &str = env!("CARGO_PKG_VERSION"); -#[derive(Debug, Clone, Copy)] -pub struct WrapAddr(pub SocketAddr); - -impl FromStr for WrapAddr { - type Err = AddrParseError; - - fn from_str(s: &str) -> Result { - if s.starts_with(":") { - let addr = format!("127.0.0.1{s}").parse::()?; - Ok(WrapAddr(addr)) - } else { - let addr = s.parse::()?; - Ok(WrapAddr(addr)) - } - } -} - -impl Display for WrapAddr { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_fmt(format_args!("{}", self.0)) - } -} - #[derive(Debug, Clone, Bpaf)] #[allow(dead_code)] struct Shared { diff --git a/src/config/mod.rs b/src/config/mod.rs index 057d047..9470661 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -17,6 +17,7 @@ mod log; mod header; mod rate; mod ip_sets; +mod wrap; use std::{str::FromStr, fmt::{Display, self}, marker::PhantomData}; @@ -26,6 +27,7 @@ pub use self::log::ConfigLog; pub use self::header::{ConfigHeader, HeaderOper}; pub use self::rate::ConfigRate; pub use self::ip_sets::*; +pub use self::wrap::*; use serde::{Serializer, Deserializer, de::{Visitor, Error, self}}; use serde_with::{SerializeAs, DeserializeAs}; diff --git a/src/config/wrap.rs b/src/config/wrap.rs new file mode 100644 index 0000000..cf3dbc4 --- /dev/null +++ b/src/config/wrap.rs @@ -0,0 +1,37 @@ +// Copyright 2022 - 2024 Wenmeng See the COPYRIGHT +// file at the top-level directory of this distribution. +// +// Licensed under the Apache License, Version 2.0 , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. +// +// Author: tickbh +// ----- +// Created Date: 2024/01/25 02:13:35 + +use std::{fmt::Display, net::{AddrParseError, SocketAddr}, str::FromStr}; + + +#[derive(Debug, Clone, Copy)] +pub struct WrapAddr(pub SocketAddr); + +impl FromStr for WrapAddr { + type Err = AddrParseError; + + fn from_str(s: &str) -> Result { + if s.starts_with(":") { + let addr = format!("127.0.0.1{s}").parse::()?; + Ok(WrapAddr(addr)) + } else { + let addr = s.parse::()?; + Ok(WrapAddr(addr)) + } + } +} + +impl Display for WrapAddr { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_fmt(format_args!("{}", self.0)) + } +} diff --git a/src/option.rs b/src/option.rs index a5dafb9..d6341df 100644 --- a/src/option.rs +++ b/src/option.rs @@ -14,7 +14,7 @@ use std::{ collections::{HashMap, HashSet}, fs::File, io::{self, BufReader}, - net::{IpAddr, Ipv4Addr, SocketAddr}, + net::{IpAddr, SocketAddr}, process, sync::Arc, time::Duration, @@ -30,8 +30,7 @@ use tokio::net::TcpListener; use tokio_rustls::{rustls, TlsAcceptor}; use crate::{ - reverse::{HttpConfig, StreamConfig, UpstreamConfig}, - CenterClient, Flag, Helper, MappingConfig, OneHealth, ProxyError, ProxyResult, + reverse::{HttpConfig, StreamConfig, UpstreamConfig}, CenterClient, Flag, Helper, MappingConfig, OneHealth, ProxyError, ProxyResult, WrapAddr }; pub struct Builder { @@ -53,26 +52,34 @@ impl Builder { }) } - pub fn mode(self, mode: String) -> Builder { + // pub fn mode(self, mode: String) -> Builder { + // self.and_then(|mut proxy| { + // proxy.mode = mode; + // Ok(proxy) + // }) + // } + + pub fn add_flag(self, flag: Flag) -> Builder { self.and_then(|mut proxy| { - proxy.mode = mode; + proxy.flag.set(flag, true); Ok(proxy) }) } - pub fn add_flag(self, flag: Flag) -> Builder { + pub fn bind(self, addr: SocketAddr) -> Builder { self.and_then(|mut proxy| { - proxy.flag.set(flag, true); + proxy.bind = Some(WrapAddr(addr)); Ok(proxy) }) } - pub fn bind_addr(self, addr: SocketAddr) -> Builder { + pub fn center_addr(self, addr: SocketAddr) -> Builder { self.and_then(|mut proxy| { - proxy.bind_addr = addr; + proxy.center_addr = Some(WrapAddr(addr)); Ok(proxy) }) } + pub fn server(self, addr: Option) -> Builder { self.and_then(|mut proxy| { @@ -88,13 +95,6 @@ impl Builder { }) } - pub fn center(self, center: bool) -> Builder { - self.and_then(|mut proxy| { - proxy.center = center; - Ok(proxy) - }) - } - pub fn tc(self, is_tls: bool) -> Builder { self.and_then(|mut proxy| { proxy.tc = is_tls; @@ -197,33 +197,47 @@ fn default_bind_addr() -> SocketAddr { "127.0.0.1:8090".parse().unwrap() } -fn default_bool_true() -> bool { - true -} - /// 代理类, 一个代理类启动一种类型的代理 #[serde_as] #[derive(Debug, Clone, Serialize, Deserialize, Bpaf)] pub struct ProxyConfig { + /// 代理id + #[bpaf( + short('s'), + long + )] + #[serde(default)] + pub(crate) server_id: u32, + /// 代理绑定端口地址 #[bpaf( - fallback(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8090)), - display_fallback, + // fallback(Some(WrapAddr(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8090)))), + // display_fallback, short('b'), long )] - #[serde(default = "default_bind_addr")] - pub(crate) bind_addr: SocketAddr, + #[serde_as(as = "Option")] + pub(crate) bind: Option, + + /// 代理绑定端口地址 + #[bpaf( + // fallback(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8090)), + // display_fallback, + short('c'), + long + )] + #[serde_as(as = "Option")] + pub(crate) center_addr: Option, /// 代理种类, 如http https socks5 #[bpaf(fallback(Flag::default()))] #[serde_as(as = "DisplayFromStr")] #[serde(default)] pub(crate) flag: Flag, - /// 启动程序类型 - #[bpaf(fallback("client".to_string()))] - #[serde(default)] - pub(crate) mode: String, + // /// 启动程序类型 + // #[bpaf(fallback("client".to_string()))] + // #[serde(default)] + // pub(crate) mode: String, /// 连接代理服务端地址 #[bpaf(short('S'), long("server"))] @@ -249,9 +263,6 @@ pub struct ProxyConfig { /// 内网映射的证书key pub(crate) map_key: Option, - /// 是否启用协议转发 - #[serde(default = "default_bool_true")] - pub(crate) center: bool, /// 连接服务端是否启用tls #[serde(default)] pub(crate) ts: bool, @@ -314,9 +325,11 @@ impl Default for ConfigOption { impl Default for ProxyConfig { fn default() -> Self { Self { + server_id: 0, flag: Flag::HTTP | Flag::HTTPS | Flag::SOCKS5, - mode: "client".to_string(), - bind_addr: default_bind_addr(), + // mode: "client".to_string(), + bind: Some(WrapAddr(default_bind_addr())), + center_addr: None, server: None, username: None, password: None, @@ -328,7 +341,6 @@ impl Default for ProxyConfig { map_cert: None, map_key: None, - center: false, ts: false, tc: false, two_way_tls: false, @@ -561,13 +573,13 @@ cR+nZ6DRmzKISbcN9/m8I7xNWwU2cglrYa4NCHguQSrTefhRoZAfl8BEOW1rJVGC } } - pub fn is_client(&self) -> bool { - self.mode.eq_ignore_ascii_case("client") - } + // pub fn is_client(&self) -> bool { + // self.mode.eq_ignore_ascii_case("client") + // } - pub fn is_server(&self) -> bool { - self.mode.eq_ignore_ascii_case("server") - } + // pub fn is_server(&self) -> bool { + // self.mode.eq_ignore_ascii_case("server") + // } pub async fn bind( &self, @@ -575,13 +587,13 @@ cR+nZ6DRmzKISbcN9/m8I7xNWwU2cglrYa4NCHguQSrTefhRoZAfl8BEOW1rJVGC Option, Option>, Option, + Option, Option, )> { - let addr = self.bind_addr.clone(); let proxy_accept = self.get_tls_accept().await.ok(); let client = self.get_tls_request().await.ok(); let mut center_client = None; - if self.center { + if self.bind.is_some() { if let Some(server) = self.server.clone() { let mut center = CenterClient::new( self.clone(), @@ -609,9 +621,19 @@ cR+nZ6DRmzKISbcN9/m8I7xNWwU2cglrYa4NCHguQSrTefhRoZAfl8BEOW1rJVGC center_client = Some(center); } } - log::info!("绑定代理:{:?},提供代理功能。", addr); - let center_listener = Some(Helper::bind(addr).await?); - Ok((proxy_accept, client, center_listener, center_client)) + let client_listener = if let Some(bind) = self.bind { + log::info!("绑定代理:{:?},提供代理功能。", bind.0); + Some(Helper::bind(bind.0).await?) + } else { + None + }; + let center_listener = if let Some(center) = self.center_addr { + log::info!("绑定代理:{:?},提供中心代理功能。", center.0); + Some(Helper::bind(center.0).await?) + } else { + None + }; + Ok((proxy_accept, client, client_listener, center_listener, center_client)) } pub async fn bind_map( diff --git a/src/prot/close.rs b/src/prot/close.rs index 5ab8954..29e75d8 100644 --- a/src/prot/close.rs +++ b/src/prot/close.rs @@ -22,29 +22,31 @@ use super::{ProtFrameHeader, read_short_string, write_short_string}; /// 旧的Socket连接关闭, 接收到则关闭掉当前的连接 #[derive(Debug)] pub struct ProtClose { + server_id: u32, sock_map: u32, reason: String, } impl ProtClose { - pub fn new(sock_map: u32) -> ProtClose { - ProtClose { sock_map, reason: String::new() } + pub fn new(server_id: u32, sock_map: u32) -> ProtClose { + ProtClose { server_id, sock_map, reason: String::new() } } - pub fn new_by_reason(sock_map: u32, reason: String) -> ProtClose { - ProtClose { sock_map, reason } + pub fn new_by_reason(server_id: u32, sock_map: u32, reason: String) -> ProtClose { + ProtClose { server_id, sock_map, reason } } pub fn parse(header: ProtFrameHeader, mut buf: T) -> ProxyResult { let reason = read_short_string(&mut buf)?; Ok(ProtClose { + server_id: header.server_id(), sock_map: header.sock_map(), reason, }) } pub fn encode(self, buf: &mut B) -> ProxyResult { - let mut head = ProtFrameHeader::new(ProtKind::Close, ProtFlag::zero(), self.sock_map); + let mut head = ProtFrameHeader::new(ProtKind::Close, ProtFlag::zero(), self.sock_map, self.server_id); head.length = self.reason.as_bytes().len() as u32 + 1; let mut size = 0; size += head.encode(buf)?; diff --git a/src/prot/create.rs b/src/prot/create.rs index 72259a7..00146f2 100644 --- a/src/prot/create.rs +++ b/src/prot/create.rs @@ -24,14 +24,16 @@ use super::ProtFrameHeader; #[derive(Debug)] #[allow(dead_code)] pub struct ProtCreate { + server_id: u32, sock_map: u32, mode: u8, domain: Option, } impl ProtCreate { - pub fn new(sock_map: u32, domain: Option) -> Self { + pub fn new(server_id: u32,sock_map: u32, domain: Option) -> Self { Self { + server_id, sock_map, mode: 0, domain, @@ -49,6 +51,7 @@ impl ProtCreate { domain = Some(String::from_utf8_lossy(data).to_string()); } Ok(ProtCreate { + server_id: header.server_id(), sock_map: header.sock_map(), mode: 0, domain, @@ -56,7 +59,7 @@ impl ProtCreate { } pub fn encode(self, buf: &mut B) -> ProxyResult { - let mut head = ProtFrameHeader::new(ProtKind::Create, ProtFlag::zero(), self.sock_map); + let mut head = ProtFrameHeader::new(ProtKind::Create, ProtFlag::zero(), self.sock_map, self.server_id); let domain_len = self.domain.as_ref().map(|s| s.as_bytes().len() as u32).unwrap_or(0); head.length = 1 + domain_len; let mut size = 0; diff --git a/src/prot/data.rs b/src/prot/data.rs index e7629e4..ecc4724 100644 --- a/src/prot/data.rs +++ b/src/prot/data.rs @@ -22,18 +22,20 @@ use super::ProtFrameHeader; /// Socket的数据消息包 #[derive(Debug)] pub struct ProtData { + server_id: u32, sock_map: u32, data: Vec, } impl ProtData { - pub fn new(sock_map: u32, data: Vec) -> ProtData { - Self { sock_map, data } + pub fn new(server_id: u32,sock_map: u32, data: Vec) -> ProtData { + Self { server_id, sock_map, data } } pub fn parse(header: ProtFrameHeader, mut buf: T) -> ProxyResult { log::trace!("代理中心: 解码Data数据长度={}", header.length); Ok(Self { + server_id: header.server_id(), sock_map: header.sock_map(), data: buf.advance_chunk(header.length as usize).to_vec(), }) @@ -41,7 +43,7 @@ impl ProtData { pub fn encode(mut self, buf: &mut B) -> ProxyResult { log::trace!("代理中心: 编码Data数据长度={}", self.data.len()); - let mut head = ProtFrameHeader::new(ProtKind::Data, ProtFlag::zero(), self.sock_map); + let mut head = ProtFrameHeader::new(ProtKind::Data, ProtFlag::zero(), self.sock_map, self.server_id); head.length = self.data.len() as u32; let mut size = 0; size += head.encode(buf)?; diff --git a/src/prot/frame.rs b/src/prot/frame.rs index c6577a0..4eb8193 100644 --- a/src/prot/frame.rs +++ b/src/prot/frame.rs @@ -11,6 +11,7 @@ // Created Date: 2023/09/22 10:30:10 +use tokio_util::bytes::buf; use webparse::{Buf, http2::frame::{read_u24, encode_u24}, BufMut}; use crate::{ProxyResult, MappingConfig}; @@ -28,6 +29,8 @@ pub struct ProtFrameHeader { flag: ProtFlag, /// 3个字节, socket在内存中相应的句柄, 客户端发起为单数, 服务端发起为双数 sock_map: u32, + /// 服务器的id + server_id: u32, } #[derive(Debug)] @@ -45,17 +48,22 @@ pub enum ProtFrame { } impl ProtFrameHeader { - pub const FRAME_HEADER_BYTES: usize = 8; + pub const FRAME_HEADER_BYTES: usize = 12; - pub fn new(kind: ProtKind, flag: ProtFlag, sock_map: u32) -> ProtFrameHeader { + pub fn new(kind: ProtKind, flag: ProtFlag, sock_map: u32, server_id: u32) -> ProtFrameHeader { ProtFrameHeader { length: 0, kind, flag, sock_map, + server_id, } } + pub fn server_id(&self) -> u32 { + self.server_id + } + pub fn sock_map(&self) -> u32 { self.sock_map } @@ -81,11 +89,13 @@ impl ProtFrameHeader { let kind = buffer.get_u8(); let flag = buffer.get_u8(); let sock_map = read_u24(buffer); + let server_id = buffer.get_u32(); Ok(ProtFrameHeader { length, kind: ProtKind::new(kind), flag: ProtFlag::new(flag), sock_map, + server_id, }) } @@ -96,6 +106,7 @@ impl ProtFrameHeader { size += buffer.put_u8(self.kind.encode()); size += buffer.put_u8(self.flag.bits()); size += encode_u24(buffer, self.sock_map); + size += buffer.put_u32(self.server_id); Ok(size) } @@ -133,8 +144,8 @@ impl ProtFrame { Ok(size) } - pub fn new_create(sock_map: u32, domain: Option) -> Self { - Self::Create(ProtCreate::new(sock_map, domain)) + pub fn new_create(server_id: u32, sock_map: u32, domain: Option) -> Self { + Self::Create(ProtCreate::new(server_id, sock_map, domain)) } pub fn new_close(sock_map: u32) -> Self { diff --git a/src/streams/center_trans.rs b/src/streams/center_trans.rs new file mode 100644 index 0000000..66609e7 --- /dev/null +++ b/src/streams/center_trans.rs @@ -0,0 +1,77 @@ +// Copyright 2022 - 2023 Wenmeng See the COPYRIGHT +// file at the top-level directory of this distribution. +// +// Licensed under the Apache License, Version 2.0 , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. +// +// Author: tickbh +// ----- +// Created Date: 2023/09/25 10:08:56 + +use std::{io, sync::Arc}; + +use tokio::{ + io::{AsyncRead, AsyncWrite}, +}; + +use tokio_rustls::{TlsConnector}; + + +use wenmeng::MaybeHttpsStream; + +use crate::{ + HealthCheck, ProxyResult, +}; + +/// 中心服务端 +/// 接受中心客户端的连接,并且将信息处理或者转发 +pub struct CenterTrans { + server: String, + domain: Option, + tls_client: Option>, +} + +impl CenterTrans { + pub fn new( + server: String, + domain: Option, + tls_client: Option>, + ) -> Self { + Self { + server, + domain, + tls_client, + } + } + + pub async fn serve(&mut self, mut stream: T) -> ProxyResult<()> + where + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, + { + let mut server = if self.tls_client.is_some() { + let connector = TlsConnector::from(self.tls_client.clone().unwrap()); + let stream = HealthCheck::connect(&self.server).await?; + // 这里的域名只为认证设置 + let domain = rustls::ServerName::try_from( + &*self + .domain + .clone() + .unwrap_or("soft.wm-proxy.com".to_string()), + ) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid dnsname"))?; + + let outbound = connector.connect(domain, stream).await?; + MaybeHttpsStream::Https(outbound) + } else { + let outbound = HealthCheck::connect(&self.server).await?; + MaybeHttpsStream::Http(outbound) + }; + + tokio::spawn(async move { + let _ = tokio::io::copy_bidirectional(&mut stream, &mut server).await; + }); + Ok(()) + } +} diff --git a/src/streams/mod.rs b/src/streams/mod.rs index 00a884d..18e1d6b 100644 --- a/src/streams/mod.rs +++ b/src/streams/mod.rs @@ -12,10 +12,12 @@ mod center_client; mod center_server; +mod center_trans; mod trans_stream; mod virtual_stream; pub use center_client::CenterClient; pub use center_server::CenterServer; +pub use center_trans::CenterTrans; pub use trans_stream::TransStream; -pub use virtual_stream::VirtualStream; \ No newline at end of file +pub use virtual_stream::VirtualStream; diff --git a/src/wmcore.rs b/src/wmcore.rs index fc20505..14acbe5 100644 --- a/src/wmcore.rs +++ b/src/wmcore.rs @@ -27,15 +27,11 @@ use tokio::{ Mutex, }, }; -use tokio_rustls::{rustls, TlsAcceptor, TlsConnector}; +use tokio_rustls::{rustls, TlsAcceptor}; use crate::{ - option::ConfigOption, - proxy::ProxyServer, - reverse::{HttpConfig, ServerConfig, StreamConfig, StreamUdp}, - ActiveHealth, CenterClient, CenterServer, HealthCheck, OneHealth, - ProxyResult, Helper, + option::ConfigOption, proxy::ProxyServer, reverse::{HttpConfig, ServerConfig, StreamConfig, StreamUdp}, ActiveHealth, CenterClient, CenterServer, CenterTrans, Helper, OneHealth, ProxyResult }; pub struct WMCore { @@ -45,6 +41,7 @@ pub struct WMCore { health_sender: Option>>, pub proxy_accept: Option, pub proxy_client: Option>, + pub client_listener: Option, pub center_listener: Option, pub map_http_listener: Option, @@ -72,6 +69,7 @@ impl WMCore { health_sender: None, proxy_accept: None, proxy_client: None, + client_listener: None, center_listener: None, map_http_listener: None, @@ -91,7 +89,9 @@ impl WMCore { } } - async fn deal_stream( + /// 来自中心端的连接, 如果存在上级则无条件转发到上级 + /// 如果不传在上级, 则构建中心服处理该请求 + async fn deal_center_stream( &mut self, inbound: T, _addr: SocketAddr, @@ -100,38 +100,45 @@ impl WMCore { where T: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static, { - // 转发到服务端 - if let Some(client) = &mut self.center_client { - return client.deal_new_stream(inbound).await; - } if let Some(option) = &mut self.option.proxy { - // 服务端开代理, 接收到客户端一律用协议处理 - if option.center && option.is_server() { + if let Some(server) = option.server.clone() { + let mut server = CenterTrans::new(server, option.domain.clone(), tls_client); + return server.serve(inbound).await; + } else { let server = CenterServer::new(option.clone()); self.center_servers.push(server); return self.center_servers.last_mut().unwrap().serve(inbound).await; } + } + Ok(()) + } - let flag = option.flag; - let domain = option.domain.clone(); - if let Some(server) = option.server.clone() { - tokio::spawn(async move { - // 转到上层服务器进行处理 - let _e = Self::transfer_server(domain, tls_client, inbound, server).await; - }); - } else { - let proxy_server = ProxyServer::new( - flag, - option.username.clone(), - option.password.clone(), - option.udp_bind.clone(), - None, - ); - tokio::spawn(async move { - // tcp的连接被移动到该协程中,我们只要专注的处理该stream即可 - let _ = proxy_server.deal_proxy(inbound).await; - }); - } + /// 处理客户端的请求, 仅可能有上级转发给上级 + /// 没有上级直接处理当前代理数据 + async fn deal_client_stream( + &mut self, + inbound: T, + _addr: SocketAddr, + ) -> ProxyResult<()> + where + T: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static, + { + // 转发到服务端 + if let Some(client) = &mut self.center_client { + return client.deal_new_stream(inbound).await; + } + if let Some(option) = &mut self.option.proxy { + let proxy_server = ProxyServer::new( + option.flag, + option.username.clone(), + option.password.clone(), + option.udp_bind.clone(), + None, + ); + tokio::spawn(async move { + // tcp的连接被移动到该协程中,我们只要专注的处理该stream即可 + let _ = proxy_server.deal_proxy(inbound).await; + }); } Ok(()) @@ -201,6 +208,7 @@ impl WMCore { ( self.proxy_accept, self.proxy_client, + self.client_listener, self.center_listener, self.center_client, ) = option.bind().await?; @@ -249,22 +257,26 @@ impl WMCore { loop { tokio::select! { Some((inbound, addr)) = Self::tcp_listen_work(&self.center_listener) => { - log::trace!("代理收到客户端连接: {}->{}", addr, self.center_listener.as_ref().unwrap().local_addr()?); + log::trace!("中心代理收到客户端连接: {}->{}", addr, self.center_listener.as_ref().unwrap().local_addr()?); if let Some(a) = self.proxy_accept.clone() { let inbound = a.accept(inbound).await; // 获取的流跟正常内容一样读写, 在内部实现了自动加解密 match inbound { Ok(inbound) => { - let _ = self.deal_stream(inbound, addr, self.proxy_client.clone()).await; + let _ = self.deal_center_stream(inbound, addr, self.proxy_client.clone()).await; } Err(e) => { log::warn!("接收来自下级代理的连接失败, 原因为: {:?}", e); } } } else { - let _ = self.deal_stream(inbound, addr, self.proxy_client.clone()).await; + let _ = self.deal_center_stream(inbound, addr, self.proxy_client.clone()).await; }; } + Some((inbound, addr)) = Self::tcp_listen_work(&self.client_listener) => { + log::trace!("代理收到客户端连接: {}->{}", addr, self.center_listener.as_ref().unwrap().local_addr()?); + let _ = self.deal_client_stream(inbound, addr).await; + } Some((inbound, addr)) = Self::tcp_listen_work(&self.map_http_listener) => { log::trace!("内网穿透:Http收到客户端连接: {}->{}", addr, self.map_http_listener.as_ref().unwrap().local_addr()?); self.server_new_http(inbound, addr).await?; @@ -357,37 +369,8 @@ impl WMCore { Ok(()) } - async fn transfer_server( - domain: Option, - tls_client: Option>, - mut inbound: T, - server: String, - ) -> ProxyResult<()> - where - T: AsyncRead + AsyncWrite + Unpin, - { - if tls_client.is_some() { - let connector = TlsConnector::from(tls_client.unwrap()); - let stream = HealthCheck::connect(&server).await?; - // 这里的域名只为认证设置 - let domain = - rustls::ServerName::try_from(&*domain.unwrap_or("soft.wm-proxy.com".to_string())) - .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid dnsname"))?; - - if let Ok(mut outbound) = connector.connect(domain, stream).await { - // connect 之后的流跟正常内容一样读写, 在内部实现了自动加解密 - let _ = tokio::io::copy_bidirectional(&mut inbound, &mut outbound).await?; - } else { - // TODO 返回对应协议的错误 - } - } else { - if let Ok(mut outbound) = HealthCheck::connect(&server).await { - let _ = tokio::io::copy_bidirectional(&mut inbound, &mut outbound).await?; - } else { - // TODO 返回对应协议的错误 - } - } - Ok(()) + pub fn clear_close_servers(&mut self) { + self.center_servers.retain(|s| !s.is_close()); } pub async fn server_new_http( @@ -395,6 +378,7 @@ impl WMCore { stream: TcpStream, addr: SocketAddr, ) -> ProxyResult<()> { + self.clear_close_servers(); for server in &mut self.center_servers { if !server.is_close() { return server.server_new_http(stream, addr).await; @@ -410,6 +394,7 @@ impl WMCore { addr: SocketAddr, accept: TlsAcceptor, ) -> ProxyResult<()> { + self.clear_close_servers(); for server in &mut self.center_servers { if !server.is_close() { return server.server_new_https(stream, addr, accept).await; @@ -424,6 +409,7 @@ impl WMCore { stream: TcpStream, _addr: SocketAddr, ) -> ProxyResult<()> { + self.clear_close_servers(); for server in &mut self.center_servers { if !server.is_close() { return server.server_new_tcp(stream).await; @@ -438,6 +424,7 @@ impl WMCore { stream: TcpStream, _addr: SocketAddr, ) -> ProxyResult<()> { + self.clear_close_servers(); for server in &mut self.center_servers { if !server.is_close() { return server.server_new_prxoy(stream).await; diff --git a/tests/mapping.rs b/tests/mapping.rs index 1aec42f..ec9f28b 100644 --- a/tests/mapping.rs +++ b/tests/mapping.rs @@ -119,13 +119,11 @@ mod tests { let local_server_addr = run_server().await.unwrap(); let addr = "127.0.0.1:0".parse().unwrap(); let proxy = ProxyConfig::builder() - .bind_addr(addr) + .center_addr(addr) .map_http_bind(Some(addr)) .map_https_bind(Some(addr)) .map_tcp_bind(Some(addr)) .map_proxy_bind(Some(addr)) - .center(true) - .mode("server".to_string()) .into_value() .unwrap(); @@ -161,10 +159,8 @@ mod tests { mapping_proxy.local_addr = Some(local_server_addr); let proxy = ProxyConfig::builder() - .bind_addr(addr) + .bind(addr) .server(Some(format!("{}", server_addr))) - .center(true) - .mode("client".to_string()) .mapping(mapping) .mapping(mapping_tcp) .mapping(mapping_proxy) diff --git a/tests/proxy.rs b/tests/proxy.rs index 621d7ac..5ec3907 100644 --- a/tests/proxy.rs +++ b/tests/proxy.rs @@ -13,7 +13,7 @@ mod tests { static HTTP_URL: &str = "http://www.baidu.com"; static HTTPS_URL: &str = "https://www.baidu.com"; - async fn run_proxy( + async fn run_server_proxy( proxy: ProxyConfig, ) -> ProxyResult<(SocketAddr, Sender<()>)> { let option = ConfigOption::new_by_proxy(proxy); @@ -27,6 +27,20 @@ mod tests { Ok((addr, sender_close)) } + async fn run_proxy( + proxy: ProxyConfig, + ) -> ProxyResult<(SocketAddr, Sender<()>)> { + let option = ConfigOption::new_by_proxy(proxy); + let (sender_close, receiver_close) = channel::<()>(1); + let mut proxy = WMCore::new(option); + proxy.ready_serve().await.unwrap(); + let addr = proxy.client_listener.as_ref().unwrap().local_addr()?; + tokio::spawn(async move { + let _ = proxy.run_serve(receiver_close, None).await; + }); + Ok((addr, sender_close)) + } + async fn test_proxy( addr: SocketAddr, url: &str, @@ -92,7 +106,7 @@ mod tests { async fn test_no_auth() { let addr = "127.0.0.1:0".parse().unwrap(); let proxy = ProxyConfig::builder() - .bind_addr(addr) + .bind(addr) .into_value() .unwrap(); @@ -109,7 +123,7 @@ mod tests { let username = "wmproxy".to_string(); let password = "wmproxy".to_string(); let proxy = ProxyConfig::builder() - .bind_addr(addr) + .bind(addr) .username(Some(username.clone())) .password(Some(password.clone())) .into_value() @@ -137,23 +151,20 @@ mod tests { let password = "wmproxy".to_string(); let proxy = ProxyConfig::builder() - .bind_addr(addr) + .center_addr(addr) .username(Some(username.clone())) .password(Some(password.clone())) - .center(true) - .mode("server".to_string()) .into_value() .unwrap(); - let (server_addr, _sender) = run_proxy(proxy) + let (server_addr, _sender) = run_server_proxy(proxy) .await .unwrap(); let proxy = ProxyConfig::builder() - .bind_addr(addr) + .bind(addr) .username(Some(username.clone())) .password(Some(password.clone())) - .center(true) .server(Some(format!("{}", server_addr))) .into_value() .unwrap();