diff --git a/src/utils.rs b/src/utils.rs index 7bbb0c7..65c8b7f 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,3 +1,4 @@ +use std::fmt; use tokio::io::{Error, ErrorKind, Result}; pub type Aes128CfbEnc = cfb_mode::Encryptor; @@ -50,23 +51,70 @@ macro_rules! md5 { } } -// TODO: find port -pub(crate) fn extract_host(buf: &[u8]) -> Result<&[u8]> { +pub(crate) struct Addr<'a> { + host: Option<&'a [u8]>, + port: Option, +} + +impl<'a> fmt::Debug for Addr<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut addr = vec![]; + + if let Some(host) = self.host { + addr.extend_from_slice(host); + } + + if let Some(port) = self.port { + addr.extend_from_slice(b":"); + addr.extend_from_slice(&port.to_string().as_bytes()); + } + + write!(f, "{:?}", addr) + } +} + +impl<'a> Addr<'a> { + pub(crate) fn host(&self) -> &'a [u8] { + self.host.unwrap_or_default() + } + + pub(crate) fn port(&self) -> u16 { + self.port.unwrap_or(80) + } +} + +pub(crate) fn extract_addr<'a>(buf: &'a [u8]) -> Result> { let header = &[72, 111, 115, 116, 58, 32]; // "Host: " - let start = buf + let mut addr = Addr { + host: None, + port: None, + }; + + let mut start = buf .windows(header.len()) .position(|w| w == header) .map(|x| x + header.len()) - .ok_or(Error::new(ErrorKind::Other, "could not extract the host"))?; + .ok_or(Error::new(ErrorKind::Other, "could not extract address"))?; - let offset = buf + let offset = buf[start..] .iter() - .skip(start) .position(|&x| x == b'\r') - .ok_or(Error::new(ErrorKind::Other, "could not extract the host"))?; + .ok_or(Error::new(ErrorKind::Other, "could not extract address"))?; + + let port_offset = buf[start..start + offset].iter().position(|&x| x == b':'); + if let Some(port_offset) = port_offset { + addr.host = Some(&buf[start..start + port_offset]); - Ok(&buf[start..start + offset]) + let end = start + offset; + start += port_offset + 1; // skip colon + let port = String::from_utf8_lossy(&buf[start..end]); + addr.port = u16::from_str_radix(&port, 10).ok(); + } else { + addr.host = Some(&buf[start..start + offset]); + } + + Ok(addr) } #[cfg(test)] @@ -74,9 +122,16 @@ mod tests { use super::*; #[test] - fn test_extract_host() { + fn test_extract_addr() { let buf = b"GET http://google.com/ HTTP/1.1\r\nHost: google.com\r\nUser-Agent: curl/7.85.0"; - let host = extract_host(buf).unwrap(); - assert_eq!(host, b"google.com"); + let addr = extract_addr(buf).unwrap(); + assert_eq!(addr.host(), b"google.com"); + assert_eq!(addr.port(), 80); + + let buf = + b"GET http://google.com/ HTTP/1.1\r\nHost: google.com:443\r\nUser-Agent: curl/7.85.0"; + let addr = extract_addr(buf).unwrap(); + assert_eq!(addr.host(), b"google.com"); + assert_eq!(addr.port(), 443); } } diff --git a/src/vmess.rs b/src/vmess.rs index dd6c470..23647ec 100644 --- a/src/vmess.rs +++ b/src/vmess.rs @@ -1,4 +1,4 @@ -use crate::utils::{extract_host, Aes128CfbDec, Aes128CfbEnc}; +use crate::utils::{extract_addr, Addr, Aes128CfbDec, Aes128CfbEnc}; use std::marker::Unpin; use std::time::{SystemTime, UNIX_EPOCH}; @@ -175,7 +175,7 @@ pub struct VmessWriter { } impl VmessWriter { - async fn handshake(&mut self, domain: &[u8]) -> Result<()> { + async fn handshake<'a>(&mut self, addr: Addr<'a>) -> Result<()> { // https://xtls.github.io/en/development/protocols/vmess.html#authentication-information // // +----------------------------+ @@ -229,13 +229,13 @@ impl VmessWriter { ]); // TODO: extract port from request. for now we use 80 for all requests - cmd.extend_from_slice(&(80u16).to_be_bytes()); // Port + cmd.extend_from_slice(&addr.port().to_be_bytes()); // Port // TODO: support ipv4/ipv6. for now we just support domain name cmd.extend_from_slice(&[0x02]); // Address Type: Domain name - let mut address = vec![domain.len() as _]; - address.extend_from_slice(domain); + let mut address = vec![addr.host().len() as _]; + address.extend_from_slice(addr.host()); cmd.extend_from_slice(&address); // P bytes random value -> assume p = 0, so we don't push data for it @@ -298,10 +298,10 @@ impl VmessWriter { pub async fn write(&mut self, buf: &[u8]) -> Result<()> { if !self.handshaked { - let domain = extract_host(buf)?; - log::info!("accepted {}", String::from_utf8_lossy(domain)); + let addr = extract_addr(buf)?; + log::info!("accepted {:?}", addr); - self.handshake(domain).await?; + self.handshake(addr).await?; self.handshaked = true; }