Skip to content

Commit

Permalink
refactor(client): remove MaxSlots limit (#1377)
Browse files Browse the repository at this point in the history
* refactor(client): remove MaxSlots limit

* fix wasm build
  • Loading branch information
niklasad1 authored May 24, 2024
1 parent a6d7a53 commit e2a0c9f
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 133 deletions.
20 changes: 4 additions & 16 deletions client/http-client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,6 @@ impl<L> HttpClientBuilder<L> {
self
}

/// Set max concurrent requests.
pub fn max_concurrent_requests(mut self, max: usize) -> Self {
self.max_concurrent_requests = max;
self
}

/// Force to use the rustls native certificate store.
///
/// Since multiple certificate stores can be optionally enabled, this option will
Expand Down Expand Up @@ -198,7 +192,6 @@ where
let Self {
max_request_size,
max_response_size,
max_concurrent_requests,
request_timeout,
certificate_store,
id_kind,
Expand All @@ -220,11 +213,7 @@ where
.build(target)
.map_err(|e| Error::Transport(e.into()))?;

Ok(HttpClient {
transport,
id_manager: Arc::new(RequestIdManager::new(max_concurrent_requests, id_kind)),
request_timeout,
})
Ok(HttpClient { transport, id_manager: Arc::new(RequestIdManager::new(id_kind)), request_timeout })
}
}

Expand Down Expand Up @@ -303,8 +292,7 @@ where
R: DeserializeOwned,
Params: ToRpcParams + Send,
{
let guard = self.id_manager.next_request_id()?;
let id = guard.inner();
let id = self.id_manager.next_request_id();
let params = params.to_rpc_params()?;

let request = RequestSer::borrowed(&id, &method, params.as_deref());
Expand Down Expand Up @@ -340,8 +328,8 @@ where
R: DeserializeOwned + fmt::Debug + 'a,
{
let batch = batch.build()?;
let guard = self.id_manager.next_request_id()?;
let id_range = generate_batch_id_range(&guard, batch.len() as u64)?;
let id = self.id_manager.next_request_id();
let id_range = generate_batch_id_range(id, batch.len() as u64)?;

let mut batch_request = Vec::with_capacity(batch.len());
for ((method, params), id) in batch.into_iter().zip(id_range.clone()) {
Expand Down
17 changes: 8 additions & 9 deletions core/src/client/async_client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ impl ClientBuilder {
to_back: to_back.clone(),
request_timeout: self.request_timeout,
error: ErrorFromBack::new(to_back, disconnect_reason),
id_manager: RequestIdManager::new(self.max_concurrent_requests, self.id_kind),
id_manager: RequestIdManager::new(self.id_kind),
max_log_length: self.max_log_length,
on_exit: Some(client_dropped_tx),
}
Expand Down Expand Up @@ -403,7 +403,7 @@ impl ClientBuilder {
to_back: to_back.clone(),
request_timeout: self.request_timeout,
error: ErrorFromBack::new(to_back, disconnect_reason),
id_manager: RequestIdManager::new(self.max_concurrent_requests, self.id_kind),
id_manager: RequestIdManager::new(self.id_kind),
max_log_length: self.max_log_length,
on_exit: Some(client_dropped_tx),
}
Expand Down Expand Up @@ -479,7 +479,7 @@ impl ClientT for Client {
Params: ToRpcParams + Send,
{
// NOTE: we use this to guard against max number of concurrent requests.
let _req_id = self.id_manager.next_request_id()?;
let _req_id = self.id_manager.next_request_id();
let params = params.to_rpc_params()?;
let notif = NotificationSer::borrowed(&method, params.as_deref());

Expand All @@ -505,8 +505,7 @@ impl ClientT for Client {
Params: ToRpcParams + Send,
{
let (send_back_tx, send_back_rx) = oneshot::channel();
let guard = self.id_manager.next_request_id()?;
let id = guard.inner();
let id = self.id_manager.next_request_id();

let params = params.to_rpc_params()?;
let raw =
Expand Down Expand Up @@ -540,8 +539,8 @@ impl ClientT for Client {
R: DeserializeOwned,
{
let batch = batch.build()?;
let guard = self.id_manager.next_request_id()?;
let id_range = generate_batch_id_range(&guard, batch.len() as u64)?;
let id = self.id_manager.next_request_id();
let id_range = generate_batch_id_range(id, batch.len() as u64)?;

let mut batches = Vec::with_capacity(batch.len());
for ((method, params), id) in batch.into_iter().zip(id_range.clone()) {
Expand Down Expand Up @@ -621,8 +620,8 @@ impl SubscriptionClientT for Client {
return Err(RegisterMethodError::SubscriptionNameConflict(unsubscribe_method.to_owned()).into());
}

let guard = self.id_manager.next_request_two_ids()?;
let (id_sub, id_unsub) = guard.inner();
let id_sub = self.id_manager.next_request_id();
let id_unsub = self.id_manager.next_request_id();
let params = params.to_rpc_params()?;

let raw = serde_json::to_string(&RequestSer::borrowed(&id_sub, &subscribe_method, params.as_deref()))
Expand Down
3 changes: 0 additions & 3 deletions core/src/client/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,6 @@ pub enum Error {
/// Request timeout
#[error("Request timeout")]
RequestTimeout,
/// Max number of request slots exceeded.
#[error("Max concurrent requests exceeded")]
MaxSlotsExceeded,
/// Custom error.
#[error("Custom error: {0}")]
Custom(String),
Expand Down
73 changes: 6 additions & 67 deletions core/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -455,10 +455,6 @@ impl<Notif> Drop for Subscription<Notif> {
#[derive(Debug)]
/// Keep track of request IDs.
pub struct RequestIdManager {
// Current pending requests.
current_pending: Arc<()>,
/// Max concurrent pending requests allowed.
max_concurrent_requests: usize,
/// Get the next request ID.
current_id: CurrentId,
/// Request ID type.
Expand All @@ -467,38 +463,15 @@ pub struct RequestIdManager {

impl RequestIdManager {
/// Create a new `RequestIdGuard` with the provided concurrency limit.
pub fn new(limit: usize, id_kind: IdKind) -> Self {
Self { current_pending: Arc::new(()), max_concurrent_requests: limit, current_id: CurrentId::new(), id_kind }
}

fn get_slot(&self) -> Result<Arc<()>, Error> {
// Strong count is 1 at start, so that's why we use `>` and not `>=`.
if Arc::strong_count(&self.current_pending) > self.max_concurrent_requests {
Err(Error::MaxSlotsExceeded)
} else {
Ok(self.current_pending.clone())
}
pub fn new(id_kind: IdKind) -> Self {
Self { current_id: CurrentId::new(), id_kind }
}

/// Attempts to get the next request ID.
///
/// Fails if request limit has been exceeded.
pub fn next_request_id(&self) -> Result<RequestIdGuard<Id<'static>>, Error> {
let rc = self.get_slot()?;
let id = self.id_kind.into_id(self.current_id.next());

Ok(RequestIdGuard { _rc: rc, id })
}

/// Attempts to get fetch two ids (used for subscriptions) but only
/// occupy one slot in the request guard.
///
/// Fails if request limit has been exceeded.
pub fn next_request_two_ids(&self) -> Result<RequestIdGuard<(Id<'static>, Id<'static>)>, Error> {
let rc = self.get_slot()?;
let id1 = self.id_kind.into_id(self.current_id.next());
let id2 = self.id_kind.into_id(self.current_id.next());
Ok(RequestIdGuard { _rc: rc, id: (id1, id2) })
pub fn next_request_id(&self) -> Id<'static> {
self.id_kind.into_id(self.current_id.next())
}

/// Get a handle to the `IdKind`.
Expand All @@ -507,21 +480,6 @@ impl RequestIdManager {
}
}

/// Reference counted request ID.
#[derive(Debug)]
pub struct RequestIdGuard<T: Clone> {
id: T,
/// Reference count decreased when dropped.
_rc: Arc<()>,
}

impl<T: Clone> RequestIdGuard<T> {
/// Get the actual ID or IDs.
pub fn inner(&self) -> T {
self.id.clone()
}
}

/// What certificate store to use
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[non_exhaustive]
Expand Down Expand Up @@ -568,8 +526,8 @@ impl CurrentId {
}

/// Generate a range of IDs to be used in a batch request.
pub fn generate_batch_id_range(guard: &RequestIdGuard<Id>, len: u64) -> Result<Range<u64>, Error> {
let id_start = guard.inner().try_parse_inner_as_number()?;
pub fn generate_batch_id_range(id: Id, len: u64) -> Result<Range<u64>, Error> {
let id_start = id.try_parse_inner_as_number()?;
let id_end = id_start
.checked_add(len)
.ok_or_else(|| Error::Custom("BatchID range wrapped; restart the client or try again later".to_string()))?;
Expand Down Expand Up @@ -704,22 +662,3 @@ fn subscription_channel(max_buf_size: usize) -> (SubscriptionSender, Subscriptio

(SubscriptionSender { inner: tx, lagged: lagged_tx }, SubscriptionReceiver { inner: rx, lagged: lagged_rx })
}

#[cfg(test)]
mod tests {
use super::{IdKind, RequestIdManager};

#[test]
fn request_id_guard_works() {
let manager = RequestIdManager::new(2, IdKind::Number);
let _first = manager.next_request_id().unwrap();

{
let _second = manager.next_request_two_ids().unwrap();
assert!(manager.next_request_id().is_err());
// second dropped here.
}

assert!(manager.next_request_id().is_ok());
}
}
38 changes: 0 additions & 38 deletions tests/tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,23 +275,6 @@ async fn http_method_call_str_id_works() {
assert_eq!(&response, "hello");
}

#[tokio::test]
async fn http_concurrent_method_call_limits_works() {
init_logger();

let server_addr = server().await;
let uri = format!("http://{}", server_addr);
let client = HttpClientBuilder::default().max_concurrent_requests(1).build(&uri).unwrap();

let (first, second) = tokio::join!(
client.request::<String, ArrayParams>("say_hello", rpc_params!()),
client.request::<String, ArrayParams>("say_hello", rpc_params![]),
);

assert!(first.is_ok());
assert!(matches!(second, Err(Error::MaxSlotsExceeded)));
}

#[tokio::test]
async fn ws_subscription_several_clients() {
init_logger();
Expand Down Expand Up @@ -418,27 +401,6 @@ async fn ws_making_more_requests_than_allowed_should_not_deadlock() {
}
}

#[tokio::test]
async fn http_making_more_requests_than_allowed_should_not_deadlock() {
init_logger();

let server_addr = server().await;
let server_url = format!("http://{}", server_addr);
let client = HttpClientBuilder::default().max_concurrent_requests(2).build(&server_url).unwrap();
let client = Arc::new(client);

let mut requests = Vec::new();

for _ in 0..6 {
let c = client.clone();
requests.push(tokio::spawn(async move { c.request::<String, ArrayParams>("say_hello", rpc_params![]).await }));
}

for req in requests {
let _ = req.await.unwrap();
}
}

#[tokio::test]
async fn https_works() {
init_logger();
Expand Down

0 comments on commit e2a0c9f

Please sign in to comment.