Skip to content

Commit

Permalink
Rewrite traits to use async-fn-in-trait.
Browse files Browse the repository at this point in the history
- Stub
- BeforeRequest
- AfterRequest

Also removed the last remaining usage of an unstable feature,
iter_intersperse.
  • Loading branch information
tikue committed Dec 29, 2023
1 parent 84932df commit 6cf18a1
Show file tree
Hide file tree
Showing 22 changed files with 97 additions and 251 deletions.
3 changes: 0 additions & 3 deletions example-service/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.

#![allow(incomplete_features)]
#![feature(async_fn_in_trait)]

use std::env;
use tracing_subscriber::{fmt::format::FmtSpan, prelude::*};

Expand Down
3 changes: 0 additions & 3 deletions example-service/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.

#![allow(incomplete_features)]
#![feature(async_fn_in_trait)]

use clap::Parser;
use futures::{future, prelude::*};
use rand::{
Expand Down
3 changes: 0 additions & 3 deletions plugins/tests/server.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
#![allow(incomplete_features)]
#![feature(async_fn_in_trait)]

// these need to be out here rather than inside the function so that the
// assert_type_eq macro can pick them up.
#[tarpc::service]
Expand Down
3 changes: 0 additions & 3 deletions plugins/tests/service.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
#![allow(incomplete_features)]
#![feature(async_fn_in_trait)]

use tarpc::context;

#[test]
Expand Down
3 changes: 0 additions & 3 deletions tarpc/examples/compression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.

#![allow(incomplete_features)]
#![feature(async_fn_in_trait)]

use flate2::{read::DeflateDecoder, write::DeflateEncoder, Compression};
use futures::{prelude::*, Sink, SinkExt, Stream, StreamExt, TryStreamExt};
use serde::{Deserialize, Serialize};
Expand Down
3 changes: 0 additions & 3 deletions tarpc/examples/custom_transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.

#![allow(incomplete_features)]
#![feature(async_fn_in_trait)]

use futures::prelude::*;
use tarpc::context::Context;
use tarpc::serde_transport as transport;
Expand Down
3 changes: 0 additions & 3 deletions tarpc/examples/pubsub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.

#![allow(incomplete_features)]
#![feature(async_fn_in_trait)]

/// - The PubSub server sets up TCP listeners on 2 ports, the "subscriber" port and the "publisher"
/// port. Because both publishers and subscribers initiate their connections to the PubSub
/// server, the server requires no prior knowledge of either publishers or subscribers.
Expand Down
3 changes: 0 additions & 3 deletions tarpc/examples/readme.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.

#![allow(incomplete_features)]
#![feature(async_fn_in_trait)]

use futures::prelude::*;
use tarpc::{
client, context,
Expand Down
3 changes: 0 additions & 3 deletions tarpc/examples/tls_over_tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.

#![allow(incomplete_features)]
#![feature(async_fn_in_trait)]

use futures::prelude::*;
use rustls_pemfile::certs;
use std::io::{BufReader, Cursor};
Expand Down
4 changes: 0 additions & 4 deletions tarpc/examples/tracing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.

#![allow(incomplete_features)]
#![feature(async_fn_in_trait, type_alias_impl_trait)]

use crate::{
add::{Add as AddService, AddStub},
double::Double as DoubleService,
Expand Down Expand Up @@ -69,7 +66,6 @@ struct DoubleServer<Stub> {
impl<Stub> DoubleService for DoubleServer<Stub>
where
Stub: AddStub + Clone + Send + Sync + 'static,
for<'a> Stub::RespFut<'a>: Send,
{
async fn double(self, _: context::Context, x: i32) -> Result<i32, String> {
self.add_client
Expand Down
27 changes: 8 additions & 19 deletions tarpc/src/client/stub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use crate::{
client::{Channel, RpcError},
context,
};
use futures::prelude::*;

pub mod load_balance;
pub mod retry;
Expand All @@ -14,43 +13,33 @@ mod mock;

/// A connection to a remote service.
/// Calls the service with requests of type `Req` and receives responses of type `Resp`.
#[allow(async_fn_in_trait)]
pub trait Stub {
/// The service request type.
type Req;

/// The service response type.
type Resp;

/// The type of the future returned by `Stub::call`.
type RespFut<'a>: Future<Output = Result<Self::Resp, RpcError>>
where
Self: 'a,
Self::Req: 'a,
Self::Resp: 'a;

/// Calls a remote service.
fn call<'a>(
&'a self,
async fn call(
&self,
ctx: context::Context,
request_name: &'static str,
request: Self::Req,
) -> Self::RespFut<'a>;
) -> Result<Self::Resp, RpcError>;
}

impl<Req, Resp> Stub for Channel<Req, Resp> {
type Req = Req;
type Resp = Resp;
type RespFut<'a> = RespFut<'a, Req, Resp>
where Self: 'a;

fn call<'a>(
&'a self,
async fn call(
&self,
ctx: context::Context,
request_name: &'static str,
request: Req,
) -> Self::RespFut<'a> {
Self::call(self, ctx, request_name, request)
) -> Result<Self::Resp, RpcError> {
Self::call(self, ctx, request_name, request).await
}
}

type RespFut<'a, Req: 'a, Resp: 'a> = impl Future<Output = Result<Resp, RpcError>> + 'a;
68 changes: 21 additions & 47 deletions tarpc/src/client/stub/load_balance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,30 +10,25 @@ mod round_robin {
context,
};
use cycle::AtomicCycle;
use futures::prelude::*;

impl<Stub> stub::Stub for RoundRobin<Stub>
where
Stub: stub::Stub,
{
type Req = Stub::Req;
type Resp = Stub::Resp;
type RespFut<'a> = RespFut<'a, Stub>
where Self: 'a;

fn call<'a>(
&'a self,
async fn call(
&self,
ctx: context::Context,
request_name: &'static str,
request: Self::Req,
) -> Self::RespFut<'a> {
Self::call(self, ctx, request_name, request)
) -> Result<Stub::Resp, RpcError> {
let next = self.stubs.next();
next.call(ctx, request_name, request).await
}
}

type RespFut<'a, Stub: stub::Stub + 'a> =
impl Future<Output = Result<Stub::Resp, RpcError>> + 'a;

/// A Stub that load-balances across backing stubs by round robin.
#[derive(Clone, Debug)]
pub struct RoundRobin<Stub> {
Expand All @@ -50,16 +45,6 @@ mod round_robin {
stubs: AtomicCycle::new(stubs),
}
}

async fn call(
&self,
ctx: context::Context,
request_name: &'static str,
request: Stub::Req,
) -> Result<Stub::Resp, RpcError> {
let next = self.stubs.next();
next.call(ctx, request_name, request).await
}
}

mod cycle {
Expand Down Expand Up @@ -118,36 +103,36 @@ mod consistent_hash {
client::{stub, RpcError},
context,
};
use futures::prelude::*;
use std::{
collections::hash_map::RandomState,
hash::{BuildHasher, Hash, Hasher},
num::TryFromIntError,
};

impl<Stub> stub::Stub for ConsistentHash<Stub>
impl<Stub, S> stub::Stub for ConsistentHash<Stub, S>
where
Stub: stub::Stub,
Stub::Req: Hash,
S: BuildHasher,
{
type Req = Stub::Req;
type Resp = Stub::Resp;
type RespFut<'a> = RespFut<'a, Stub>
where Self: 'a;

fn call<'a>(
&'a self,
async fn call(
&self,
ctx: context::Context,
request_name: &'static str,
request: Self::Req,
) -> Self::RespFut<'a> {
Self::call(self, ctx, request_name, request)
) -> Result<Stub::Resp, RpcError> {
let index = usize::try_from(self.hash_request(&request) % self.stubs_len).expect(
"invariant broken: stubs_len is not larger than a usize, \
so the hash modulo stubs_len should always fit in a usize",
);
let next = &self.stubs[index];
next.call(ctx, request_name, request).await
}
}

type RespFut<'a, Stub: stub::Stub + 'a> =
impl Future<Output = Result<Stub::Resp, RpcError>> + 'a;

/// A Stub that load-balances across backing stubs by round robin.
#[derive(Clone, Debug)]
pub struct ConsistentHash<Stub, S = RandomState> {
Expand Down Expand Up @@ -188,20 +173,6 @@ mod consistent_hash {
})
}

async fn call(
&self,
ctx: context::Context,
request_name: &'static str,
request: Stub::Req,
) -> Result<Stub::Resp, RpcError> {
let index = usize::try_from(self.hash_request(&request) % self.stubs_len).expect(
"invariant broken: stubs_len is not larger than a usize, \
so the hash modulo stubs_len should always fit in a usize",
);
let next = &self.stubs[index];
next.call(ctx, request_name, request).await
}

fn hash_request(&self, req: &Stub::Req) -> u64 {
let mut hasher = self.hasher.build_hasher();
req.hash(&mut hasher);
Expand All @@ -212,7 +183,10 @@ mod consistent_hash {
#[cfg(test)]
mod tests {
use super::ConsistentHash;
use crate::{client::stub::mock::Mock, context};
use crate::{
client::stub::{mock::Mock, Stub},
context,
};
use std::{
collections::HashMap,
hash::{BuildHasher, Hash, Hasher},
Expand All @@ -221,7 +195,7 @@ mod consistent_hash {

#[tokio::test]
async fn test() -> anyhow::Result<()> {
let stub = ConsistentHash::with_hasher(
let stub = ConsistentHash::<_, FakeHasherBuilder>::with_hasher(
vec![
// For easier reading of the assertions made in this test, each Mock's response
// value is equal to a hash value that should map to its index: 3 % 3 = 0, 1 %
Expand Down
31 changes: 13 additions & 18 deletions tarpc/src/client/stub/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use crate::{
client::{stub::Stub, RpcError},
context, ServerError,
};
use futures::future;
use std::{collections::HashMap, hash::Hash, io};

/// A mock stub that returns user-specified responses.
Expand All @@ -29,26 +28,22 @@ where
{
type Req = Req;
type Resp = Resp;
type RespFut<'a> = future::Ready<Result<Resp, RpcError>>
where Self: 'a;

fn call<'a>(
&'a self,
async fn call(
&self,
_: context::Context,
_: &'static str,
request: Self::Req,
) -> Self::RespFut<'a> {
future::ready(
self.responses
.get(&request)
.cloned()
.map(Ok)
.unwrap_or_else(|| {
Err(RpcError::Server(ServerError {
kind: io::ErrorKind::NotFound,
detail: "mock (request, response) entry not found".into(),
}))
}),
)
) -> Result<Resp, RpcError> {
self.responses
.get(&request)
.cloned()
.map(Ok)
.unwrap_or_else(|| {
Err(RpcError::Server(ServerError {
kind: io::ErrorKind::NotFound,
detail: "mock (request, response) entry not found".into(),
}))
})
}
}
Loading

0 comments on commit 6cf18a1

Please sign in to comment.