diff --git a/Cargo.lock b/Cargo.lock index c639c01..5944b37 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -204,7 +204,8 @@ dependencies = [ "async-trait", "fs-err", "glob", - "hyper", + "hyper 1.1.0", + "hyper-util", "prost", "prost-build", "prost-wkt", @@ -390,7 +391,26 @@ dependencies = [ "futures-core", "futures-sink", "futures-util", - "http", + "http 0.2.11", + "indexmap", + "slab", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "h2" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31d030e59af851932b72ceebadf4a2b5986dba4c3b99dd2493f8273a0f151943" +dependencies = [ + "bytes", + "fnv", + "futures-core", + "futures-sink", + "futures-util", + "http 1.0.0", "indexmap", "slab", "tokio", @@ -436,6 +456,17 @@ dependencies = [ "itoa", ] +[[package]] +name = "http" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b32afd38673a8016f7c9ae69e5af41a58f81b1d31689040f2f1959594ce194ea" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + [[package]] name = "http-body" version = "0.4.6" @@ -443,7 +474,30 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" dependencies = [ "bytes", - "http", + "http 0.2.11", + "pin-project-lite", +] + +[[package]] +name = "http-body" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1cac85db508abc24a2e48553ba12a996e87244a0395ce011e62b37158745d643" +dependencies = [ + "bytes", + "http 1.0.0", +] + +[[package]] +name = "http-body-util" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41cb79eb393015dadd30fc252023adb0b2400a0caee0fa2a077e6e21a551e840" +dependencies = [ + "bytes", + "futures-util", + "http 1.0.0", + "http-body 1.0.0", "pin-project-lite", ] @@ -469,9 +523,9 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", - "h2", - "http", - "http-body", + "h2 0.3.23", + "http 0.2.11", + "http-body 0.4.6", "httparse", "httpdate", "itoa", @@ -483,6 +537,26 @@ dependencies = [ "want", ] +[[package]] +name = "hyper" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb5aa53871fc917b1a9ed87b683a5d86db645e23acb32c2e0785a353e522fb75" +dependencies = [ + "bytes", + "futures-channel", + "futures-util", + "h2 0.4.2", + "http 1.0.0", + "http-body 1.0.0", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "tokio", + "want", +] + [[package]] name = "hyper-tls" version = "0.5.0" @@ -490,12 +564,30 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905" dependencies = [ "bytes", - "hyper", + "hyper 0.14.28", "native-tls", "tokio", "tokio-native-tls", ] +[[package]] +name = "hyper-util" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bdea9aac0dbe5a9240d68cfd9501e2db94222c6dc06843e06640b9e07f0fdc67" +dependencies = [ + "bytes", + "futures-channel", + "futures-util", + "http 1.0.0", + "http-body 1.0.0", + "hyper 1.1.0", + "pin-project-lite", + "socket2", + "tokio", + "tracing", +] + [[package]] name = "idna" version = "0.5.0" @@ -922,10 +1014,10 @@ dependencies = [ "encoding_rs", "futures-core", "futures-util", - "h2", - "http", - "http-body", - "hyper", + "h2 0.3.23", + "http 0.2.11", + "http-body 0.4.6", + "hyper 0.14.28", "hyper-tls", "ipnet", "js-sys", @@ -1237,8 +1329,12 @@ name = "twirp" version = "0.1.0" dependencies = [ "async-trait", + "bytes", "futures", - "hyper", + "http-body-util", + "hyper 1.1.0", + "hyper-util", + "pin-project-lite", "prost", "reqwest", "serde", diff --git a/crates/twirp/Cargo.toml b/crates/twirp/Cargo.toml index b6a4676..0b8ab6b 100644 --- a/crates/twirp/Cargo.toml +++ b/crates/twirp/Cargo.toml @@ -22,8 +22,15 @@ reqwest = { version = "0.11", features = ["default", "gzip", "json"], optional = url = { version = "2.5", optional = true } # For the server feature -hyper = { version = "0.14", features = ["full"], optional = true } +bytes = "1.5" +http-body-util = "0.1" +hyper = { version = "1.1", features = ["full"], optional = true } +hyper-util = { version = "0.1", features = ["tokio"] } +pin-project-lite = "0.2" # For the test-support feature async-trait = { version = "0.1", optional = true } tokio = { version = "1.33", features = [], optional = true } + +[dev-dependencies] +tokio = { version = "1.33", features = ["rt", "macros"] } diff --git a/crates/twirp/src/body.rs b/crates/twirp/src/body.rs new file mode 100644 index 0000000..112df31 --- /dev/null +++ b/crates/twirp/src/body.rs @@ -0,0 +1,112 @@ +use std::fmt::{self, Debug, Formatter}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use bytes::Bytes; +use http_body_util::combinators::UnsyncBoxBody; +use http_body_util::BodyExt; +use hyper::body::Frame; + +use crate::GenericError; + +type BoxBody = UnsyncBoxBody; + +pin_project_lite::pin_project! { + /// Generic body type (like `axum::body::Body`). + pub struct Body { + #[pin] + inner: BoxBody + } +} + +impl Debug for Body { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("Body").finish() + } +} + +impl From for Body { + fn from(bytes: Bytes) -> Self { + Body { + inner: BoxBody::new(http_body_util::Full::new(bytes).map_err(|err| match err {})), + } + } +} + +impl From> for Body { + fn from(bytes: Vec) -> Self { + Bytes::from(bytes).into() + } +} + +impl From for Body { + fn from(text: String) -> Self { + Bytes::from(text).into() + } +} + +impl From<&'static str> for Body { + fn from(text: &'static str) -> Self { + Bytes::from(text).into() + } +} + +impl Body { + /// Create a new Body that wraps another `http::body::Body`. + pub fn new(body: B) -> Self + where + B: hyper::body::Body + Send + 'static, + B::Error: Into, + { + Body { + inner: BoxBody::new(body.map_err(|err| err.into())), + } + } + + pub fn empty() -> Self { + Self::new(http_body_util::Empty::new()) + } + + pub(crate) fn protobuf(message: &T) -> Self + where + T: prost::Message, + { + serialize_proto_message(message).into() + } + + pub(crate) fn json(data: &T) -> Result + where + T: serde::Serialize, + { + let json = serde_json::to_string(&data)?; + Ok(json.into()) + } +} + +pub(crate) fn serialize_proto_message(m: &T) -> Vec +where + T: prost::Message, +{ + let len = m.encoded_len(); + let mut data = Vec::with_capacity(len); + m.encode(&mut data) + .expect("can only fail if buffer does not have capacity"); + assert_eq!(data.len(), len); + data +} + +impl hyper::body::Body for Body { + /// Values yielded by the `Body`. + type Data = bytes::Bytes; + + /// The error type this `Body` might generate. + type Error = GenericError; + + /// Attempt to pull out the next data buffer of this stream. + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + self.project().inner.poll_frame(cx) + } +} diff --git a/crates/twirp/src/client.rs b/crates/twirp/src/client.rs index 5ff23b7..79b64a9 100644 --- a/crates/twirp/src/client.rs +++ b/crates/twirp/src/client.rs @@ -1,13 +1,13 @@ use std::sync::Arc; use async_trait::async_trait; -use hyper::header::{InvalidHeaderValue, CONTENT_TYPE}; -use hyper::StatusCode; +use reqwest::header::{InvalidHeaderValue, CONTENT_TYPE}; +use reqwest::StatusCode; use thiserror::Error; use url::Url; use crate::headers::{CONTENT_TYPE_JSON, CONTENT_TYPE_PROTOBUF}; -use crate::{to_proto_body, TwirpErrorResponse}; +use crate::{body, TwirpErrorResponse}; #[derive(Debug, Error)] pub enum ClientError { @@ -146,7 +146,7 @@ impl Client { .http_client .post(url) .header(CONTENT_TYPE, CONTENT_TYPE_PROTOBUF) - .body(to_proto_body(body)) + .body(body::serialize_proto_message(&body)) .build()?; // Create and execute the middleware handlers diff --git a/crates/twirp/src/error.rs b/crates/twirp/src/error.rs index e8ab926..c15f3e4 100644 --- a/crates/twirp/src/error.rs +++ b/crates/twirp/src/error.rs @@ -2,11 +2,13 @@ use std::collections::HashMap; -use hyper::{header, Body, Response, StatusCode}; +use hyper::{header, Response, StatusCode}; use serde::{Deserialize, Serialize, Serializer}; +use crate::Body; + // Alias for a generic error -pub type GenericError = Box; +pub type GenericError = Box; macro_rules! twirp_error_codes { ( diff --git a/crates/twirp/src/lib.rs b/crates/twirp/src/lib.rs index 49c08d8..fda18f0 100644 --- a/crates/twirp/src/lib.rs +++ b/crates/twirp/src/lib.rs @@ -1,6 +1,7 @@ #[cfg(feature = "client")] pub mod client; +mod body; pub mod error; pub mod headers; pub mod server; @@ -8,6 +9,7 @@ pub mod server; #[cfg(any(test, feature = "test-support"))] pub mod test; +pub use body::Body; pub use client::{Client, ClientBuilder, ClientError, Middleware, Next, Result}; pub use error::*; // many constructors like `invalid_argument()` pub use server::{serve, Router, Timings}; @@ -17,15 +19,3 @@ pub use reqwest; // Re-export `url so that the generated code works without additional dependencies beyond just the `twirp` crate. pub use url; - -pub(crate) fn to_proto_body(m: T) -> hyper::Body -where - T: prost::Message, -{ - let len = m.encoded_len(); - let mut data = Vec::with_capacity(len); - m.encode(&mut data) - .expect("can only fail if buffer does not have capacity"); - assert_eq!(data.len(), len); - hyper::Body::from(data) -} diff --git a/crates/twirp/src/server.rs b/crates/twirp/src/server.rs index d8fb863..140a8cd 100644 --- a/crates/twirp/src/server.rs +++ b/crates/twirp/src/server.rs @@ -3,13 +3,14 @@ use std::fmt::Debug; use std::sync::Arc; use futures::Future; -use hyper::{header, Body, Method, Request, Response}; +use http_body_util::BodyExt; +use hyper::{header, Method, Request, Response}; use serde::de::DeserializeOwned; use serde::Serialize; use tokio::time::{Duration, Instant}; use crate::headers::{CONTENT_TYPE_JSON, CONTENT_TYPE_PROTOBUF}; -use crate::{error, to_proto_body, GenericError, TwirpErrorResponse}; +use crate::{error, Body, GenericError, TwirpErrorResponse}; /// A function that handles a request and returns a response. type HandlerFn = Box) -> HandlerResponse + Send + Sync>; @@ -180,7 +181,7 @@ where T: prost::Message + Default + DeserializeOwned, { let format = BodyFormat::from_content_type(&req); - let bytes = hyper::body::to_bytes(req.into_body()).await?; + let bytes = req.into_body().collect().await?.to_bytes(); timings.set_received(); let request = match format { BodyFormat::Pb => T::decode(bytes)?, @@ -202,14 +203,13 @@ where BodyFormat::Pb => { let response = Response::builder() .header(header::CONTENT_TYPE, CONTENT_TYPE_PROTOBUF) - .body(to_proto_body(response))?; + .body(Body::protobuf(&response))?; Ok(response) } _ => { - let data = serde_json::to_string(&response)?; let response = Response::builder() .header(header::CONTENT_TYPE, CONTENT_TYPE_JSON) - .body(Body::from(data))?; + .body(Body::json(&response)?)?; Ok(response) } }, @@ -313,7 +313,7 @@ mod tests { #[tokio::test] async fn test_routes() { - let router = test_api_router().await; + let router = test_api_router(); assert!(router .routes .contains_key(&(Method::POST, "/twirp/test.TestAPI/Ping".to_string()))); @@ -324,7 +324,7 @@ mod tests { #[tokio::test] async fn test_ping_success() { - let router = test_api_router().await; + let router = test_api_router(); let resp = serve(router, gen_ping_request("hi")).await.unwrap(); assert!(resp.status().is_success(), "{:?}", resp); let data: PingResponse = read_json_body(resp.into_body()).await; @@ -333,7 +333,7 @@ mod tests { #[tokio::test] async fn test_ping_invalid_request() { - let router = test_api_router().await; + let router = test_api_router(); let req = Request::post("/twirp/test.TestAPI/Ping") .body(Body::empty()) // not a valid request .unwrap(); @@ -354,7 +354,7 @@ mod tests { #[tokio::test] async fn test_boom() { - let router = test_api_router().await; + let router = test_api_router(); let req = serde_json::to_string(&PingRequest { name: "hi".to_string(), }) diff --git a/crates/twirp/src/test.rs b/crates/twirp/src/test.rs index df3ac9e..a5d3809 100644 --- a/crates/twirp/src/test.rs +++ b/crates/twirp/src/test.rs @@ -1,31 +1,60 @@ //! Test helpers and mini twirp api server implementation. +use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; use async_trait::async_trait; -use hyper::service::{make_service_fn, service_fn}; -use hyper::{Body, Request, Server}; +use http_body_util::BodyExt; +use hyper::body::Incoming; +use hyper::server::conn::http1; +use hyper::service::{service_fn, Service}; +use hyper::{Request, Response}; +use hyper_util::rt::TokioIo; use serde::de::DeserializeOwned; +use tokio::net::TcpListener; use tokio::task::JoinHandle; -use crate::{error, Client, GenericError, Result, Router, TwirpErrorResponse}; +use crate::{error, Body, Client, GenericError, Result, Router, TwirpErrorResponse}; -pub async fn run_test_server(port: u16) -> JoinHandle> { - let router = test_api_router().await; - let service = make_service_fn(move |_| { - let router = router.clone(); - async { Ok::<_, GenericError>(service_fn(move |req| crate::serve(router.clone(), req))) } +async fn test_server_main(tcp_listener: TcpListener, service: S) -> Result<(), std::io::Error> +where + S: Clone + Service, Response = Response> + Send + 'static, + S::Future: Send + 'static, + S::Error: Into, +{ + loop { + let (stream, _) = tcp_listener.accept().await?; + let io = TokioIo::new(stream); + let service = service.clone(); + let task = async move { + if let Err(err) = http1::Builder::new().serve_connection(io, service).await { + eprintln!("test server: error serving connection: {err:#}"); + } + }; + tokio::spawn(task); + } +} + +pub async fn run_test_server(port: u16) -> JoinHandle> { + let router = test_api_router(); + let service = service_fn(move |req: Request| { + let router = Arc::clone(&router); + async move { + let req = req.map(Body::new); + crate::serve(router, req).await + } }); - let addr = ([127, 0, 0, 1], port).into(); - let server = Server::bind(&addr).serve(service); + let addr: SocketAddr = ([127, 0, 0, 1], port).into(); + let tcp_listener = TcpListener::bind(&addr).await.unwrap(); + let server = test_server_main(tcp_listener, service); println!("Listening on {addr}"); let h = tokio::spawn(server); tokio::time::sleep(Duration::from_millis(100)).await; h } -pub async fn test_api_router() -> Arc { +pub fn test_api_router() -> Arc { let api = Arc::new(TestAPIServer {}); let mut router = Router::default(); // NB: This would be generated @@ -45,7 +74,7 @@ pub async fn test_api_router() -> Arc { Arc::new(router) } -pub fn gen_ping_request(name: &str) -> Request { +pub fn gen_ping_request(name: &str) -> Request { let req = serde_json::to_string(&PingRequest { name: name.to_string(), }) @@ -56,9 +85,11 @@ pub fn gen_ping_request(name: &str) -> Request { } pub async fn read_string_body(body: Body) -> String { - let data = hyper::body::to_bytes(body) + let data = body + .collect() .await .expect("invalid body") + .to_bytes() .to_vec(); String::from_utf8(data).expect("non-utf8 body") } @@ -67,10 +98,7 @@ pub async fn read_json_body(body: Body) -> T where T: DeserializeOwned, { - let data = hyper::body::to_bytes(body) - .await - .expect("invalid body") - .to_vec(); + let data = body.collect().await.expect("invalid body").to_bytes(); serde_json::from_slice(&data).expect("twirp response isn't valid JSON") } diff --git a/example/Cargo.toml b/example/Cargo.toml index b161acf..a20773e 100644 --- a/example/Cargo.toml +++ b/example/Cargo.toml @@ -6,7 +6,8 @@ edition = "2021" [dependencies] twirp = { path = "../crates/twirp" } async-trait = "0.1" -hyper = { version = "0.14", features = ["full"] } +hyper = { version = "1.1", features = ["full"] } +hyper-util = { version = "0.1", features = ["tokio"] } prost = "0.12" prost-wkt = "0.5" prost-wkt-types = "0.5" diff --git a/example/src/main.rs b/example/src/main.rs index 2aaf783..f1dc009 100644 --- a/example/src/main.rs +++ b/example/src/main.rs @@ -2,9 +2,13 @@ use std::sync::Arc; use std::time::UNIX_EPOCH; use async_trait::async_trait; -use hyper::service::{make_service_fn, service_fn}; -use hyper::{Body, Method, Response, Server}; -use twirp::{invalid_argument, GenericError, Router, TwirpErrorResponse}; +use hyper::body::Incoming; +use hyper::server::conn::http1; +use hyper::service::{service_fn, Service}; +use hyper::{Method, Request, Response}; +use hyper_util::rt::TokioIo; +use tokio::net::TcpListener; +use twirp::{invalid_argument, Body, GenericError, Router, TwirpErrorResponse}; pub mod service { pub mod haberdash { @@ -25,19 +29,40 @@ pub async fn main() { }); println!("{router:?}"); let router = Arc::new(router); - let service = make_service_fn(move |_| { - let router = router.clone(); - async { Ok::<_, GenericError>(service_fn(move |req| twirp::serve(router.clone(), req))) } + let service = service_fn(move |req: Request| { + let router = Arc::clone(&router); + async move { + let req = req.map(Body::new); + twirp::serve(router, req).await + } }); - let addr = ([127, 0, 0, 1], 3000).into(); - let server = Server::bind(&addr).serve(service); - println!("Listening on {addr}"); - if let Err(e) = server.await { + let tcp_listener = TcpListener::bind("localhost:3000").await.unwrap(); + println!("Listening on localhost:3000"); + if let Err(e) = serve_forever(tcp_listener, service).await { eprintln!("server error: {}", e); } } +async fn serve_forever(tcp_listener: TcpListener, service: S) -> Result<(), std::io::Error> +where + S: Clone + Service, Response = Response> + Send + 'static, + S::Future: Send + 'static, + S::Error: Into, +{ + loop { + let (stream, _) = tcp_listener.accept().await?; + let io = TokioIo::new(stream); + let service = service.clone(); + let task = async move { + if let Err(err) = http1::Builder::new().serve_connection(io, service).await { + eprintln!("test server: error serving connection: {err:#}"); + } + }; + tokio::spawn(task); + } +} + struct HaberdasherAPIServer; #[async_trait]