From e3f5a042d1d5d00cfdac0b3f8f5cd15f33f25253 Mon Sep 17 00:00:00 2001 From: V Date: Wed, 29 Jan 2025 19:32:49 +0800 Subject: [PATCH] style: make poll_read_exact a general trait (#689) --- clash_lib/src/common/io.rs | 55 ++++++++++++++ .../proxy/shadowsocks/shadow_tls/stream.rs | 71 +++---------------- .../src/proxy/vmess/vmess_impl/stream.rs | 53 ++------------ 3 files changed, 73 insertions(+), 106 deletions(-) diff --git a/clash_lib/src/common/io.rs b/clash_lib/src/common/io.rs index cdc45d9ce..123de6794 100644 --- a/clash_lib/src/common/io.rs +++ b/clash_lib/src/common/io.rs @@ -2,11 +2,13 @@ use std::future::Future; use std::{ io, + mem::MaybeUninit, pin::Pin, task::{Context, Poll}, time::Duration, }; +use bytes::BytesMut; use futures::ready; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; @@ -352,3 +354,56 @@ where } .await } + +pub trait ReadExactBase { + /// inner stream to be polled + type I: AsyncRead + Unpin; + /// prepare the inner stream, read buffer and read position + fn decompose(&mut self) -> (&mut Self::I, &mut BytesMut, &mut usize); +} + +pub trait ReadExt: ReadExactBase { + fn poll_read_exact( + &mut self, + cx: &mut std::task::Context, + size: usize, + ) -> Poll>; +} + +impl ReadExt for T { + fn poll_read_exact( + &mut self, + cx: &mut std::task::Context, + size: usize, + ) -> Poll> { + let (raw, read_buf, read_pos) = self.decompose(); + read_buf.reserve(size); + // # safety: read_buf has reserved `size` + unsafe { read_buf.set_len(size) } + loop { + if *read_pos < size { + // # safety: read_pos]) + }; + let mut buf = ReadBuf::uninit(dst); + let ptr = buf.filled().as_ptr(); + ready!(Pin::new(&mut *raw).poll_read(cx, &mut buf))?; + assert_eq!(ptr, buf.filled().as_ptr()); + if buf.filled().is_empty() { + return Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "unexpected eof", + ))); + } + *read_pos += buf.filled().len(); + } else { + assert!(*read_pos == size); + *read_pos = 0; + return Poll::Ready(Ok(())); + } + } + } +} diff --git a/clash_lib/src/proxy/shadowsocks/shadow_tls/stream.rs b/clash_lib/src/proxy/shadowsocks/shadow_tls/stream.rs index 9f441339c..4a0373eda 100644 --- a/clash_lib/src/proxy/shadowsocks/shadow_tls/stream.rs +++ b/clash_lib/src/proxy/shadowsocks/shadow_tls/stream.rs @@ -1,5 +1,4 @@ use std::{ - mem::MaybeUninit, pin::Pin, ptr::{copy, copy_nonoverlapping}, task::{ready, Poll}, @@ -7,7 +6,9 @@ use std::{ use byteorder::{BigEndian, WriteBytesExt}; use bytes::{BufMut, BytesMut}; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio::io::{AsyncRead, AsyncWrite}; + +use crate::common::io::{ReadExactBase, ReadExt}; use super::utils::{prelude::*, *}; @@ -26,60 +27,6 @@ pub enum WriteState { FlushingData(usize, usize, usize), } -pub trait AsyncReadUnpin: AsyncRead + Unpin {} - -impl AsyncReadUnpin for T {} - -pub trait ReadExtBase { - fn prepare(&mut self) -> (&mut dyn AsyncReadUnpin, &mut BytesMut, &mut usize); -} - -pub trait ReadExt { - fn poll_read_exact( - &mut self, - cx: &mut std::task::Context, - size: usize, - ) -> Poll>; -} - -impl ReadExt for T { - fn poll_read_exact( - &mut self, - cx: &mut std::task::Context, - size: usize, - ) -> Poll> { - let (raw, read_buf, read_pos) = self.prepare(); - read_buf.reserve(size); - // # safety: read_buf has reserved `size` - unsafe { read_buf.set_len(size) } - loop { - if *read_pos < size { - // # safety: read_pos]) - }; - let mut buf = ReadBuf::uninit(dst); - let ptr = buf.filled().as_ptr(); - ready!(Pin::new(&mut *raw).poll_read(cx, &mut buf))?; - assert_eq!(ptr, buf.filled().as_ptr()); - if buf.filled().is_empty() { - return Poll::Ready(Err(std::io::Error::new( - std::io::ErrorKind::UnexpectedEof, - "unexpected eof", - ))); - } - *read_pos += buf.filled().len(); - } else { - assert!(*read_pos == size); - *read_pos = 0; - return Poll::Ready(Ok(())); - } - } - } -} - #[derive(Clone, Debug)] pub struct Certs { pub(crate) server_random: [u8; TLS_RANDOM_SIZE], @@ -139,8 +86,10 @@ impl ProxyTlsStream { } } -impl ReadExtBase for ProxyTlsStream { - fn prepare(&mut self) -> (&mut dyn AsyncReadUnpin, &mut BytesMut, &mut usize) { +impl ReadExactBase for ProxyTlsStream { + type I = S; + + fn decompose(&mut self) -> (&mut Self::I, &mut BytesMut, &mut usize) { (&mut self.raw, &mut self.read_buf, &mut self.read_pos) } } @@ -334,8 +283,10 @@ impl VerifiedStream { } } -impl ReadExtBase for VerifiedStream { - fn prepare(&mut self) -> (&mut dyn AsyncReadUnpin, &mut BytesMut, &mut usize) { +impl ReadExactBase for VerifiedStream { + type I = S; + + fn decompose(&mut self) -> (&mut Self::I, &mut BytesMut, &mut usize) { (&mut self.raw, &mut self.read_buf, &mut self.read_pos) } } diff --git a/clash_lib/src/proxy/vmess/vmess_impl/stream.rs b/clash_lib/src/proxy/vmess/vmess_impl/stream.rs index 87b81bc61..16b592dab 100644 --- a/clash_lib/src/proxy/vmess/vmess_impl/stream.rs +++ b/clash_lib/src/proxy/vmess/vmess_impl/stream.rs @@ -1,4 +1,4 @@ -use std::{fmt::Debug, mem::MaybeUninit, pin::Pin, task::Poll, time::SystemTime}; +use std::{fmt::Debug, pin::Pin, task::Poll, time::SystemTime}; use aes_gcm::Aes128Gcm; use bytes::{BufMut, BytesMut}; @@ -6,7 +6,7 @@ use chacha20poly1305::ChaCha20Poly1305; use futures::ready; use md5::Md5; -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf}; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use crate::{ common::{ @@ -78,52 +78,13 @@ enum WriteState { FlushingData(usize, (usize, usize)), } -pub trait ReadExt { - fn poll_read_exact( - &mut self, - cx: &mut std::task::Context, - size: usize, - ) -> Poll>; - #[allow(unused)] - fn get_data(&self) -> &[u8]; -} +use crate::common::io::{ReadExactBase, ReadExt}; -impl ReadExt for VmessStream { - // Read exactly `size` bytes into `read_buf`, starting from position 0. - fn poll_read_exact( - &mut self, - cx: &mut std::task::Context, - size: usize, - ) -> Poll> { - self.read_buf.reserve(size); - unsafe { self.read_buf.set_len(size) } - loop { - if self.read_pos < size { - let dst = unsafe { - &mut *((&mut self.read_buf[self.read_pos..size]) as *mut _ - as *mut [MaybeUninit]) - }; - let mut buf = ReadBuf::uninit(dst); - let ptr = buf.filled().as_ptr(); - ready!(Pin::new(&mut self.stream).poll_read(cx, &mut buf))?; - assert_eq!(ptr, buf.filled().as_ptr()); - if buf.filled().is_empty() { - return Poll::Ready(Err(std::io::Error::new( - std::io::ErrorKind::UnexpectedEof, - "unexpected eof", - ))); - } - self.read_pos += buf.filled().len(); - } else { - assert!(self.read_pos == size); - self.read_pos = 0; - return Poll::Ready(Ok(())); - } - } - } +impl ReadExactBase for VmessStream { + type I = S; - fn get_data(&self) -> &[u8] { - self.read_buf.as_ref() + fn decompose(&mut self) -> (&mut Self::I, &mut BytesMut, &mut usize) { + (&mut self.stream, &mut self.read_buf, &mut self.read_pos) } }