From cb7a6f34aa78d80a5b8892585d1a076b93d45dfc Mon Sep 17 00:00:00 2001 From: "Herman J. Radtke III" Date: Fri, 3 Nov 2023 06:25:47 -0400 Subject: [PATCH] only send trailer fields if TE: trailers request header is present --- src/proto/h1/conn.rs | 14 +++++++++++++- src/proto/h1/mod.rs | 1 - tests/server.rs | 43 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 56 insertions(+), 2 deletions(-) diff --git a/src/proto/h1/conn.rs b/src/proto/h1/conn.rs index f53318236f..56fe9e94da 100644 --- a/src/proto/h1/conn.rs +++ b/src/proto/h1/conn.rs @@ -75,6 +75,7 @@ where // We assume a modern world where the remote speaks HTTP/1.1. // If they tell us otherwise, we'll downgrade in `read_head`. version: Version::HTTP_11, + allow_trailer_fields: false, }, _marker: PhantomData, } @@ -266,8 +267,12 @@ where if let Some(Ok(te_value)) = msg.head.headers.get("te").map(|v| v.to_str()) { if te_value.eq_ignore_ascii_case("trailers") { - wants = wants.add(Wants::TRAILERS); + self.state.allow_trailer_fields = true; + } else { + self.state.allow_trailer_fields = false; } + } else { + self.state.allow_trailer_fields = false; } Poll::Ready(Some(Ok((msg.head, msg.decode, wants)))) @@ -647,6 +652,11 @@ where } pub(crate) fn write_trailers(&mut self, trailers: HeaderMap) { + if self.state.allow_trailer_fields == false { + debug!("trailers not allowed to be sent"); + return; + } + debug_assert!(self.can_write_body() && self.can_buffer_body()); match self.state.writing { @@ -869,6 +879,8 @@ struct State { upgrade: Option, /// Either HTTP/1.0 or 1.1 connection version: Version, + /// Flag to track if trailer fields are allowed to be sent + allow_trailer_fields: bool, } #[derive(Debug)] diff --git a/src/proto/h1/mod.rs b/src/proto/h1/mod.rs index 210a3d2d75..86561c3764 100644 --- a/src/proto/h1/mod.rs +++ b/src/proto/h1/mod.rs @@ -112,7 +112,6 @@ impl Wants { const EMPTY: Wants = Wants(0b00); const EXPECT: Wants = Wants(0b01); const UPGRADE: Wants = Wants(0b10); - const TRAILERS: Wants = Wants(0b100); #[must_use] fn add(self, other: Wants) -> Wants { diff --git a/tests/server.rs b/tests/server.rs index 31800fcafa..edf569a860 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -2640,6 +2640,49 @@ fn http1_trailer_fields() { assert_eq!(body, expected_body); } +#[test] +fn http1_trailer_fields_not_allowed() { + let body = futures_util::stream::once(async move { Ok("hello".into()) }); + let mut headers = HeaderMap::new(); + headers.insert("chunky-trailer", "header data".parse().unwrap()); + + let server = serve(); + server + .reply() + .header("transfer-encoding", "chunked") + .header("trailer", "chunky-trailer") + .body_stream_with_trailers(body, headers); + let mut req = connect(server.addr()); + + // TE: trailers is not specified in request headers + req.write_all( + b"\ + GET / HTTP/1.1\r\n\ + Host: example.domain\r\n\ + Connection: keep-alive\r\n\ + \r\n\ + ", + ) + .expect("writing"); + + let last_chunk = b"\r\n0\r\n\r\n"; + let res = read_until(&mut req, |buf| buf.ends_with(last_chunk)).expect("reading"); + let sres = s(&res); + + let expected_head = + "HTTP/1.1 200 OK\r\ntransfer-encoding: chunked\r\ntrailer: chunky-trailer\r\n"; + assert_eq!(&sres[..expected_head.len()], expected_head); + + // skip the date header + let date_fragment = "GMT\r\n\r\n"; + let pos = sres.find(date_fragment).expect("find GMT"); + let body = &sres[pos + date_fragment.len()..]; + + // no trailer fields should be sent because TE: trailers was not in request headers + let expected_body = "5\r\nhello\r\n0\r\n\r\n"; + assert_eq!(body, expected_body); +} + // ------------------------------------------------- // the Server that is used to run all the tests with // -------------------------------------------------