Skip to content

Commit

Permalink
style: make poll_read_exact a general trait (#689)
Browse files Browse the repository at this point in the history
  • Loading branch information
VendettaReborn authored Jan 29, 2025
1 parent 58f1118 commit e3f5a04
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 106 deletions.
55 changes: 55 additions & 0 deletions clash_lib/src/common/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
use std::future::Future;
use std::{
io,
mem::MaybeUninit,
pin::Pin,
task::{Context, Poll},
time::Duration,
};

use bytes::BytesMut;
use futures::ready;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

Expand Down Expand Up @@ -352,3 +354,56 @@ where
}
.await
}

pub trait ReadExactBase {
/// inner stream to be polled
type I: AsyncRead + Unpin;
/// prepare the inner stream, read buffer and read position
fn decompose(&mut self) -> (&mut Self::I, &mut BytesMut, &mut usize);
}

pub trait ReadExt: ReadExactBase {
fn poll_read_exact(
&mut self,
cx: &mut std::task::Context,
size: usize,
) -> Poll<std::io::Result<()>>;
}

impl<T: ReadExactBase> ReadExt for T {
fn poll_read_exact(
&mut self,
cx: &mut std::task::Context,
size: usize,
) -> Poll<std::io::Result<()>> {
let (raw, read_buf, read_pos) = self.decompose();
read_buf.reserve(size);
// # safety: read_buf has reserved `size`
unsafe { read_buf.set_len(size) }
loop {
if *read_pos < size {
// # safety: read_pos<size==read_buf.len(), and
// read_buf[0..read_pos] is initialized
let dst = unsafe {
&mut *((&mut read_buf[*read_pos..size]) as *mut _
as *mut [MaybeUninit<u8>])
};
let mut buf = ReadBuf::uninit(dst);
let ptr = buf.filled().as_ptr();
ready!(Pin::new(&mut *raw).poll_read(cx, &mut buf))?;
assert_eq!(ptr, buf.filled().as_ptr());
if buf.filled().is_empty() {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"unexpected eof",
)));
}
*read_pos += buf.filled().len();
} else {
assert!(*read_pos == size);
*read_pos = 0;
return Poll::Ready(Ok(()));
}
}
}
}
71 changes: 11 additions & 60 deletions clash_lib/src/proxy/shadowsocks/shadow_tls/stream.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
use std::{
mem::MaybeUninit,
pin::Pin,
ptr::{copy, copy_nonoverlapping},
task::{ready, Poll},
};

use byteorder::{BigEndian, WriteBytesExt};
use bytes::{BufMut, BytesMut};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::io::{AsyncRead, AsyncWrite};

use crate::common::io::{ReadExactBase, ReadExt};

use super::utils::{prelude::*, *};

Expand All @@ -26,60 +27,6 @@ pub enum WriteState {
FlushingData(usize, usize, usize),
}

pub trait AsyncReadUnpin: AsyncRead + Unpin {}

impl<T: AsyncRead + Unpin> AsyncReadUnpin for T {}

pub trait ReadExtBase {
fn prepare(&mut self) -> (&mut dyn AsyncReadUnpin, &mut BytesMut, &mut usize);
}

pub trait ReadExt {
fn poll_read_exact(
&mut self,
cx: &mut std::task::Context,
size: usize,
) -> Poll<std::io::Result<()>>;
}

impl<T: ReadExtBase> ReadExt for T {
fn poll_read_exact(
&mut self,
cx: &mut std::task::Context,
size: usize,
) -> Poll<std::io::Result<()>> {
let (raw, read_buf, read_pos) = self.prepare();
read_buf.reserve(size);
// # safety: read_buf has reserved `size`
unsafe { read_buf.set_len(size) }
loop {
if *read_pos < size {
// # safety: read_pos<size==read_buf.len(), and
// read_buf[0..read_pos] is initialized
let dst = unsafe {
&mut *((&mut read_buf[*read_pos..size]) as *mut _
as *mut [MaybeUninit<u8>])
};
let mut buf = ReadBuf::uninit(dst);
let ptr = buf.filled().as_ptr();
ready!(Pin::new(&mut *raw).poll_read(cx, &mut buf))?;
assert_eq!(ptr, buf.filled().as_ptr());
if buf.filled().is_empty() {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"unexpected eof",
)));
}
*read_pos += buf.filled().len();
} else {
assert!(*read_pos == size);
*read_pos = 0;
return Poll::Ready(Ok(()));
}
}
}
}

#[derive(Clone, Debug)]
pub struct Certs {
pub(crate) server_random: [u8; TLS_RANDOM_SIZE],
Expand Down Expand Up @@ -139,8 +86,10 @@ impl<S> ProxyTlsStream<S> {
}
}

impl<S: AsyncReadUnpin> ReadExtBase for ProxyTlsStream<S> {
fn prepare(&mut self) -> (&mut dyn AsyncReadUnpin, &mut BytesMut, &mut usize) {
impl<S: AsyncRead + Unpin> ReadExactBase for ProxyTlsStream<S> {
type I = S;

fn decompose(&mut self) -> (&mut Self::I, &mut BytesMut, &mut usize) {
(&mut self.raw, &mut self.read_buf, &mut self.read_pos)
}
}
Expand Down Expand Up @@ -334,8 +283,10 @@ impl<S> VerifiedStream<S> {
}
}

impl<S: AsyncReadUnpin> ReadExtBase for VerifiedStream<S> {
fn prepare(&mut self) -> (&mut dyn AsyncReadUnpin, &mut BytesMut, &mut usize) {
impl<S: AsyncRead + Unpin> ReadExactBase for VerifiedStream<S> {
type I = S;

fn decompose(&mut self) -> (&mut Self::I, &mut BytesMut, &mut usize) {
(&mut self.raw, &mut self.read_buf, &mut self.read_pos)
}
}
Expand Down
53 changes: 7 additions & 46 deletions clash_lib/src/proxy/vmess/vmess_impl/stream.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use std::{fmt::Debug, mem::MaybeUninit, pin::Pin, task::Poll, time::SystemTime};
use std::{fmt::Debug, pin::Pin, task::Poll, time::SystemTime};

use aes_gcm::Aes128Gcm;
use bytes::{BufMut, BytesMut};
use chacha20poly1305::ChaCha20Poly1305;
use futures::ready;

use md5::Md5;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};

use crate::{
common::{
Expand Down Expand Up @@ -78,52 +78,13 @@ enum WriteState {
FlushingData(usize, (usize, usize)),
}

pub trait ReadExt {
fn poll_read_exact(
&mut self,
cx: &mut std::task::Context,
size: usize,
) -> Poll<std::io::Result<()>>;
#[allow(unused)]
fn get_data(&self) -> &[u8];
}
use crate::common::io::{ReadExactBase, ReadExt};

impl<S: AsyncRead + Unpin> ReadExt for VmessStream<S> {
// Read exactly `size` bytes into `read_buf`, starting from position 0.
fn poll_read_exact(
&mut self,
cx: &mut std::task::Context,
size: usize,
) -> Poll<std::io::Result<()>> {
self.read_buf.reserve(size);
unsafe { self.read_buf.set_len(size) }
loop {
if self.read_pos < size {
let dst = unsafe {
&mut *((&mut self.read_buf[self.read_pos..size]) as *mut _
as *mut [MaybeUninit<u8>])
};
let mut buf = ReadBuf::uninit(dst);
let ptr = buf.filled().as_ptr();
ready!(Pin::new(&mut self.stream).poll_read(cx, &mut buf))?;
assert_eq!(ptr, buf.filled().as_ptr());
if buf.filled().is_empty() {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"unexpected eof",
)));
}
self.read_pos += buf.filled().len();
} else {
assert!(self.read_pos == size);
self.read_pos = 0;
return Poll::Ready(Ok(()));
}
}
}
impl<S: AsyncRead + Unpin> ReadExactBase for VmessStream<S> {
type I = S;

fn get_data(&self) -> &[u8] {
self.read_buf.as_ref()
fn decompose(&mut self) -> (&mut Self::I, &mut BytesMut, &mut usize) {
(&mut self.stream, &mut self.read_buf, &mut self.read_pos)
}
}

Expand Down

0 comments on commit e3f5a04

Please sign in to comment.