Skip to content

Commit

Permalink
升级rustls
Browse files Browse the repository at this point in the history
  • Loading branch information
tickbh committed Jan 26, 2024
1 parent a7aa354 commit a961231
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 101 deletions.
26 changes: 13 additions & 13 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "wmproxy"
version = "0.2.5"
version = "0.2.6"
edition = "2021"
authors = ["tickbh <[email protected]>"]
description = " http https proxy by rust"
Expand All @@ -16,15 +16,15 @@ log = "0.4.19"
bitflags = "2.4"

tokio-util = "0.7"
rustls = { version = "0.21.7", default-features = false }
webpki-roots = "0.25"
rustls-pemfile = "1.0.3"
rustls = { version = "0.22.2", default-features = false }
webpki-roots = "0.26.0"
rustls-pemfile = "2.0.0"
webpki = { version = "0.22", features = ["alloc", "std"] }
tokio-rustls = "0.24"
tokio-rustls = "0.25.0"
futures-core = { version = "0.3", default-features = false }
futures = "0.3.28"

env_logger = "0.10.0"
env_logger = "0.11.0"
serde = { version = "1.0", features = ["derive"] }
serde_with = "3.4.0"
serde_yaml = "0.9"
Expand All @@ -39,7 +39,7 @@ log4rs = "1.2.0"
chrono = "0.4.31"

async-trait = "0.1.74"
rbtree = "0.1.7"
rbtree = "0.2.0"

regex = "1.10.2"

Expand All @@ -53,15 +53,15 @@ bpaf = { version = "0.9.8", features = [
"batteries",
"autocomplete",
] }
# webparse = { version = "0.2.4" }
# wenmeng = { version = "0.2.5" }
webparse = { version = "0.2.6" }
wenmeng = { version = "0.2.6" }
# wenmeng={git="https://github.com/tickbh/wenmeng.git"}
[features]
bright-color = ["bpaf/bright-color"]
dull-color = ["bpaf/dull-color"]

[dependencies.webparse]
path = "../webparse"
# [dependencies.webparse]
# path = "../webparse"

[dependencies.wenmeng]
path = "../wenmeng"
# [dependencies.wenmeng]
# path = "../wenmeng"
90 changes: 36 additions & 54 deletions src/option.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,16 @@ use std::{

use bpaf::*;
use log::LevelFilter;
use rustls::{Certificate, ClientConfig, PrivateKey};
use rustls::{pki_types::{CertificateDer, PrivateKeyDer}, ClientConfig};

use serde::{Deserialize, Serialize};
use serde_with::{serde_as, DisplayFromStr};
use tokio::net::TcpListener;
use tokio_rustls::{rustls, TlsAcceptor};

use crate::{
reverse::{HttpConfig, StreamConfig, UpstreamConfig}, CenterClient, Flag, Helper, MappingConfig, OneHealth, ProxyError, ProxyResult, WrapAddr
reverse::{HttpConfig, StreamConfig, UpstreamConfig},
CenterClient, Flag, Helper, MappingConfig, OneHealth, ProxyError, ProxyResult, WrapAddr,
};

pub struct Builder {
Expand Down Expand Up @@ -79,7 +80,6 @@ impl Builder {
Ok(proxy)
})
}


pub fn server(self, addr: Option<String>) -> Builder {
self.and_then(|mut proxy| {
Expand Down Expand Up @@ -202,32 +202,17 @@ fn default_bind_addr() -> SocketAddr {
#[derive(Debug, Clone, Serialize, Deserialize, Bpaf)]
pub struct ProxyConfig {
/// 代理id
#[bpaf(
fallback(0),
display_fallback,
short('s'),
long
)]
#[bpaf(fallback(0), display_fallback, short('s'), long)]
#[serde(default)]
pub(crate) server_id: u32,

/// 代理绑定端口地址
#[bpaf(
// fallback(Some(WrapAddr(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8090)))),
// display_fallback,
short('b'),
long
)]
#[bpaf(short('b'), long)]
#[serde_as(as = "Option<DisplayFromStr>")]
pub(crate) bind: Option<WrapAddr>,

/// 代理绑定端口地址
#[bpaf(
// fallback(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8090)),
// display_fallback,
short('c'),
long
)]

/// 中心代理绑定端口地址
#[bpaf(short('c'), long)]
#[serde_as(as = "Option<DisplayFromStr>")]
pub(crate) center_addr: Option<WrapAddr>,

Expand All @@ -236,10 +221,6 @@ pub struct ProxyConfig {
#[serde_as(as = "DisplayFromStr")]
#[serde(default)]
pub(crate) flag: Flag,
// /// 启动程序类型
// #[bpaf(fallback("client".to_string()))]
// #[serde(default)]
// pub(crate) mode: String,

/// 连接代理服务端地址
#[bpaf(short('S'), long("server"))]
Expand Down Expand Up @@ -360,12 +341,12 @@ impl ProxyConfig {
Builder::new()
}

fn load_certs(path: &Option<String>) -> io::Result<Vec<Certificate>> {
fn load_certs(path: &Option<String>) -> io::Result<Vec<CertificateDer<'static>>> {
if let Some(path) = path {
let file = File::open(path)?;
let mut reader = BufReader::new(file);
let certs = rustls_pemfile::certs(&mut reader)?;
Ok(certs.into_iter().map(Certificate).collect())
let certs = rustls_pemfile::certs(&mut reader);
Ok(certs.into_iter().collect::<Result<Vec<_>, _>>()?)
} else {
let cert = br"-----BEGIN CERTIFICATE-----
MIIF+zCCBOOgAwIBAgIQCkkcvmucB5JXt9JAehuNqTANBgkqhkiG9w0BAQsFADBu
Expand Down Expand Up @@ -432,16 +413,16 @@ n2hcLrfZSbynEC/pSw/ET7H5nWwckjmAJ1l9fcnbqkU/pf6uMQmnfl0JQjJNSg==

let cursor = io::Cursor::new(cert);
let mut buf = BufReader::new(cursor);
let certs = rustls_pemfile::certs(&mut buf)?;
Ok(certs.into_iter().map(Certificate).collect())
let certs = rustls_pemfile::certs(&mut buf);
Ok(certs.into_iter().collect::<Result<Vec<_>, _>>()?)
}
}

fn load_keys(path: &Option<String>) -> io::Result<PrivateKey> {
fn load_keys(path: &Option<String>) -> io::Result<PrivateKeyDer<'static>> {
let mut keys = if let Some(path) = path {
let file = File::open(&path)?;
let mut reader = BufReader::new(file);
rustls_pemfile::rsa_private_keys(&mut reader)?
rustls_pemfile::rsa_private_keys(&mut reader).collect::<Result<Vec<_>, _>>()?
} else {
let key = br"-----BEGIN RSA PRIVATE KEY-----
MIIEpQIBAAKCAQEAw7gdhMEwp5al49V4b3DkwPWUa/Aiaxo5dk8+JWETaIfU8L9w
Expand Down Expand Up @@ -473,15 +454,15 @@ cR+nZ6DRmzKISbcN9/m8I7xNWwU2cglrYa4NCHguQSrTefhRoZAfl8BEOW1rJVGC
";
let cursor = io::Cursor::new(key);
let mut buf = BufReader::new(cursor);
rustls_pemfile::rsa_private_keys(&mut buf)?
rustls_pemfile::rsa_private_keys(&mut buf).collect::<Result<Vec<_>, _>>()?
};

match keys.len() {
0 => Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("No RSA private key found"),
)),
1 => Ok(PrivateKey(keys.remove(0))),
1 => Ok(PrivateKeyDer::from(keys.remove(0))),
_ => Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("More than one RSA private key found"),
Expand All @@ -498,7 +479,6 @@ cR+nZ6DRmzKISbcN9/m8I7xNWwU2cglrYa4NCHguQSrTefhRoZAfl8BEOW1rJVGC
let key = Self::load_keys(&self.map_key)?;

let config = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(certs, key)
.map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
Expand All @@ -518,18 +498,22 @@ cR+nZ6DRmzKISbcN9/m8I7xNWwU2cglrYa4NCHguQSrTefhRoZAfl8BEOW1rJVGC
let certs = Self::load_certs(&self.cert)?;
let key = Self::load_keys(&self.key)?;

let config = rustls::ServerConfig::builder().with_safe_defaults();
let config = rustls::ServerConfig::builder();
// 开始双向认证,需要客户端提供证书信息
let config = if self.two_way_tls {
let mut client_auth_roots = rustls::RootCertStore::empty();
for root in &certs {
client_auth_roots.add(&root).unwrap();
for root in certs.clone().into_iter() {
client_auth_roots.add(root).unwrap();
}
let client_auth =
rustls::server::WebPkiClientVerifier::builder(client_auth_roots.into())
.build()
.map_err(|_| ProxyError::Extension("add cert error"))?;

let client_auth = rustls::server::AllowAnyAuthenticatedClient::new(client_auth_roots);
// let client_auth = rustls::server::AllowAnyAuthenticatedClient::new(client_auth_roots);

config
.with_client_cert_verifier(client_auth.boxed())
.with_client_cert_verifier(client_auth)
.with_single_cert(certs, key)
.map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?
} else {
Expand All @@ -551,19 +535,11 @@ cR+nZ6DRmzKISbcN9/m8I7xNWwU2cglrYa4NCHguQSrTefhRoZAfl8BEOW1rJVGC
let certs = Self::load_certs(&self.cert)?;
let mut root_cert_store = rustls::RootCertStore::empty();
// 信任通用的签名商
root_cert_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| {
rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
}));
for cert in &certs {
root_cert_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
for cert in certs.clone().into_iter() {
let _ = root_cert_store.add(cert);
}
let config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_cert_store);
let config = rustls::ClientConfig::builder().with_root_certificates(root_cert_store);

if self.two_way_tls {
let key = Self::load_keys(&self.key)?;
Expand Down Expand Up @@ -635,7 +611,13 @@ cR+nZ6DRmzKISbcN9/m8I7xNWwU2cglrYa4NCHguQSrTefhRoZAfl8BEOW1rJVGC
} else {
None
};
Ok((proxy_accept, client, client_listener, center_listener, center_client))
Ok((
proxy_accept,
client,
client_listener,
center_listener,
center_client,
))
}

pub async fn bind_map(
Expand Down
21 changes: 11 additions & 10 deletions src/reverse/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ use std::{
use crate::{data::LimitReqData, Helper, ProxyResult};
use async_trait::async_trait;
use rustls::{
crypto::ring::sign::any_supported_type,
pki_types::{CertificateDer, PrivateKeyDer},
server::ResolvesServerCertUsingSni,
sign::{self, CertifiedKey},
Certificate, PrivateKey,
sign::CertifiedKey,
};
use serde::{Deserialize, Serialize};
use serde_with::{serde_as, DisplayFromStr};
Expand Down Expand Up @@ -133,13 +134,13 @@ impl HttpConfig {
}
}

fn load_certs(path: &Option<String>) -> io::Result<Vec<Certificate>> {
fn load_certs(path: &Option<String>) -> io::Result<Vec<CertificateDer<'static>>> {
if let Some(path) = path {
match File::open(&path) {
Ok(file) => {
let mut reader = BufReader::new(file);
let certs = rustls_pemfile::certs(&mut reader)?;
Ok(certs.into_iter().map(Certificate).collect())
let certs = rustls_pemfile::certs(&mut reader);
Ok(certs.into_iter().collect::<Result<Vec<_>, _>>()?)
}
Err(e) => {
log::warn!("加载公钥{}出错,错误内容:{:?}", path, e);
Expand All @@ -151,12 +152,12 @@ impl HttpConfig {
}
}

fn load_keys(path: &Option<String>) -> io::Result<PrivateKey> {
fn load_keys(path: &Option<String>) -> io::Result<PrivateKeyDer<'static>> {
let mut keys = if let Some(path) = path {
match File::open(&path) {
Ok(file) => {
let mut reader = BufReader::new(file);
rustls_pemfile::rsa_private_keys(&mut reader)?
rustls_pemfile::rsa_private_keys(&mut reader).collect::<Result<Vec<_>, _>>()?
}
Err(e) => {
log::warn!("加载私钥{}出错,错误内容:{:?}", path, e);
Expand All @@ -172,7 +173,7 @@ impl HttpConfig {
io::ErrorKind::InvalidInput,
format!("No RSA private key found"),
)),
1 => Ok(PrivateKey(keys.remove(0))),
1 => Ok(PrivateKeyDer::from(keys.remove(0))),
_ => Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("More than one RSA private key found"),
Expand All @@ -186,12 +187,12 @@ impl HttpConfig {
let mut listeners = vec![];
let mut tlss = vec![];
let mut bind_port = HashSet::new();
let config = rustls::ServerConfig::builder().with_safe_defaults();
let config = rustls::ServerConfig::builder();
let mut resolve = ResolvesServerCertUsingSni::new();
for value in &self.server.clone() {
let mut is_ssl = false;
if value.cert.is_some() && value.key.is_some() {
let key = sign::any_supported_type(&Self::load_keys(&value.key)?)
let key = any_supported_type(&Self::load_keys(&value.key)?)
.map_err(|_| ProtError::Extension("unvaild key"))?;
let ck = CertifiedKey::new(Self::load_certs(&value.cert)?, key);
resolve.add(&value.up_name, ck).map_err(|e| {
Expand Down
4 changes: 2 additions & 2 deletions src/streams/center_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ impl CenterClient {
let stream = HealthCheck::connect(&server_addr).await?;
// 这里的域名只为认证设置
let domain =
rustls::ServerName::try_from(&*domain.unwrap_or("soft.wm-proxy.com".to_string()))
rustls::pki_types::ServerName::try_from(domain.unwrap_or("soft.wm-proxy.com".to_string()))
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid dnsname"))?;

let outbound = connector.connect(domain, stream).await?;
Expand Down Expand Up @@ -364,7 +364,7 @@ impl CenterClient {

fn calc_next_id(&mut self) -> u64 {
let id = self.next_id;
self.next_id += 2;
self.next_id = self.next_id.wrapping_add(2);
Helper::calc_sock_map(self.option.server_id, id)
}

Expand Down
2 changes: 1 addition & 1 deletion src/streams/center_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ impl CenterServer {

pub fn calc_next_id(&mut self) -> u64 {
let id = self.next_id;
self.next_id += 2;
self.next_id = self.next_id.wrapping_add(2);
Helper::calc_sock_map(self.option.server_id, id)
}

Expand Down
15 changes: 5 additions & 10 deletions src/streams/center_trans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,13 @@

use std::{io, sync::Arc};

use tokio::{
io::{AsyncRead, AsyncWrite},
};

use tokio_rustls::{TlsConnector};
use tokio::io::{AsyncRead, AsyncWrite};

use tokio_rustls::TlsConnector;

use wenmeng::MaybeHttpsStream;

use crate::{
HealthCheck, ProxyResult,
};
use crate::{HealthCheck, ProxyResult};

/// 中心服务端
/// 接受中心客户端的连接,并且将信息处理或者转发
Expand Down Expand Up @@ -54,8 +49,8 @@ impl CenterTrans {
let connector = TlsConnector::from(self.tls_client.clone().unwrap());
let stream = HealthCheck::connect(&self.server).await?;
// 这里的域名只为认证设置
let domain = rustls::ServerName::try_from(
&*self
let domain = rustls::pki_types::ServerName::try_from(
self
.domain
.clone()
.unwrap_or("soft.wm-proxy.com".to_string()),
Expand Down
Loading

0 comments on commit a961231

Please sign in to comment.