From 56f62a0cb63e822c9316213b0b8a21467d81b385 Mon Sep 17 00:00:00 2001 From: Timothy Clem Date: Fri, 3 Nov 2023 14:13:53 -0700 Subject: [PATCH] Minor refactoring --- crates/twirp/src/client.rs | 55 +++++++++++++++---------------- crates/twirp/src/server.rs | 16 +++++++++ example/src/bin/example-client.rs | 2 +- 3 files changed, 43 insertions(+), 30 deletions(-) diff --git a/crates/twirp/src/client.rs b/crates/twirp/src/client.rs index b826c3c..42fed26 100644 --- a/crates/twirp/src/client.rs +++ b/crates/twirp/src/client.rs @@ -1,9 +1,8 @@ use std::sync::Arc; use async_trait::async_trait; -use hyper::header::{InvalidHeaderValue, CONTENT_TYPE}; -use hyper::http::HeaderValue; -use hyper::{HeaderMap, StatusCode}; +use hyper::header; +use hyper::StatusCode; use thiserror::Error; use url::Url; @@ -14,7 +13,7 @@ use crate::{error::*, to_proto_body}; #[derive(Debug, Error)] pub enum ClientError { #[error(transparent)] - InvalidHeader(#[from] InvalidHeaderValue), + InvalidHeader(#[from] header::InvalidHeaderValue), #[error("base_url must end in /, but got: {0}")] InvalidBaseUrl(Url), #[error(transparent)] @@ -45,19 +44,23 @@ pub type Result = core::result::Result; /// Use ClientBuilder to build a TwirpClient. pub struct ClientBuilder { base_url: Url, - builder: reqwest::ClientBuilder, + http_client: reqwest::Client, middleware: Vec>, } impl ClientBuilder { - pub fn new(base_url: Url) -> Self { + pub fn new(base_url: Url, http_client: reqwest::Client) -> Self { Self { base_url, - builder: reqwest::ClientBuilder::default(), + http_client, middleware: vec![], } } + pub fn from_base_url(base_url: Url) -> Self { + Self::new(base_url, reqwest::Client::default()) + } + /// Add middleware to the client that will be called on each request. /// Middlewares are invoked in the order they are added as part of the /// request cycle. @@ -69,21 +72,13 @@ impl ClientBuilder { mw.push(Arc::new(middleware)); Self { base_url: self.base_url, - builder: self.builder, + http_client: self.http_client, middleware: mw, } } - pub fn with_client_builder(self, builder: reqwest::ClientBuilder) -> Self { - Self { - base_url: self.base_url, - builder, - middleware: self.middleware, - } - } - pub fn build(self) -> Result { - Client::new(self.base_url, self.builder, self.middleware) + Client::new(self.base_url, self.http_client, self.middleware) } } @@ -92,7 +87,7 @@ impl ClientBuilder { #[derive(Clone)] pub struct Client { pub base_url: Arc, - client: Arc, + http_client: Arc, middlewares: Vec>, } @@ -100,7 +95,7 @@ impl std::fmt::Debug for Client { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("TwirpClient") .field("base_url", &self.base_url) - .field("client", &self.client) + .field("client", &self.http_client) .field("middlewares", &self.middlewares.len()) .finish() } @@ -113,16 +108,13 @@ impl Client { /// you create one and **reuse** it. pub fn new( base_url: Url, - b: reqwest::ClientBuilder, + http_client: reqwest::Client, middlewares: Vec>, ) -> Result { if base_url.path().ends_with('/') { - let mut headers: HeaderMap = HeaderMap::default(); - headers.insert(CONTENT_TYPE, CONTENT_TYPE_PROTOBUF.try_into()?); - let client = b.default_headers(headers).build()?; Ok(Client { base_url: Arc::new(base_url), - client: Arc::new(client), + http_client: Arc::new(http_client), middlewares, }) } else { @@ -135,7 +127,7 @@ impl Client { /// The underlying `reqwest::Client` holds a connection pool internally, so it is advised that /// you create one and **reuse** it. pub fn from_base_url(base_url: Url) -> Result { - Self::new(base_url, reqwest::ClientBuilder::default(), vec![]) + Self::new(base_url, reqwest::Client::default(), vec![]) } /// Add middleware to this specific request stack. Middlewares are invoked @@ -156,15 +148,20 @@ impl Client { O: prost::Message + Default, { let path = url.path().to_string(); - let req = self.client.post(url).body(to_proto_body(body)).build()?; + let req = self + .http_client + .post(url) + .header(header::CONTENT_TYPE, CONTENT_TYPE_PROTOBUF) + .body(to_proto_body(body)) + .build()?; // Create and execute the middleware handlers - let next = Next::new(&self.client, &self.middlewares); + let next = Next::new(&self.http_client, &self.middlewares); let resp = next.run(req).await?; // These have to be extracted because reading the body consumes `Response`. let status = resp.status(); - let content_type = resp.headers().get(CONTENT_TYPE).cloned(); + let content_type = resp.headers().get(header::CONTENT_TYPE).cloned(); match (status, content_type) { (status, Some(ct)) if status.is_success() && ct.as_bytes() == CONTENT_TYPE_PROTOBUF => { @@ -271,7 +268,7 @@ mod tests { #[tokio::test] async fn test_routes() { let base_url = Url::parse("http://localhost:3001/twirp/").unwrap(); - let client = ClientBuilder::new(base_url) + let client = ClientBuilder::new(base_url, reqwest::Client::new()) .with(AssertRouting { expected_url: "http://localhost:3001/twirp/test.TestAPI/Ping", }) diff --git a/crates/twirp/src/server.rs b/crates/twirp/src/server.rs index 4f405cc..4938981 100644 --- a/crates/twirp/src/server.rs +++ b/crates/twirp/src/server.rs @@ -71,6 +71,22 @@ impl Router { self.routes.insert(key, Box::new(g)); } + /// Adds an async handler to the router for the given method and path. + pub fn add_async_handler(&mut self, method: Method, path: &str, f: F) + where + F: Fn(Request) -> Fut + Clone + Sync + Send + 'static, + Fut: Future, GenericError>> + Send, + { + let g = move |req| -> Box< + dyn Future, GenericError>> + Unpin + Send, + > { + let f = f.clone(); + Box::new(Box::pin(async move { f(req).await })) + }; + let key = (method, path.to_string()); + self.routes.insert(key, Box::new(g)); + } + /// Adds a twirp method handler to the router for the given path. pub fn add_method(&mut self, path: &str, f: F) where diff --git a/example/src/bin/example-client.rs b/example/src/bin/example-client.rs index a0ae847..5dba302 100644 --- a/example/src/bin/example-client.rs +++ b/example/src/bin/example-client.rs @@ -23,7 +23,7 @@ pub async fn main() -> Result<(), GenericError> { eprintln!("{:?}", resp); // customize the client with middleware - let client = ClientBuilder::new(Url::parse("http://xyz:3000/twirp/")?) + let client = ClientBuilder::from_base_url(Url::parse("http://xyz:3000/twirp/")?) .with(RequestHeaders { hmac_key: None }) .build()?; let resp = client