Skip to content

Commit

Permalink
Client fixes, streamline tokio deps
Browse files Browse the repository at this point in the history
  • Loading branch information
tclem committed Jul 19, 2023
1 parent cf6855b commit 1201e2b
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 105 deletions.
73 changes: 4 additions & 69 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 1 addition & 5 deletions crates/twirp/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,4 @@ hyper = { version = "0.14", features = ["full"], optional = true }

# For the test-support feature
async-trait = { version = "0.1", optional = true }
tokio = { version = "1.21.2", features = ["full"], optional = true }

[dev-dependencies]
tokio = { version = "1.28", features = ["full", "test-util"] }
async-trait = "0.1"
tokio = { version = "1.28", features = [], optional = true }
59 changes: 30 additions & 29 deletions crates/twirp/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ impl ClientBuilder {
}
}

/// Add middleware to the client that will be called on each request.
/// Middlewares are invoked in the order they are added as part of the
/// request cycle.
pub fn with<M>(self, middleware: M) -> Self
where
M: Middleware,
Expand Down Expand Up @@ -102,51 +105,49 @@ impl std::fmt::Debug for Client {
}

impl Client {
/// Creates a `twirp::Client` with the default `reqwest::ClientBuilder`.
/// Creates a `twirp::Client`.
///
/// The underlying `reqwest::Client` holds a connection pool internally, so it is advised that
/// you create one and **reuse** it.
pub fn default(base_url: Url) -> Result<Self> {
pub fn new(
base_url: Url,
b: reqwest::ClientBuilder,
middlewares: Vec<Arc<dyn Middleware>>,
) -> Result<Self> {
if base_url.path().ends_with('/') {
Self::new(base_url, reqwest::ClientBuilder::default(), vec![])
let mut headers: HeaderMap<HeaderValue> = HeaderMap::default();
headers.insert(CONTENT_TYPE, CONTENT_TYPE_PROTOBUF.try_into()?);
let client = b.default_headers(headers).build()?;
Ok(Client {
base_url: Arc::new(base_url),
client: Arc::new(client),
middlewares,
})
} else {
Err(ClientError::InvalidBaseUrl(base_url))
}
}

/// Creates a `twirp::Client`.
/// Creates a `twirp::Client` with the default `reqwest::ClientBuilder`.
///
/// The underlying `reqwest::Client` holds a connection pool internally, so it is advised that
/// you create one and **reuse** it.
pub fn new(
base_url: Url,
b: reqwest::ClientBuilder,
middlewares: Vec<Arc<dyn Middleware>>,
) -> Result<Self> {
let mut headers: HeaderMap<HeaderValue> = HeaderMap::default();
headers.insert(CONTENT_TYPE, CONTENT_TYPE_PROTOBUF.try_into()?);
let client = b.default_headers(headers).build()?;
Ok(Client {
base_url: Arc::new(base_url),
client: Arc::new(client),
middlewares,
})
pub fn from_base_url(base_url: Url) -> Result<Self> {
Self::new(base_url, reqwest::ClientBuilder::default(), vec![])
}

/// Add some middleware to the request stack.
pub fn with<M>(&self, middleware: M) -> Self
/// Add middleware to this specific request stack. Middlewares are invoked
/// in the order they are added as part of the request cycle. Middleware
/// added here will run after any middleware added with the `ClientBuilder`.
pub fn with<M>(mut self, middleware: M) -> Self
where
M: Middleware,
{
let mut middlewares = self.middlewares.clone();
middlewares.push(Arc::new(middleware));
Self {
base_url: self.base_url.clone(),
client: self.client.clone(),
middlewares,
}
self.middlewares.push(Arc::new(middleware));
self
}

/// Make an HTTP twirp request.
pub async fn request<I, O>(&self, url: Url, body: I) -> Result<O>
where
I: prost::Message,
Expand Down Expand Up @@ -258,10 +259,10 @@ mod tests {
#[tokio::test]
async fn test_base_url() {
let url = Url::parse("http://localhost:3001/twirp/").unwrap();
assert!(Client::default(url).is_ok());
assert!(Client::from_base_url(url).is_ok());
let url = Url::parse("http://localhost:3001/twirp").unwrap();
assert_eq!(
Client::default(url).unwrap_err().to_string(),
Client::from_base_url(url).unwrap_err().to_string(),
"base_url must end in /, but got: http://localhost:3001/twirp",
);
}
Expand All @@ -288,7 +289,7 @@ mod tests {
async fn test_standard_client() {
let h = run_test_server(3001).await;
let base_url = Url::parse("http://localhost:3001/twirp/").unwrap();
let client = Client::default(base_url).unwrap();
let client = Client::from_base_url(base_url).unwrap();
let resp = client
.ping(PingRequest {
name: "hi".to_string(),
Expand Down
2 changes: 1 addition & 1 deletion example/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ prost = "0.11"
prost-wkt = "0.3"
prost-wkt-types = "0.3"
serde = { version = "1.0", features = ["derive"] }
tokio = { version = "1.21.2", features = ["full"] }
tokio = { version = "1.28", features = ["rt-multi-thread", "macros"] }

[build-dependencies]
fs-err = "2.8"
Expand Down
2 changes: 1 addition & 1 deletion example/src/bin/example-client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use service::haberdash::v1::{HaberdasherAPIClient, MakeHatRequest, MakeHatRespon
pub async fn main() -> Result<(), GenericError> {
// basic client
use service::haberdash::v1::HaberdasherAPIClient;
let client = Client::default(Url::parse("http://localhost:3000/twirp/")?)?;
let client = Client::from_base_url(Url::parse("http://localhost:3000/twirp/")?)?;
let resp = client.make_hat(MakeHatRequest { inches: 1 }).await;
eprintln!("{:?}", resp);

Expand Down

0 comments on commit 1201e2b

Please sign in to comment.