Skip to content

Commit

Permalink
Introduce RemoteResult
Browse files Browse the repository at this point in the history
Signed-off-by: Alessandro Passaro <[email protected]>
  • Loading branch information
passaro committed Dec 16, 2024
1 parent f9748e2 commit d41531b
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 44 deletions.
142 changes: 110 additions & 32 deletions mountpoint-s3/src/async_util.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use std::fmt::Debug;
use std::future::Future;
use std::{fmt::Debug, future::Future};

use futures::future::{BoxFuture, FutureExt};
use futures::task::{Spawn, SpawnError};
use async_channel::{Receiver, Sender};
use futures::task::{Spawn, SpawnError, SpawnExt};

/// Type-erasure for a [Spawn] implementation.
pub struct BoxRuntime(Box<dyn Spawn + Send + Sync>);
Expand All @@ -23,51 +22,130 @@ impl BoxRuntime {
pub fn new(runtime: impl Spawn + Sync + Send + 'static) -> Self {
BoxRuntime(Box::new(runtime))
}

/// Spawns a task that polls the given future to completion and return
/// a [RemoteResult] with its output.
pub fn spawn_with_result<T, E, F>(&self, future: F) -> Result<RemoteResult<T, E>, SpawnError>
where
T: Send + 'static,
E: Send + 'static,
F: Future<Output = Result<T, E>> + Send + 'static,
{
let (sender, receiver) = result_channel();
self.spawn(async move {
let result = future.await;
sender.send(result).await;
})?;
Ok(receiver)
}
}

/// Creates an async one shot channel with a [RemoteResult] on the receiving end.
pub fn result_channel<T, E>() -> (ResultSender<T, E>, RemoteResult<T, E>) {
let (sender, receiver) = async_channel::bounded(1);
(ResultSender { sender }, RemoteResult { receiver, value: None })
}

/// Holds a value lazily initialized when awaiting a future.
pub struct Lazy<T, E> {
future: Option<BoxFuture<'static, Result<T, E>>>,
/// Holds the result of a spawned task.
#[derive(Debug)]
pub struct RemoteResult<T, E> {
receiver: Receiver<Result<T, E>>,
value: Option<T>,
}

impl<T, E> Lazy<T, E> {
pub fn new(f: impl Future<Output = Result<T, E>> + Send + 'static) -> Self {
Self {
future: Some(f.boxed()),
value: None,
}
/// Sender side of a [RemoteResult].
pub struct ResultSender<T, E> {
sender: Sender<Result<T, E>>,
}

impl<T, E> ResultSender<T, E> {
pub async fn send(self, value: Result<T, E>) -> bool {
self.sender.send(value).await.is_ok()
}
}

async fn force(&mut self) -> Result<(), E> {
if let Some(f) = self.future.take() {
self.value = Some(f.await?);
impl<T, E> RemoteResult<T, E> {
async fn receive(&mut self) -> Result<&mut Option<T>, E> {
if self.value.is_none() {
if let Ok(value) = self.receiver.recv().await {
self.value = Some(value?);
}
}
Ok(())
Ok(&mut self.value)
}

pub async fn get_mut(&mut self) -> Result<Option<&mut T>, E> {
self.force().await?;
Ok(self.value.as_mut())
Ok(self.receive().await?.as_mut())
}

pub async fn into_inner(mut self) -> Result<Option<T>, E> {
self.force().await?;
Ok(self.value.take())
Ok(self.receive().await?.take())
}
}

impl<T, E> Debug for Lazy<T, E>
where
T: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut s = f.debug_struct("Lazy");
if let Some(value) = &self.value {
s.field("value", value);
} else {
s.field("future", &"<pending>");
impl<T, E> Drop for RemoteResult<T, E> {
fn drop(&mut self) {
// Blocks to wait for the result and then drop it.
// Ignore the error if the sender has already been dropped.
_ = self.receiver.recv_blocking();
}
}

#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;

use futures::executor::{block_on, ThreadPool};
use test_case::test_case;

use super::{result_channel, BoxRuntime};

#[test_case(Ok(42))]
#[test_case(Err("error"))]
fn test_into_inner(result: Result<i32, &'static str>) {
let expected = result;
let (sender, receiver) = result_channel();
block_on(sender.send(result));

let result = block_on(receiver.into_inner()).transpose().unwrap();
assert_eq!(result, expected);
}

#[test_case(Ok(42))]
#[test_case(Err("error"))]
fn test_get_mut(result: Result<i32, &'static str>) {
let expected = result;
let (sender, mut receiver) = result_channel();
block_on(sender.send(result));

let result = block_on(receiver.get_mut()).transpose().unwrap();
match expected {
Ok(expected_value) => assert!(matches!(result, Ok(value) if *value == expected_value)),
Err(expected_error) => assert!(matches!(result, Err(error) if *error == *expected_error)),
}
s.finish()
}

#[test]
fn test_drop() {
let runtime = BoxRuntime::new(ThreadPool::new().unwrap());

struct Dropping(Arc<AtomicBool>);

impl Drop for Dropping {
fn drop(&mut self) {
self.0.store(true, Ordering::SeqCst);
}
}

let was_dropped = Arc::new(AtomicBool::new(false));
let clone = was_dropped.clone();

let result = runtime
.spawn_with_result(async move { Ok::<_, &'static str>(Dropping(clone)) })
.unwrap();
drop(result);

assert!(was_dropped.load(Ordering::SeqCst));
}
}
12 changes: 6 additions & 6 deletions mountpoint-s3/src/upload/atomic.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::fmt::Debug;

use futures::task::SpawnExt as _;
use mountpoint_s3_client::checksums::{crc32c, crc32c_from_base64, Crc32c};
use mountpoint_s3_client::error::{ObjectClientError, PutObjectError};
use mountpoint_s3_client::types::{
Expand All @@ -9,7 +8,7 @@ use mountpoint_s3_client::types::{
use mountpoint_s3_client::{ObjectClient, PutObjectRequest};
use tracing::error;

use crate::async_util::Lazy;
use crate::async_util::RemoteResult;
use crate::checksums::combine_checksums;
use crate::ServerSideEncryption;

Expand All @@ -21,7 +20,7 @@ const MAX_S3_MULTIPART_UPLOAD_PARTS: usize = 10000;
///
/// Wraps a PutObject request and enforces sequential writes.
pub struct UploadRequest<Client: ObjectClient> {
request: Lazy<Client::PutObjectRequest, ObjectClientError<PutObjectError, Client::ClientError>>,
request: RemoteResult<Client::PutObjectRequest, ObjectClientError<PutObjectError, Client::ClientError>>,
bucket: String,
key: String,
next_request_offset: u64,
Expand Down Expand Up @@ -67,17 +66,17 @@ where
let put_bucket = bucket.to_owned();
let put_key = key.to_owned();
let client = uploader.client.clone();
let request_handle = uploader
let request = uploader
.runtime
.spawn_with_handle(async move { client.put_object(&put_bucket, &put_key, &params).await })
.spawn_with_result(async move { client.put_object(&put_bucket, &put_key, &params).await })
.unwrap();
let maximum_upload_size = uploader
.client
.write_part_size()
.map(|ps| ps.saturating_mul(MAX_S3_MULTIPART_UPLOAD_PARTS));

Ok(UploadRequest {
request: Lazy::new(request_handle),
request,
bucket: bucket.to_owned(),
key: key.to_owned(),
next_request_offset: 0,
Expand Down Expand Up @@ -369,6 +368,7 @@ mod tests {
for i in 0..successful_writes {
let offset = i * write_size;
request.write(offset as i64, &data).await.expect("object should fit");
assert!(client.is_upload_in_progress(key));
}

let offset = successful_writes * write_size;
Expand Down
11 changes: 5 additions & 6 deletions mountpoint-s3/src/upload/incremental.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use std::fmt::Debug;
use std::mem;

use async_channel::{bounded, unbounded, Receiver, Sender};
use futures::channel::oneshot;
use futures::future::RemoteHandle;
use futures::task::SpawnExt as _;
use mountpoint_s3_client::error::{ObjectClientError, PutObjectError};
Expand All @@ -14,7 +13,7 @@ use mountpoint_s3_client::types::{
use mountpoint_s3_client::ObjectClient;
use tracing::{debug_span, trace, Instrument};

use crate::async_util::{BoxRuntime, Lazy};
use crate::async_util::{result_channel, BoxRuntime, RemoteResult};
use crate::mem_limiter::{BufferArea, MemoryLimiter};
use crate::sync::Arc;
use crate::ServerSideEncryption;
Expand Down Expand Up @@ -147,7 +146,7 @@ struct AppendUploadQueue<Client: ObjectClient> {
mem_limiter: Arc<MemoryLimiter<Client>>,
_task_handle: RemoteHandle<()>,
/// Algorithm used to compute checksums. Lazily initialized.
checksum_algorithm: Lazy<Option<ChecksumAlgorithm>, UploadError<Client::ClientError>>,
checksum_algorithm: RemoteResult<Option<ChecksumAlgorithm>, UploadError<Client::ClientError>>,
/// Stores the last successful result to return in [join].
last_known_result: Option<PutObjectResult>,
/// Tracks the requests pushed to the queue but still pending a response.
Expand All @@ -167,15 +166,15 @@ where
let span = debug_span!("append", key = params.key, initial_offset = params.initial_offset);
let (request_sender, request_receiver) = bounded(params.capacity);
let (response_sender, response_receiver) = unbounded();
let (checksum_algorithm_sender, checksum_algorithm_receiver) = oneshot::channel();
let (checksum_algorithm_sender, checksum_algorithm) = result_channel();

// Create a task for reading data out of the upload queue and create S3 requests for them.
let task_handle = runtime
.spawn_with_handle(
async move {
let checksum_algorithm = get_checksum_algorithm(&client, &params).await;
let is_error = checksum_algorithm.is_err();
if checksum_algorithm_sender.send(checksum_algorithm).is_err() || is_error {
if !checksum_algorithm_sender.send(checksum_algorithm).await || is_error {
return;
}

Expand All @@ -191,7 +190,7 @@ where
last_known_result: None,
requests_in_queue: 0,
mem_limiter,
checksum_algorithm: Lazy::new(async move { checksum_algorithm_receiver.await.unwrap() }),
checksum_algorithm,
_task_handle: task_handle,
}
}
Expand Down

0 comments on commit d41531b

Please sign in to comment.