diff --git a/Cargo.lock b/Cargo.lock index f9cc9ec..5795057 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -924,6 +924,7 @@ dependencies = [ "serde_json", "thiserror", "tokio", + "tokio-util", "tracing", "tracing-subscriber", "url", diff --git a/Cargo.toml b/Cargo.toml index f088761..7fc3538 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ crate-type = ["cdylib"] [dependencies] tokio = { version = "1.32.0", features = ["rt-multi-thread", "macros", "signal", "time"] } +tokio-util = { version = "0.7", default-features = false, features = ["io"] } bytes = "1.0" tracing = "0.1" tracing-subscriber = "0.3" diff --git a/src/error.rs b/src/error.rs index 88c3a81..5de1410 100644 --- a/src/error.rs +++ b/src/error.rs @@ -23,7 +23,7 @@ pub(crate) enum ErrorReason { Timeout } -fn backoff_duration_for_retry(retries: usize, meta: &ConfigMeta) -> Duration { +pub(crate) fn backoff_duration_for_retry(retries: usize, meta: &ConfigMeta) -> Duration { // We try to use the same settings as the object_store backoff but the implementation is // different so this is best effort. let mut backoff = backoff::ExponentialBackoff { diff --git a/src/lib.rs b/src/lib.rs index 268eb02..1460699 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,12 +1,14 @@ +use futures_util::{StreamExt, TryStreamExt}; use futures_util::stream::BoxStream; -use futures_util::StreamExt; use object_store::{RetryConfig, ObjectMeta}; use once_cell::sync::OnceCell; -use tokio::io::AsyncWriteExt; +use tokio::io::{AsyncWriteExt, AsyncReadExt}; use tokio::runtime::Runtime; +use tokio_util::io::StreamReader; use std::collections::HashMap; use std::ffi::CString; use std::ffi::{c_char, c_void}; +use std::ops::Range; use std::sync::Arc; use std::time::{Duration, Instant}; use anyhow::anyhow; @@ -20,7 +22,7 @@ use moka::future::Cache; mod error; -use error::{extract_error_info, should_retry}; +use error::{extract_error_info, should_retry, should_retry_logic, backoff_duration_for_retry}; // Our global variables needed by our library at runtime. Note that we follow Rust's // safety rules here by making them immutable with write-exactly-once semantics using @@ -68,7 +70,8 @@ enum Request { Put(Path, &'static [u8], &'static Config, ResponseGuard), Delete(Path, &'static Config, ResponseGuard), List(Path, &'static Config, ListResponseGuard), - ListStream(Path, &'static Config, ListStreamResponseGuard) + ListStream(Path, &'static Config, ListStreamResponseGuard), + GetStream(Path, usize, &'static Config, GetStreamResponseGuard) } unsafe impl Send for Request {} @@ -300,18 +303,18 @@ pub struct Response { unsafe impl Send for Response {} -async fn multipart_get(slice: &mut [u8], path: &Path, client: &dyn ObjectStore) -> anyhow::Result { - let part_size: usize = static_config().multipart_get_part_size as usize; - let result = client.head(&path).await?; - if result.size > slice.len() { - return Err(anyhow!("Supplied buffer was too small")); +fn size_to_ranges(object_size: usize) -> Vec> { + if object_size == 0 { + return vec![]; } + let part_size: usize = static_config().multipart_get_part_size as usize; + // If the object size happens to be smaller than part_size, // then we will end up doing a single range get of the whole // object. - let mut parts = result.size / part_size; - if result.size % part_size != 0 { + let mut parts = object_size / part_size; + if object_size % part_size != 0 { parts += 1; } let mut part_ranges = Vec::with_capacity(parts); @@ -319,7 +322,18 @@ async fn multipart_get(slice: &mut [u8], path: &Path, client: &dyn ObjectStore) part_ranges.push((i*part_size)..((i+1)*part_size)); } // Last part which handles sizes not divisible by part_size - part_ranges.push(((parts-1)*part_size)..result.size); + part_ranges.push(((parts-1)*part_size)..object_size); + + return part_ranges; +} + +async fn multipart_get(slice: &mut [u8], path: &Path, client: &dyn ObjectStore) -> anyhow::Result { + let result = client.head(&path).await?; + if result.size > slice.len() { + return Err(anyhow!("Supplied buffer was too small")); + } + + let part_ranges = size_to_ranges(result.size); let result_vec = client.get_ranges(&path, &part_ranges).await?; let mut accum: usize = 0; @@ -448,6 +462,65 @@ async fn handle_list_stream(prefix: &Path, config: &Config) -> anyhow::Result anyhow::Result<(Box, usize)> { + let (client, config_meta) = clients() + .try_get_with(config.get_hash(), dyn_connect(config)).await + .map_err(|e| anyhow!(e))?; + + if size_hint > 0 && size_hint < static_config().multipart_get_threshold as usize { + // Perform a single get without the head request + let result = client.get(path).await?; + let full_size = result.meta.size; + let stream = result.into_stream().map_err(Into::into).boxed(); + let reader = StreamReader::new(stream); + return Ok((Box::new(GetStreamWrapper { reader }), full_size)); + } else { + // Perform head request and prefetch parts in parallel + let meta = client.head(&path).await?; + let part_ranges = size_to_ranges(meta.size); + + let state = ( + client, + path.clone(), + config_meta.clone() + ); + let stream = futures_util::stream::iter(part_ranges) + .scan(state, |state, range| { + let state = state.clone(); + async move { Some((state, range)) } + }) + .map(|((client, path, config_meta), range)| async move { + return tokio::spawn(async move { + let start_instant = Instant::now(); + let mut retries = 0; + 'retry: loop { + match client.get_range(&path, range.clone()).await.map_err(Into::into) { + Ok(bytes) => { + return Ok::<_, anyhow::Error>(bytes) + }, + Err(e) => { + if should_retry_logic(retries, &e, start_instant.elapsed(), &config_meta) { + let duration = backoff_duration_for_retry(retries, &config_meta); + + retries += 1; + tracing::info!("retrying error (reason: {:?}) after {:?}: {}", extract_error_info(&e).reason, duration, e); + tokio::time::sleep(duration).await; + continue 'retry; + } + } + } + } + }).await?; + }) + .buffered(16) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)) + .boxed(); + + let reader = StreamReader::new(stream); + return Ok((Box::new(GetStreamWrapper { reader }), meta.size)); + } +} + #[no_mangle] pub extern "C" fn start( config: StaticConfig, @@ -612,6 +685,27 @@ pub extern "C" fn start( } } } + Request::GetStream(path, size_hint, config, response) => { + 'retry: loop { + match handle_get_stream(&path, size_hint, config).await { + Ok((stream, full_size)) => { + response.success(stream, full_size); + return; + } + Err(e) => { + if let Some(duration) = should_retry(retries, &e, start_instant.elapsed(), config).await { + retries += 1; + tracing::info!("retrying error (reason: {:?}) after {:?}: {}", extract_error_info(&e).reason, duration, e); + tokio::time::sleep(duration).await; + continue 'retry; + } + tracing::warn!("{}", e); + response.into_error(e); + return; + } + } + } + } } } }).buffer_unordered(static_config().concurrency_limit as usize).for_each(|_| async {}).await; @@ -1013,6 +1107,218 @@ pub extern "C" fn next_list_stream_chunk( } } +#[repr(C)] +pub struct ReadResponse { + result: CResult, + length: usize, + eof: bool, + error_message: *mut c_char +} + +unsafe impl Send for ReadResponse {} + +// RAII Guard for a ListResponse that ensures the awaiting Julia task will be notified +// even if this is dropped on a panic. +pub struct ReadResponseGuard { + response: &'static mut ReadResponse, + handle: *const c_void +} + +impl NotifyGuard for ReadResponseGuard { + fn is_uninitialized(&self) -> bool { + self.response.result == CResult::Uninitialized + } + fn condition_handle(&self) -> *const c_void { + self.handle + } + fn set_error(&mut self, error: impl std::fmt::Display) { + self.response.result = CResult::Error; + self.response.length = 0; + self.response.eof = false; + let c_string = CString::new(format!("{}", error)).expect("should not have nulls"); + self.response.error_message = c_string.into_raw(); + } +} + +impl ReadResponseGuard { + unsafe fn new(response_ptr: *mut ReadResponse, handle: *const c_void) -> ReadResponseGuard { + let response = unsafe { &mut (*response_ptr) }; + response.result = CResult::Uninitialized; + + ReadResponseGuard { response, handle } + } + fn success(self, length: usize, eof: bool) { + self.response.result = CResult::Ok; + self.response.length = length; + self.response.eof = eof; + self.response.error_message = std::ptr::null_mut(); + } +} + +impl Drop for ReadResponseGuard { + fn drop(&mut self) { + self.notify_on_drop() + } +} + +unsafe impl Send for ReadResponseGuard {} + +pub struct GetStreamWrapper { + reader: StreamReader>, bytes::Bytes> +} + +#[no_mangle] +pub extern "C" fn destroy_get_stream( + stream: *mut GetStreamWrapper +) -> CResult { + let boxed = unsafe { Box::from_raw(stream) }; + drop(boxed); + CResult::Ok +} + +#[repr(C)] +pub struct GetStreamResponse { + result: CResult, + stream: *mut GetStreamWrapper, + object_size: u64, + error_message: *mut c_char +} + +unsafe impl Send for GetStreamResponse {} + +// RAII Guard for a ListResponse that ensures the awaiting Julia task will be notified +// even if this is dropped on a panic. +pub struct GetStreamResponseGuard { + response: &'static mut GetStreamResponse, + handle: *const c_void +} + +impl NotifyGuard for GetStreamResponseGuard { + fn is_uninitialized(&self) -> bool { + self.response.result == CResult::Uninitialized + } + fn condition_handle(&self) -> *const c_void { + self.handle + } + fn set_error(&mut self, error: impl std::fmt::Display) { + self.response.result = CResult::Error; + self.response.stream = std::ptr::null_mut(); + self.response.object_size = 0; + let c_string = CString::new(format!("{}", error)).expect("should not have nulls"); + self.response.error_message = c_string.into_raw(); + } +} + +impl GetStreamResponseGuard { + unsafe fn new(response_ptr: *mut GetStreamResponse, handle: *const c_void) -> GetStreamResponseGuard { + let response = unsafe { &mut (*response_ptr) }; + response.result = CResult::Uninitialized; + + GetStreamResponseGuard { response, handle } + } + fn success(self, stream: Box, object_size: usize) { + self.response.result = CResult::Ok; + self.response.stream = Box::into_raw(stream); + self.response.object_size = object_size as u64; + self.response.error_message = std::ptr::null_mut(); + } +} + +impl Drop for GetStreamResponseGuard { + fn drop(&mut self) { + self.notify_on_drop() + } +} + +unsafe impl Send for GetStreamResponseGuard {} + +#[no_mangle] +pub extern "C" fn get_stream( + path: *const c_char, + size_hint: usize, + config: *const Config, + response: *mut GetStreamResponse, + handle: *const c_void +) -> CResult { + let response = unsafe { GetStreamResponseGuard::new(response, handle) }; + let path = unsafe { std::ffi::CStr::from_ptr(path) }; + let path: Path = path.to_str().expect("invalid utf8").try_into().unwrap(); + let config = unsafe { & (*config) }; + + match SQ.get() { + Some(sq) => { + match sq.try_send(Request::GetStream(path, size_hint, config, response)) { + Ok(_) => CResult::Ok, + Err(async_channel::TrySendError::Full(_)) => { + CResult::Backoff + } + Err(async_channel::TrySendError::Closed(_)) => { + CResult::Error + } + } + } + None => { + return CResult::Error; + } + } +} + +#[no_mangle] +pub extern "C" fn read_get_stream( + stream: *mut GetStreamWrapper, + buffer: *mut u8, + size: usize, + amount: usize, + response: *mut ReadResponse, + handle: *const c_void +) -> CResult { + let response = unsafe { ReadResponseGuard::new(response, handle) }; + let mut slice = unsafe { std::slice::from_raw_parts_mut(buffer, size) }; + let wrapper = match unsafe { stream.as_mut() } { + Some(w) => w, + None => { + tracing::error!("null stream pointer"); + return CResult::Error; + } + }; + + match RT.get() { + Some(runtime) => { + runtime.spawn(async move { + let read_op = async { + let amount_to_read = size.min(amount); + let mut bytes_read = 0; + while amount_to_read > bytes_read { + let n = wrapper.reader.read_buf(&mut slice).await?; + + if n == 0 { + return Ok((bytes_read, true)) + } else { + bytes_read += n; + } + } + + Ok::<_, anyhow::Error>((bytes_read, false)) + }; + + match read_op.await { + Ok((bytes_read, eof)) => { + response.success(bytes_read, eof); + }, + Err(e) => { + tracing::warn!("{}", e); + response.into_error(e); + } + } + }); + CResult::Ok + } + None => { + return CResult::Error; + } + } +} + #[no_mangle] pub extern "C" fn destroy_cstring(string: *mut c_char) -> CResult { let string = unsafe { std::ffi::CString::from_raw(string) };