Skip to content

Commit

Permalink
feat: add layer that limits body size (#271)
Browse files Browse the repository at this point in the history
* feat: add layer that limits body size

* chore: rename `LengthLimited` to `RequestBodyLimit`

* refactor: remove request body witness from layer

* refactor: remove need to box wrapped service error

* fix: impl bounds to work as a service

* feat: handle `Content-Length` case, improve docs

* docs: add example and recommendation for without 413 handling

* refactor: split out body and data

* refactor: get rid of custom data type

* refactor: stop interpreting service error

This removes the `StdError + 'static` bound on the service error.
Also sets the read limit for a given message to the lesser of
`Content-Length` or the configured limit.

* refactor: remove accessory function for finding nested source

* docs: improve limit module documentation

* misc docs changes

* Derive `Debug`

* Order functions like we normally do

* Add `RequestBodyLimit::layer`

* fix nit picks

* changelog

* add note about hyper

Co-authored-by: David Pedersen <[email protected]>
  • Loading branch information
neoeinstein and davidpdrsn authored Jun 6, 2022
1 parent 960c83b commit 5468fc8
Show file tree
Hide file tree
Showing 9 changed files with 431 additions and 1 deletion.
2 changes: 2 additions & 0 deletions tower-http/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Added

- Add `Timeout` middleware ([#270])
- Add `RequestBodyLimit` middleware ([#271])

[#270]: https://github.com/tower-rs/tower-http/pull/270
[#271]: https://github.com/tower-rs/tower-http/pull/271

## Changed

Expand Down
4 changes: 3 additions & 1 deletion tower-http/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ bytes = "1"
futures-core = "0.3"
futures-util = { version = "0.3.14", default_features = false, features = [] }
http = "0.2.2"
http-body = "0.4.1"
http-body = "0.4.5"
pin-project-lite = "0.2.7"
tower-layer = "0.3"
tower-service = "0.3"
Expand Down Expand Up @@ -62,6 +62,7 @@ full = [
"decompression-full",
"follow-redirect",
"fs",
"limit",
"map-request-body",
"map-response-body",
"metrics",
Expand All @@ -82,6 +83,7 @@ catch-panic = ["tracing", "futures-util/std"]
cors = []
follow-redirect = ["iri-string", "tower/util"]
fs = ["tokio/fs", "tokio-util/io", "tokio/io-util", "mime_guess", "mime", "percent-encoding", "httpdate", "set-status", "futures-util/alloc"]
limit = []
map-request-body = []
map-response-body = []
metrics = ["tokio/time"]
Expand Down
20 changes: 20 additions & 0 deletions tower-http/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,18 @@ pub trait ServiceBuilderExt<L>: crate::sealed::Sealed<L> + Sized {
) -> ServiceBuilder<
Stack<crate::catch_panic::CatchPanicLayer<crate::catch_panic::DefaultResponseForPanic>, L>,
>;

/// Intercept requests with over-sized payloads and convert them into
/// `413 Payload Too Large` responses.
///
/// See [`tower_http::limit`] for more details.
///
/// [`tower_http::limit`]: crate::limit
#[cfg(feature = "limit")]
fn request_body_limit(
self,
limit: usize,
) -> ServiceBuilder<Stack<crate::limit::RequestBodyLimitLayer, L>>;
}

impl<L> crate::sealed::Sealed<L> for ServiceBuilder<L> {}
Expand Down Expand Up @@ -558,4 +570,12 @@ impl<L> ServiceBuilderExt<L> for ServiceBuilder<L> {
> {
self.layer(crate::catch_panic::CatchPanicLayer::new())
}

#[cfg(feature = "limit")]
fn request_body_limit(
self,
limit: usize,
) -> ServiceBuilder<Stack<crate::limit::RequestBodyLimitLayer, L>> {
self.layer(crate::limit::RequestBodyLimitLayer::new(limit))
}
}
3 changes: 3 additions & 0 deletions tower-http/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,9 @@ pub mod trace;
#[cfg(feature = "follow-redirect")]
pub mod follow_redirect;

#[cfg(feature = "limit")]
pub mod limit;

#[cfg(feature = "metrics")]
pub mod metrics;

Expand Down
107 changes: 107 additions & 0 deletions tower-http/src/limit/body.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
use bytes::Bytes;
use http::{HeaderMap, HeaderValue, Response, StatusCode};
use http_body::{Body, Full, SizeHint};
use pin_project_lite::pin_project;
use std::pin::Pin;
use std::task::{Context, Poll};

pin_project! {
/// Response body for [`RequestBodyLimit`].
///
/// [`RequestBodyLimit`]: super::RequestBodyLimit
pub struct ResponseBody<B> {
#[pin]
inner: ResponseBodyInner<B>
}
}

impl<B> ResponseBody<B> {
fn payload_too_large() -> Self {
Self {
inner: ResponseBodyInner::PayloadTooLarge {
body: Full::from(BODY),
},
}
}

pub(crate) fn new(body: B) -> Self {
Self {
inner: ResponseBodyInner::Body { body },
}
}
}

pin_project! {
#[project = BodyProj]
enum ResponseBodyInner<B> {
PayloadTooLarge {
#[pin]
body: Full<Bytes>,
},
Body {
#[pin]
body: B
}
}
}

impl<B> Body for ResponseBody<B>
where
B: Body<Data = Bytes>,
{
type Data = Bytes;
type Error = B::Error;

fn poll_data(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
match self.project().inner.project() {
BodyProj::PayloadTooLarge { body } => body.poll_data(cx).map_err(|err| match err {}),
BodyProj::Body { body } => body.poll_data(cx),
}
}

fn poll_trailers(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
match self.project().inner.project() {
BodyProj::PayloadTooLarge { body } => {
body.poll_trailers(cx).map_err(|err| match err {})
}
BodyProj::Body { body } => body.poll_trailers(cx),
}
}

fn is_end_stream(&self) -> bool {
match &self.inner {
ResponseBodyInner::PayloadTooLarge { body } => body.is_end_stream(),
ResponseBodyInner::Body { body } => body.is_end_stream(),
}
}

fn size_hint(&self) -> SizeHint {
match &self.inner {
ResponseBodyInner::PayloadTooLarge { body } => body.size_hint(),
ResponseBodyInner::Body { body } => body.size_hint(),
}
}
}

const BODY: &[u8] = b"length limit exceeded";

pub(crate) fn create_error_response<B>() -> Response<ResponseBody<B>>
where
B: Body,
{
let mut res = Response::new(ResponseBody::payload_too_large());
*res.status_mut() = StatusCode::PAYLOAD_TOO_LARGE;

#[allow(clippy::declare_interior_mutable_const)]
const TEXT_PLAIN: HeaderValue = HeaderValue::from_static("text/plain; charset=utf-8");
res.headers_mut()
.insert(http::header::CONTENT_TYPE, TEXT_PLAIN);

res
}
61 changes: 61 additions & 0 deletions tower-http/src/limit/future.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
use super::body::create_error_response;
use super::ResponseBody;
use futures_core::ready;
use http::Response;
use http_body::Body;
use pin_project_lite::pin_project;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};

pin_project! {
/// Response future for [`RequestBodyLimit`].
///
/// [`RequestBodyLimit`]: super::RequestBodyLimit
pub struct ResponseFuture<F> {
#[pin]
inner: ResponseFutureInner<F>,
}
}

impl<F> ResponseFuture<F> {
pub(crate) fn payload_too_large() -> Self {
Self {
inner: ResponseFutureInner::PayloadTooLarge,
}
}

pub(crate) fn new(future: F) -> Self {
Self {
inner: ResponseFutureInner::Future { future },
}
}
}

pin_project! {
#[project = ResFutProj]
enum ResponseFutureInner<F> {
PayloadTooLarge,
Future {
#[pin]
future: F,
}
}
}

impl<ResBody, F, E> Future for ResponseFuture<F>
where
ResBody: Body,
F: Future<Output = Result<Response<ResBody>, E>>,
{
type Output = Result<Response<ResponseBody<ResBody>>, E>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let res = match self.project().inner.project() {
ResFutProj::PayloadTooLarge => create_error_response(),
ResFutProj::Future { future } => ready!(future.poll(cx))?.map(ResponseBody::new),
};

Poll::Ready(Ok(res))
}
}
32 changes: 32 additions & 0 deletions tower-http/src/limit/layer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use super::RequestBodyLimit;
use tower_layer::Layer;

/// Layer that applies the [`RequestBodyLimit`] middleware that intercepts requests
/// with body lengths greater than the configured limit and converts them into
/// `413 Payload Too Large` responses.
///
/// See the [module docs](crate::limit) for an example.
///
/// [`RequestBodyLimit`]: super::RequestBodyLimit
#[derive(Clone, Copy, Debug)]
pub struct RequestBodyLimitLayer {
limit: usize,
}

impl RequestBodyLimitLayer {
/// Create a new `RequestBodyLimitLayer` with the given body length limit.
pub fn new(limit: usize) -> Self {
Self { limit }
}
}

impl<S> Layer<S> for RequestBodyLimitLayer {
type Service = RequestBodyLimit<S>;

fn layer(&self, inner: S) -> Self::Service {
RequestBodyLimit {
inner,
limit: self.limit,
}
}
}
Loading

0 comments on commit 5468fc8

Please sign in to comment.