From 34db8de361fd4a8d5f74e9c4ae737d60aadb45ab Mon Sep 17 00:00:00 2001 From: "Herman J. Radtke III" Date: Sat, 21 Oct 2023 08:28:12 -0400 Subject: [PATCH 1/4] feat(http1) Add support for writing Trailer Fields Closes #2719 --- src/proto/h1/conn.rs | 38 ++++++++++++ src/proto/h1/dispatch.rs | 42 +++++++------ src/proto/h1/encode.rs | 124 ++++++++++++++++++++++++++++++++++++-- src/proto/h1/role.rs | 68 +++++++++++++++++++-- tests/client.rs | 50 ++++++++++++++- tests/server.rs | 103 ++++++++++++++++++++++++++++++- tests/support/mod.rs | 2 + tests/support/trailers.rs | 76 +++++++++++++++++++++++ 8 files changed, 473 insertions(+), 30 deletions(-) create mode 100644 tests/support/trailers.rs diff --git a/src/proto/h1/conn.rs b/src/proto/h1/conn.rs index 43fba3b793..c2719ac555 100644 --- a/src/proto/h1/conn.rs +++ b/src/proto/h1/conn.rs @@ -75,6 +75,7 @@ where // We assume a modern world where the remote speaks HTTP/1.1. // If they tell us otherwise, we'll downgrade in `read_head`. version: Version::HTTP_11, + allow_trailer_fields: false, }, _marker: PhantomData, } @@ -264,6 +265,16 @@ where self.state.reading = Reading::Body(Decoder::new(msg.decode)); } + if let Some(Ok(te_value)) = msg.head.headers.get("te").map(|v| v.to_str()) { + if te_value.eq_ignore_ascii_case("trailers") { + self.state.allow_trailer_fields = true; + } else { + self.state.allow_trailer_fields = false; + } + } else { + self.state.allow_trailer_fields = false; + } + Poll::Ready(Some(Ok((msg.head, msg.decode, wants)))) } @@ -640,6 +651,31 @@ where self.state.writing = state; } + pub(crate) fn write_trailers(&mut self, trailers: HeaderMap) { + if T::is_server() && self.state.allow_trailer_fields == false { + debug!("trailers not allowed to be sent"); + return; + } + debug_assert!(self.can_write_body() && self.can_buffer_body()); + + match self.state.writing { + Writing::Body(ref encoder) => { + if let Some(enc_buf) = + encoder.encode_trailers(trailers, self.state.title_case_headers) + { + self.io.buffer(enc_buf); + + self.state.writing = if encoder.is_last() || encoder.is_close_delimited() { + Writing::Closed + } else { + Writing::KeepAlive + }; + } + } + _ => unreachable!("write_trailers invalid state: {:?}", self.state.writing), + } + } + pub(crate) fn write_body_and_end(&mut self, chunk: B) { debug_assert!(self.can_write_body() && self.can_buffer_body()); // empty chunks should be discarded at Dispatcher level @@ -842,6 +878,8 @@ struct State { upgrade: Option, /// Either HTTP/1.0 or 1.1 connection version: Version, + /// Flag to track if trailer fields are allowed to be sent + allow_trailer_fields: bool, } #[derive(Debug)] diff --git a/src/proto/h1/dispatch.rs b/src/proto/h1/dispatch.rs index c29c15dcae..0871af12ef 100644 --- a/src/proto/h1/dispatch.rs +++ b/src/proto/h1/dispatch.rs @@ -351,27 +351,33 @@ where *clear_body = true; crate::Error::new_user_body(e) })?; - let chunk = if let Ok(data) = frame.into_data() { - data - } else { - trace!("discarding non-data frame"); - continue; - }; - let eos = body.is_end_stream(); - if eos { - *clear_body = true; - if chunk.remaining() == 0 { - trace!("discarding empty chunk"); - self.conn.end_body()?; + + if frame.is_data() { + let chunk = frame.into_data().unwrap_or_else(|_| unreachable!()); + let eos = body.is_end_stream(); + if eos { + *clear_body = true; + if chunk.remaining() == 0 { + trace!("discarding empty chunk"); + self.conn.end_body()?; + } else { + self.conn.write_body_and_end(chunk); + } } else { - self.conn.write_body_and_end(chunk); + if chunk.remaining() == 0 { + trace!("discarding empty chunk"); + continue; + } + self.conn.write_body(chunk); } + } else if frame.is_trailers() { + *clear_body = true; + self.conn.write_trailers( + frame.into_trailers().unwrap_or_else(|_| unreachable!()), + ); } else { - if chunk.remaining() == 0 { - trace!("discarding empty chunk"); - continue; - } - self.conn.write_body(chunk); + trace!("discarding unknown frame"); + continue; } } else { *clear_body = true; diff --git a/src/proto/h1/encode.rs b/src/proto/h1/encode.rs index c98c55d664..3a42a795a2 100644 --- a/src/proto/h1/encode.rs +++ b/src/proto/h1/encode.rs @@ -1,10 +1,19 @@ +use std::collections::HashMap; use std::fmt; use std::io::IoSlice; use bytes::buf::{Chain, Take}; -use bytes::Buf; +use bytes::{Buf, Bytes}; +use http::{ + header::{ + AUTHORIZATION, CACHE_CONTROL, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_RANGE, + CONTENT_TYPE, HOST, MAX_FORWARDS, SET_COOKIE, TRAILER, TRANSFER_ENCODING, + }, + HeaderMap, HeaderName, HeaderValue, +}; use super::io::WriteBuf; +use super::role::{write_headers, write_headers_title_case}; type StaticBuf = &'static [u8]; @@ -26,7 +35,7 @@ pub(crate) struct NotEof(u64); #[derive(Debug, PartialEq, Clone)] enum Kind { /// An Encoder for when Transfer-Encoding includes `chunked`. - Chunked, + Chunked(Option>), /// An Encoder for when Content-Length is set. /// /// Enforces that the body is not longer than the Content-Length header. @@ -45,6 +54,7 @@ enum BufKind { Limited(Take), Chunked(Chain, StaticBuf>), ChunkedEnd(StaticBuf), + Trailers(Chain, StaticBuf>), } impl Encoder { @@ -55,7 +65,7 @@ impl Encoder { } } pub(crate) fn chunked() -> Encoder { - Encoder::new(Kind::Chunked) + Encoder::new(Kind::Chunked(None)) } pub(crate) fn length(len: u64) -> Encoder { @@ -67,6 +77,16 @@ impl Encoder { Encoder::new(Kind::CloseDelimited) } + pub(crate) fn into_chunked_with_trailing_fields(self, trailers: Vec) -> Encoder { + match self.kind { + Kind::Chunked(_) => Encoder { + kind: Kind::Chunked(Some(trailers)), + is_last: self.is_last, + }, + _ => self, + } + } + pub(crate) fn is_eof(&self) -> bool { matches!(self.kind, Kind::Length(0)) } @@ -89,10 +109,17 @@ impl Encoder { } } + pub(crate) fn is_chunked(&self) -> bool { + match self.kind { + Kind::Chunked(_) => true, + _ => false, + } + } + pub(crate) fn end(&self) -> Result>, NotEof> { match self.kind { Kind::Length(0) => Ok(None), - Kind::Chunked => Ok(Some(EncodedBuf { + Kind::Chunked(_) => Ok(Some(EncodedBuf { kind: BufKind::ChunkedEnd(b"0\r\n\r\n"), })), #[cfg(feature = "server")] @@ -109,7 +136,7 @@ impl Encoder { debug_assert!(len > 0, "encode() called with empty buf"); let kind = match self.kind { - Kind::Chunked => { + Kind::Chunked(_) => { trace!("encoding chunked {}B", len); let buf = ChunkSize::new(len) .chain(msg) @@ -136,6 +163,54 @@ impl Encoder { EncodedBuf { kind } } + pub(crate) fn encode_trailers( + &self, + mut trailers: HeaderMap, + title_case_headers: bool, + ) -> Option> { + match &self.kind { + Kind::Chunked(allowed_trailer_fields) => { + let allowed_trailer_fields_map = match allowed_trailer_fields { + Some(ref allowed_trailer_fields) => { + allowed_trailer_field_map(&allowed_trailer_fields) + } + None => return None, + }; + + let mut cur_name = None; + let mut allowed_trailers = HeaderMap::new(); + + for (opt_name, value) in trailers.drain() { + if let Some(n) = opt_name { + cur_name = Some(n); + } + let name = cur_name.as_ref().expect("current header name"); + + if allowed_trailer_fields_map.contains_key(name.as_str()) + && !invalid_trailer_field(name) + { + allowed_trailers.insert(name, value); + } + } + + let mut buf = Vec::new(); + if title_case_headers { + write_headers_title_case(&allowed_trailers, &mut buf); + } else { + write_headers(&allowed_trailers, &mut buf); + } + + Some(EncodedBuf { + kind: BufKind::Trailers(b"0\r\n".chain(Bytes::from(buf)).chain(b"\r\n")), + }) + } + _ => { + debug!("attempted to encode trailers for non-chunked response"); + None + } + } + } + pub(super) fn encode_and_end(&self, msg: B, dst: &mut WriteBuf>) -> bool where B: Buf, @@ -144,7 +219,7 @@ impl Encoder { debug_assert!(len > 0, "encode() called with empty buf"); match self.kind { - Kind::Chunked => { + Kind::Chunked(_) => { trace!("encoding chunked {}B", len); let buf = ChunkSize::new(len) .chain(msg) @@ -181,6 +256,39 @@ impl Encoder { } } +fn invalid_trailer_field(name: &HeaderName) -> bool { + match name { + &AUTHORIZATION => true, + &CACHE_CONTROL => true, + &CONTENT_ENCODING => true, + &CONTENT_LENGTH => true, + &CONTENT_RANGE => true, + &CONTENT_TYPE => true, + &HOST => true, + &MAX_FORWARDS => true, + &SET_COOKIE => true, + &TRAILER => true, + &TRANSFER_ENCODING => true, + _ => false, + } +} + +fn allowed_trailer_field_map(allowed_trailer_fields: &Vec) -> HashMap { + let mut trailer_map = HashMap::new(); + + for header_value in allowed_trailer_fields { + if let Ok(header_str) = header_value.to_str() { + let items: Vec<&str> = header_str.split(',').map(|item| item.trim()).collect(); + + for item in items { + trailer_map.entry(item.to_string()).or_insert(()); + } + } + } + + trailer_map +} + impl Buf for EncodedBuf where B: Buf, @@ -192,6 +300,7 @@ where BufKind::Limited(ref b) => b.remaining(), BufKind::Chunked(ref b) => b.remaining(), BufKind::ChunkedEnd(ref b) => b.remaining(), + BufKind::Trailers(ref b) => b.remaining(), } } @@ -202,6 +311,7 @@ where BufKind::Limited(ref b) => b.chunk(), BufKind::Chunked(ref b) => b.chunk(), BufKind::ChunkedEnd(ref b) => b.chunk(), + BufKind::Trailers(ref b) => b.chunk(), } } @@ -212,6 +322,7 @@ where BufKind::Limited(ref mut b) => b.advance(cnt), BufKind::Chunked(ref mut b) => b.advance(cnt), BufKind::ChunkedEnd(ref mut b) => b.advance(cnt), + BufKind::Trailers(ref mut b) => b.advance(cnt), } } @@ -222,6 +333,7 @@ where BufKind::Limited(ref b) => b.chunks_vectored(dst), BufKind::Chunked(ref b) => b.chunks_vectored(dst), BufKind::ChunkedEnd(ref b) => b.chunks_vectored(dst), + BufKind::Trailers(ref b) => b.chunks_vectored(dst), } } } diff --git a/src/proto/h1/role.rs b/src/proto/h1/role.rs index c30a4948f9..e9a38d569f 100644 --- a/src/proto/h1/role.rs +++ b/src/proto/h1/role.rs @@ -629,6 +629,7 @@ impl Server { }; let mut encoder = Encoder::length(0); + let mut allowed_trailer_fields: Option> = None; let mut wrote_date = false; let mut cur_name = None; let mut is_name_written = false; @@ -815,6 +816,38 @@ impl Server { header::DATE => { wrote_date = true; } + header::TRAILER => { + // check that we actually can send a chunked body... + if msg.head.version == Version::HTTP_10 + || !Server::can_chunked(msg.req_method, msg.head.subject) + { + continue; + } + + if !is_name_written { + is_name_written = true; + header_name_writer.write_header_name_with_colon( + dst, + "trailer: ", + header::TRAILER, + ); + extend(dst, value.as_bytes()); + } else { + extend(dst, b", "); + extend(dst, value.as_bytes()); + } + + match allowed_trailer_fields { + Some(ref mut allowed_trailer_fields) => { + allowed_trailer_fields.push(value); + } + None => { + allowed_trailer_fields = Some(vec![value]); + } + } + + continue 'headers; + } _ => (), } //TODO: this should perhaps instead combine them into @@ -899,6 +932,12 @@ impl Server { extend(dst, b"\r\n"); } + if encoder.is_chunked() { + if let Some(allowed_trailer_fields) = allowed_trailer_fields { + encoder = encoder.into_chunked_with_trailing_fields(allowed_trailer_fields); + } + } + Ok(encoder.set_last(is_last)) } } @@ -1306,6 +1345,29 @@ impl Client { } }; + let encoder = encoder.map(|enc| { + if enc.is_chunked() { + let mut allowed_trailer_fields: Option> = None; + let trailers = headers.get_all(header::TRAILER); + for trailer in trailers.iter() { + match allowed_trailer_fields { + Some(ref mut allowed_trailer_fields) => { + allowed_trailer_fields.push(trailer.clone()); + } + None => { + allowed_trailer_fields = Some(vec![trailer.clone()]); + } + } + } + + if let Some(allowed_trailer_fields) = allowed_trailer_fields { + return enc.into_chunked_with_trailing_fields(allowed_trailer_fields); + } + } + + enc + }); + // This is because we need a second mutable borrow to remove // content-length header. if let Some(encoder) = encoder { @@ -1468,8 +1530,7 @@ fn title_case(dst: &mut Vec, name: &[u8]) { } } -#[cfg(feature = "client")] -fn write_headers_title_case(headers: &HeaderMap, dst: &mut Vec) { +pub(crate) fn write_headers_title_case(headers: &HeaderMap, dst: &mut Vec) { for (name, value) in headers { title_case(dst, name.as_str().as_bytes()); extend(dst, b": "); @@ -1478,8 +1539,7 @@ fn write_headers_title_case(headers: &HeaderMap, dst: &mut Vec) { } } -#[cfg(feature = "client")] -fn write_headers(headers: &HeaderMap, dst: &mut Vec) { +pub(crate) fn write_headers(headers: &HeaderMap, dst: &mut Vec) { for (name, value) in headers { extend(dst, name.as_str().as_bytes()); extend(dst, b": "); diff --git a/tests/client.rs b/tests/client.rs index b306016eea..8ac4a5e9b2 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -5,6 +5,7 @@ use std::convert::Infallible; use std::fmt; use std::future::Future; use std::io::{Read, Write}; +use std::iter::FromIterator; use std::net::{SocketAddr, TcpListener}; use std::pin::Pin; use std::thread; @@ -13,7 +14,7 @@ use std::time::Duration; use http::uri::PathAndQuery; use http_body_util::{BodyExt, StreamBody}; use hyper::body::Frame; -use hyper::header::HeaderValue; +use hyper::header::{HeaderMap, HeaderName, HeaderValue}; use hyper::{Method, Request, StatusCode, Uri, Version}; use bytes::Bytes; @@ -408,6 +409,15 @@ macro_rules! __client_req_prop { Frame::data, ))); }}; + + ($req_builder:ident, $body:ident, $addr:ident, body_stream_with_trailers: $body_e:expr) => {{ + use support::trailers::StreamBodyWithTrailers; + let (body, trailers) = $body_e; + $body = BodyExt::boxed(StreamBodyWithTrailers::with_trailers( + futures_util::TryStreamExt::map_ok(body, Frame::data), + trailers, + )); + }}; } macro_rules! __client_req_header { @@ -631,6 +641,44 @@ test! { body: &b"hello"[..], } +test! { + name: client_post_req_body_chunked_with_trailer, + + server: + expected: "\ + POST / HTTP/1.1\r\n\ + trailer: chunky-trailer\r\n\ + host: {addr}\r\n\ + transfer-encoding: chunked\r\n\ + \r\n\ + 5\r\n\ + hello\r\n\ + 0\r\n\ + chunky-trailer: header data\r\n\ + \r\n\ + ", + reply: REPLY_OK, + + client: + request: { + method: POST, + url: "http://{addr}/", + headers: { + "trailer" => "chunky-trailer", + }, + body_stream_with_trailers: ( + (futures_util::stream::once(async { Ok::<_, Infallible>(Bytes::from("hello"))})), + HeaderMap::from_iter(vec![( + HeaderName::from_static("chunky-trailer"), + HeaderValue::from_static("header data") + )].into_iter())), + }, + response: + status: OK, + headers: {}, + body: None, +} + test! { name: client_get_req_body_sized, diff --git a/tests/server.rs b/tests/server.rs index 16a5a9afbe..edf569a860 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -19,7 +19,7 @@ use futures_channel::oneshot; use futures_util::future::{self, Either, FutureExt}; use h2::client::SendRequest; use h2::{RecvStream, SendStream}; -use http::header::{HeaderName, HeaderValue}; +use http::header::{HeaderMap, HeaderName, HeaderValue}; use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full, StreamBody}; use hyper::rt::Timer; use hyper::rt::{Read as AsyncRead, Write as AsyncWrite}; @@ -2595,6 +2595,94 @@ async fn http2_keep_alive_count_server_pings() { .expect("timed out waiting for pings"); } +#[test] +fn http1_trailer_fields() { + let body = futures_util::stream::once(async move { Ok("hello".into()) }); + let mut headers = HeaderMap::new(); + headers.insert("chunky-trailer", "header data".parse().unwrap()); + // Invalid trailer field that should not be sent + headers.insert("Host", "www.example.com".parse().unwrap()); + // Not specified in Trailer header, so should not be sent + headers.insert("foo", "bar".parse().unwrap()); + + let server = serve(); + server + .reply() + .header("transfer-encoding", "chunked") + .header("trailer", "chunky-trailer") + .body_stream_with_trailers(body, headers); + let mut req = connect(server.addr()); + req.write_all( + b"\ + GET / HTTP/1.1\r\n\ + Host: example.domain\r\n\ + Connection: keep-alive\r\n\ + TE: trailers\r\n\ + \r\n\ + ", + ) + .expect("writing"); + + let chunky_trailer_chunk = b"\r\nchunky-trailer: header data\r\n\r\n"; + let res = read_until(&mut req, |buf| buf.ends_with(chunky_trailer_chunk)).expect("reading"); + let sres = s(&res); + + let expected_head = + "HTTP/1.1 200 OK\r\ntransfer-encoding: chunked\r\ntrailer: chunky-trailer\r\n"; + assert_eq!(&sres[..expected_head.len()], expected_head); + + // skip the date header + let date_fragment = "GMT\r\n\r\n"; + let pos = sres.find(date_fragment).expect("find GMT"); + let body = &sres[pos + date_fragment.len()..]; + + let expected_body = "5\r\nhello\r\n0\r\nchunky-trailer: header data\r\n\r\n"; + assert_eq!(body, expected_body); +} + +#[test] +fn http1_trailer_fields_not_allowed() { + let body = futures_util::stream::once(async move { Ok("hello".into()) }); + let mut headers = HeaderMap::new(); + headers.insert("chunky-trailer", "header data".parse().unwrap()); + + let server = serve(); + server + .reply() + .header("transfer-encoding", "chunked") + .header("trailer", "chunky-trailer") + .body_stream_with_trailers(body, headers); + let mut req = connect(server.addr()); + + // TE: trailers is not specified in request headers + req.write_all( + b"\ + GET / HTTP/1.1\r\n\ + Host: example.domain\r\n\ + Connection: keep-alive\r\n\ + \r\n\ + ", + ) + .expect("writing"); + + let last_chunk = b"\r\n0\r\n\r\n"; + let res = read_until(&mut req, |buf| buf.ends_with(last_chunk)).expect("reading"); + let sres = s(&res); + + let expected_head = + "HTTP/1.1 200 OK\r\ntransfer-encoding: chunked\r\ntrailer: chunky-trailer\r\n"; + assert_eq!(&sres[..expected_head.len()], expected_head); + + // skip the date header + let date_fragment = "GMT\r\n\r\n"; + let pos = sres.find(date_fragment).expect("find GMT"); + let body = &sres[pos + date_fragment.len()..]; + + // no trailer fields should be sent because TE: trailers was not in request headers + let expected_body = "5\r\nhello\r\n0\r\n\r\n"; + assert_eq!(body, expected_body); +} + // ------------------------------------------------- // the Server that is used to run all the tests with // ------------------------------------------------- @@ -2700,6 +2788,19 @@ impl<'a> ReplyBuilder<'a> { self.tx.lock().unwrap().send(Reply::Body(body)).unwrap(); } + fn body_stream_with_trailers(self, stream: S, trailers: HeaderMap) + where + S: futures_util::Stream> + Send + Sync + 'static, + { + use futures_util::TryStreamExt; + use hyper::body::Frame; + use support::trailers::StreamBodyWithTrailers; + let mut stream_body = StreamBodyWithTrailers::new(stream.map_ok(Frame::data)); + stream_body.set_trailers(trailers); + let body = BodyExt::boxed(stream_body); + self.tx.lock().unwrap().send(Reply::Body(body)).unwrap(); + } + #[allow(dead_code)] fn error>(self, err: E) { self.tx diff --git a/tests/support/mod.rs b/tests/support/mod.rs index c796459412..1de834532d 100644 --- a/tests/support/mod.rs +++ b/tests/support/mod.rs @@ -24,6 +24,8 @@ mod tokiort; #[allow(unused)] pub use tokiort::{TokioExecutor, TokioIo, TokioTimer}; +pub mod trailers; + #[allow(unused_macros)] macro_rules! t { ( diff --git a/tests/support/trailers.rs b/tests/support/trailers.rs new file mode 100644 index 0000000000..a23664e31c --- /dev/null +++ b/tests/support/trailers.rs @@ -0,0 +1,76 @@ +use bytes::Buf; +use futures_util::stream::Stream; +use http::header::HeaderMap; +use http_body::{Body, Frame}; +use pin_project_lite::pin_project; +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +pin_project! { + /// A body created from a [`Stream`]. + #[derive(Clone, Debug)] + pub struct StreamBodyWithTrailers { + #[pin] + stream: S, + trailers: Option, + } +} + +impl StreamBodyWithTrailers { + /// Create a new `StreamBodyWithTrailers`. + pub fn new(stream: S) -> Self { + Self { + stream, + trailers: None, + } + } + + pub fn with_trailers(stream: S, trailers: HeaderMap) -> Self { + Self { + stream, + trailers: Some(trailers), + } + } + + pub fn set_trailers(&mut self, trailers: HeaderMap) { + self.trailers = Some(trailers); + } +} + +impl Body for StreamBodyWithTrailers +where + S: Stream, E>>, + D: Buf, +{ + type Data = D; + type Error = E; + + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + let project = self.project(); + match project.stream.poll_next(cx) { + Poll::Ready(Some(result)) => Poll::Ready(Some(result)), + Poll::Ready(None) => match project.trailers.take() { + Some(trailers) => Poll::Ready(Some(Ok(Frame::trailers(trailers)))), + None => Poll::Ready(None), + }, + Poll::Pending => Poll::Pending, + } + } +} + +impl Stream for StreamBodyWithTrailers { + type Item = S::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().stream.poll_next(cx) + } + + fn size_hint(&self) -> (usize, Option) { + self.stream.size_hint() + } +} From 401dfaf0bd192c3772f4a66622a1e47c224c05b5 Mon Sep 17 00:00:00 2001 From: "Herman J. Radtke III" Date: Fri, 10 Nov 2023 22:29:56 -0500 Subject: [PATCH 2/4] fix(http1): code review fixes - use more idiomatic expressions - add TE as invalid header - add tests for encode_trailers - fix bug in encode_trailers when buffer is empty --- src/proto/h1/conn.rs | 16 +++-- src/proto/h1/encode.rs | 159 +++++++++++++++++++++++++++++++++++------ src/proto/h1/role.rs | 16 +---- 3 files changed, 148 insertions(+), 43 deletions(-) diff --git a/src/proto/h1/conn.rs b/src/proto/h1/conn.rs index c2719ac555..8007ebe5cb 100644 --- a/src/proto/h1/conn.rs +++ b/src/proto/h1/conn.rs @@ -8,7 +8,7 @@ use std::time::Duration; use crate::rt::{Read, Write}; use bytes::{Buf, Bytes}; -use http::header::{HeaderValue, CONNECTION}; +use http::header::{HeaderValue, CONNECTION, TE}; use http::{HeaderMap, Method, Version}; use httparse::ParserConfig; @@ -265,12 +265,14 @@ where self.state.reading = Reading::Body(Decoder::new(msg.decode)); } - if let Some(Ok(te_value)) = msg.head.headers.get("te").map(|v| v.to_str()) { - if te_value.eq_ignore_ascii_case("trailers") { - self.state.allow_trailer_fields = true; - } else { - self.state.allow_trailer_fields = false; - } + if msg + .head + .headers + .get(TE) + .map(|te_header| te_header == "trailers") + .unwrap_or(false) + { + self.state.allow_trailer_fields = true; } else { self.state.allow_trailer_fields = false; } diff --git a/src/proto/h1/encode.rs b/src/proto/h1/encode.rs index 3a42a795a2..cbea0c0b09 100644 --- a/src/proto/h1/encode.rs +++ b/src/proto/h1/encode.rs @@ -7,7 +7,7 @@ use bytes::{Buf, Bytes}; use http::{ header::{ AUTHORIZATION, CACHE_CONTROL, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_RANGE, - CONTENT_TYPE, HOST, MAX_FORWARDS, SET_COOKIE, TRAILER, TRANSFER_ENCODING, + CONTENT_TYPE, HOST, MAX_FORWARDS, SET_COOKIE, TE, TRAILER, TRANSFER_ENCODING, }, HeaderMap, HeaderName, HeaderValue, }; @@ -169,13 +169,8 @@ impl Encoder { title_case_headers: bool, ) -> Option> { match &self.kind { - Kind::Chunked(allowed_trailer_fields) => { - let allowed_trailer_fields_map = match allowed_trailer_fields { - Some(ref allowed_trailer_fields) => { - allowed_trailer_field_map(&allowed_trailer_fields) - } - None => return None, - }; + Kind::Chunked(Some(ref allowed_trailer_fields)) => { + let allowed_trailer_field_map = allowed_trailer_field_map(&allowed_trailer_fields); let mut cur_name = None; let mut allowed_trailers = HeaderMap::new(); @@ -186,8 +181,8 @@ impl Encoder { } let name = cur_name.as_ref().expect("current header name"); - if allowed_trailer_fields_map.contains_key(name.as_str()) - && !invalid_trailer_field(name) + if allowed_trailer_field_map.contains_key(name.as_str()) + && valid_trailer_field(name) { allowed_trailers.insert(name, value); } @@ -200,6 +195,10 @@ impl Encoder { write_headers(&allowed_trailers, &mut buf); } + if buf.is_empty() { + return None; + } + Some(EncodedBuf { kind: BufKind::Trailers(b"0\r\n".chain(Bytes::from(buf)).chain(b"\r\n")), }) @@ -256,20 +255,21 @@ impl Encoder { } } -fn invalid_trailer_field(name: &HeaderName) -> bool { +fn valid_trailer_field(name: &HeaderName) -> bool { match name { - &AUTHORIZATION => true, - &CACHE_CONTROL => true, - &CONTENT_ENCODING => true, - &CONTENT_LENGTH => true, - &CONTENT_RANGE => true, - &CONTENT_TYPE => true, - &HOST => true, - &MAX_FORWARDS => true, - &SET_COOKIE => true, - &TRAILER => true, - &TRANSFER_ENCODING => true, - _ => false, + &AUTHORIZATION => false, + &CACHE_CONTROL => false, + &CONTENT_ENCODING => false, + &CONTENT_LENGTH => false, + &CONTENT_RANGE => false, + &CONTENT_TYPE => false, + &HOST => false, + &MAX_FORWARDS => false, + &SET_COOKIE => false, + &TRAILER => false, + &TRANSFER_ENCODING => false, + &TE => false, + _ => true, } } @@ -439,7 +439,16 @@ impl std::error::Error for NotEof {} #[cfg(test)] mod tests { + use std::iter::FromIterator; + use bytes::BufMut; + use http::{ + header::{ + AUTHORIZATION, CACHE_CONTROL, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_RANGE, + CONTENT_TYPE, HOST, MAX_FORWARDS, SET_COOKIE, TE, TRAILER, TRANSFER_ENCODING, + }, + HeaderMap, HeaderName, HeaderValue, + }; use super::super::io::Cursor; use super::Encoder; @@ -514,4 +523,108 @@ mod tests { assert!(!encoder.is_eof()); encoder.end::<()>().unwrap(); } + + #[test] + fn chunked_with_valid_trailers() { + let encoder = Encoder::chunked(); + let trailers = vec![HeaderValue::from_static("chunky-trailer")]; + let encoder = encoder.into_chunked_with_trailing_fields(trailers); + + let mut headers = HeaderMap::new(); + headers.insert( + HeaderName::from_static("chunky-trailer"), + HeaderValue::from_static("header data"), + ); + headers.insert( + HeaderName::from_static("should-not-be-included"), + HeaderValue::from_static("oops"), + ); + + let buf1 = encoder.encode_trailers::<&[u8]>(headers, false).unwrap(); + + let mut dst = Vec::new(); + dst.put(buf1); + assert_eq!(dst, b"0\r\nchunky-trailer: header data\r\n\r\n"); + } + + #[test] + fn chunked_with_no_trailer_header() { + let encoder = Encoder::chunked(); + + let headers = HeaderMap::from_iter( + vec![( + HeaderName::from_static("chunky-trailer"), + HeaderValue::from_static("header data"), + )] + .into_iter(), + ); + + assert!(encoder + .encode_trailers::<&[u8]>(headers.clone(), false) + .is_none()); + + let trailers = vec![]; + let encoder = encoder.into_chunked_with_trailing_fields(trailers); + + assert!(encoder.encode_trailers::<&[u8]>(headers, false).is_none()); + } + + #[test] + fn chunked_with_invalid_trailers() { + let encoder = Encoder::chunked(); + + let trailers = format!( + "{},{},{},{},{},{},{},{},{},{},{},{}", + AUTHORIZATION, + CACHE_CONTROL, + CONTENT_ENCODING, + CONTENT_LENGTH, + CONTENT_RANGE, + CONTENT_TYPE, + HOST, + MAX_FORWARDS, + SET_COOKIE, + TRAILER, + TRANSFER_ENCODING, + TE, + ); + let trailers = vec![HeaderValue::from_str(&trailers).unwrap()]; + let encoder = encoder.into_chunked_with_trailing_fields(trailers); + + let mut headers = HeaderMap::new(); + headers.insert(AUTHORIZATION, HeaderValue::from_static("header data")); + headers.insert(CACHE_CONTROL, HeaderValue::from_static("header data")); + headers.insert(CONTENT_ENCODING, HeaderValue::from_static("header data")); + headers.insert(CONTENT_LENGTH, HeaderValue::from_static("header data")); + headers.insert(CONTENT_RANGE, HeaderValue::from_static("header data")); + headers.insert(CONTENT_TYPE, HeaderValue::from_static("header data")); + headers.insert(HOST, HeaderValue::from_static("header data")); + headers.insert(MAX_FORWARDS, HeaderValue::from_static("header data")); + headers.insert(SET_COOKIE, HeaderValue::from_static("header data")); + headers.insert(TRAILER, HeaderValue::from_static("header data")); + headers.insert(TRANSFER_ENCODING, HeaderValue::from_static("header data")); + headers.insert(TE, HeaderValue::from_static("header data")); + + assert!(encoder.encode_trailers::<&[u8]>(headers, true).is_none()); + } + + #[test] + fn chunked_with_title_case_headers() { + let encoder = Encoder::chunked(); + let trailers = vec![HeaderValue::from_static("chunky-trailer")]; + let encoder = encoder.into_chunked_with_trailing_fields(trailers); + + let headers = HeaderMap::from_iter( + vec![( + HeaderName::from_static("chunky-trailer"), + HeaderValue::from_static("header data"), + )] + .into_iter(), + ); + let buf1 = encoder.encode_trailers::<&[u8]>(headers, true).unwrap(); + + let mut dst = Vec::new(); + dst.put(buf1); + assert_eq!(dst, b"0\r\nChunky-Trailer: header data\r\n\r\n"); + } } diff --git a/src/proto/h1/role.rs b/src/proto/h1/role.rs index e9a38d569f..a5a08d8cab 100644 --- a/src/proto/h1/role.rs +++ b/src/proto/h1/role.rs @@ -1347,20 +1347,10 @@ impl Client { let encoder = encoder.map(|enc| { if enc.is_chunked() { - let mut allowed_trailer_fields: Option> = None; - let trailers = headers.get_all(header::TRAILER); - for trailer in trailers.iter() { - match allowed_trailer_fields { - Some(ref mut allowed_trailer_fields) => { - allowed_trailer_fields.push(trailer.clone()); - } - None => { - allowed_trailer_fields = Some(vec![trailer.clone()]); - } - } - } + let allowed_trailer_fields: Vec = + headers.get_all(header::TRAILER).iter().cloned().collect(); - if let Some(allowed_trailer_fields) = allowed_trailer_fields { + if allowed_trailer_fields.len() > 0 { return enc.into_chunked_with_trailing_fields(allowed_trailer_fields); } } From 9cc26a59ae18ad81ff32eaf8fc1919695d7bfee4 Mon Sep 17 00:00:00 2001 From: "Herman J. Radtke III" Date: Sat, 11 Nov 2023 07:10:48 -0500 Subject: [PATCH 3/4] fix(http1): code review fixes - remove uncessary if statement - prefer into_iter() instead of drain() - prefer !vec.is_empty() to vec.len() > 0 - add another test for multiple trailer headers in single response --- src/proto/h1/conn.rs | 9 ++----- src/proto/h1/encode.rs | 57 ++++++++++++++++++++++++++++++++++-------- src/proto/h1/role.rs | 2 +- 3 files changed, 50 insertions(+), 18 deletions(-) diff --git a/src/proto/h1/conn.rs b/src/proto/h1/conn.rs index 8007ebe5cb..918ac62cdc 100644 --- a/src/proto/h1/conn.rs +++ b/src/proto/h1/conn.rs @@ -265,17 +265,12 @@ where self.state.reading = Reading::Body(Decoder::new(msg.decode)); } - if msg + self.state.allow_trailer_fields = msg .head .headers .get(TE) .map(|te_header| te_header == "trailers") - .unwrap_or(false) - { - self.state.allow_trailer_fields = true; - } else { - self.state.allow_trailer_fields = false; - } + .unwrap_or(false); Poll::Ready(Some(Ok((msg.head, msg.decode, wants)))) } diff --git a/src/proto/h1/encode.rs b/src/proto/h1/encode.rs index cbea0c0b09..41538aa90d 100644 --- a/src/proto/h1/encode.rs +++ b/src/proto/h1/encode.rs @@ -165,7 +165,7 @@ impl Encoder { pub(crate) fn encode_trailers( &self, - mut trailers: HeaderMap, + trailers: HeaderMap, title_case_headers: bool, ) -> Option> { match &self.kind { @@ -175,7 +175,7 @@ impl Encoder { let mut cur_name = None; let mut allowed_trailers = HeaderMap::new(); - for (opt_name, value) in trailers.drain() { + for (opt_name, value) in trailers.into_iter() { if let Some(n) = opt_name { cur_name = Some(n); } @@ -530,14 +530,18 @@ mod tests { let trailers = vec![HeaderValue::from_static("chunky-trailer")]; let encoder = encoder.into_chunked_with_trailing_fields(trailers); - let mut headers = HeaderMap::new(); - headers.insert( - HeaderName::from_static("chunky-trailer"), - HeaderValue::from_static("header data"), - ); - headers.insert( - HeaderName::from_static("should-not-be-included"), - HeaderValue::from_static("oops"), + let headers = HeaderMap::from_iter( + vec![ + ( + HeaderName::from_static("chunky-trailer"), + HeaderValue::from_static("header data"), + ), + ( + HeaderName::from_static("should-not-be-included"), + HeaderValue::from_static("oops"), + ), + ] + .into_iter(), ); let buf1 = encoder.encode_trailers::<&[u8]>(headers, false).unwrap(); @@ -547,6 +551,39 @@ mod tests { assert_eq!(dst, b"0\r\nchunky-trailer: header data\r\n\r\n"); } + #[test] + fn chunked_with_multiple_trailer_headers() { + let encoder = Encoder::chunked(); + let trailers = vec![ + HeaderValue::from_static("chunky-trailer"), + HeaderValue::from_static("chunky-trailer-2"), + ]; + let encoder = encoder.into_chunked_with_trailing_fields(trailers); + + let headers = HeaderMap::from_iter( + vec![ + ( + HeaderName::from_static("chunky-trailer"), + HeaderValue::from_static("header data"), + ), + ( + HeaderName::from_static("chunky-trailer-2"), + HeaderValue::from_static("more header data"), + ), + ] + .into_iter(), + ); + + let buf1 = encoder.encode_trailers::<&[u8]>(headers, false).unwrap(); + + let mut dst = Vec::new(); + dst.put(buf1); + assert_eq!( + dst, + b"0\r\nchunky-trailer: header data\r\nchunky-trailer-2: more header data\r\n\r\n" + ); + } + #[test] fn chunked_with_no_trailer_header() { let encoder = Encoder::chunked(); diff --git a/src/proto/h1/role.rs b/src/proto/h1/role.rs index a5a08d8cab..6828db75a7 100644 --- a/src/proto/h1/role.rs +++ b/src/proto/h1/role.rs @@ -1350,7 +1350,7 @@ impl Client { let allowed_trailer_fields: Vec = headers.get_all(header::TRAILER).iter().cloned().collect(); - if allowed_trailer_fields.len() > 0 { + if !allowed_trailer_fields.is_empty() { return enc.into_chunked_with_trailing_fields(allowed_trailer_fields); } } From 7a8eb3381387aed4f126a7c5fc366b8310b05525 Mon Sep 17 00:00:00 2001 From: "Herman J. Radtke III" Date: Sat, 11 Nov 2023 09:47:53 -0500 Subject: [PATCH 4/4] fix(http1): code review fixes - remove unnecessary into_iter --- src/proto/h1/encode.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/proto/h1/encode.rs b/src/proto/h1/encode.rs index 41538aa90d..90eeae4712 100644 --- a/src/proto/h1/encode.rs +++ b/src/proto/h1/encode.rs @@ -175,7 +175,7 @@ impl Encoder { let mut cur_name = None; let mut allowed_trailers = HeaderMap::new(); - for (opt_name, value) in trailers.into_iter() { + for (opt_name, value) in trailers { if let Some(n) = opt_name { cur_name = Some(n); }