diff --git a/mountpoint-s3-client/src/lib.rs b/mountpoint-s3-client/src/lib.rs index 68df66785..1bab3b8f5 100644 --- a/mountpoint-s3-client/src/lib.rs +++ b/mountpoint-s3-client/src/lib.rs @@ -61,7 +61,9 @@ pub mod error_metadata; pub use object_client::{ObjectClient, PutObjectRequest}; -pub use s3_crt_client::{get_object::S3GetObjectResponse, put_object::S3PutObjectRequest, S3CrtClient, S3RequestError}; +pub use s3_crt_client::{ + get_object::S3GetObjectResponse, put_object::S3PutObjectRequest, OnTelemetry, S3CrtClient, S3RequestError, +}; /// Configuration for the S3 client pub mod config { diff --git a/mountpoint-s3-client/src/s3_crt_client.rs b/mountpoint-s3-client/src/s3_crt_client.rs index e0bc343dc..43a720b44 100644 --- a/mountpoint-s3-client/src/s3_crt_client.rs +++ b/mountpoint-s3-client/src/s3_crt_client.rs @@ -102,6 +102,7 @@ pub struct S3ClientConfig { read_backpressure: bool, initial_read_window: usize, network_interface_names: Vec, + telemetry_callback: Option>, } impl Default for S3ClientConfig { @@ -120,6 +121,7 @@ impl Default for S3ClientConfig { read_backpressure: false, initial_read_window: DEFAULT_PART_SIZE, network_interface_names: vec![], + telemetry_callback: None, } } } @@ -221,6 +223,13 @@ impl S3ClientConfig { self.network_interface_names = network_interface_names; self } + + /// Set a custom telemetry callback handler + #[must_use = "S3ClientConfig follows a builder pattern"] + pub fn telemetry_callback(mut self, telemetry_callback: Arc) -> Self { + self.telemetry_callback = Some(telemetry_callback); + self + } } /// Authentication configuration for the CRT-based S3 client @@ -288,6 +297,7 @@ struct S3CrtClientInner { bucket_owner: Option, credentials_provider: Option, host_resolver: HostResolver, + telemetry_callback: Option>, } impl S3CrtClientInner { @@ -422,6 +432,7 @@ impl S3CrtClientInner { bucket_owner: config.bucket_owner, credentials_provider: Some(credentials_provider), host_resolver, + telemetry_callback: config.telemetry_callback, }) } @@ -551,6 +562,7 @@ impl S3CrtClientInner { let endpoint = options.get_endpoint().expect("S3Message always has an endpoint"); let hostname = endpoint.host_name().to_str().unwrap().to_owned(); let host_resolver = self.host_resolver.clone(); + let telemetry_callback = self.telemetry_callback.clone(); let start_time = Instant::now(); let first_body_part = Arc::new(AtomicBool::new(true)); @@ -595,6 +607,10 @@ impl S3CrtClientInner { } else if request_canceled { metrics::counter!("s3.requests.canceled", "op" => op, "type" => request_type).increment(1); } + + if let Some(telemetry_callback) = &telemetry_callback { + telemetry_callback.on_telemetry(metrics); + } }) .on_headers(move |headers, response_status| { (on_headers)(headers, response_status); @@ -1370,6 +1386,11 @@ impl ObjectClient for S3CrtClient { } } +/// Custom handling of telemetry events +pub trait OnTelemetry: std::fmt::Debug + Send + Sync { + fn on_telemetry(&self, request_metrics: &RequestMetrics); +} + #[cfg(test)] mod tests { use mountpoint_s3_crt::common::error::Error; diff --git a/mountpoint-s3-client/tests/common/mod.rs b/mountpoint-s3-client/tests/common/mod.rs index d1c97dbd5..8de3fb095 100644 --- a/mountpoint-s3-client/tests/common/mod.rs +++ b/mountpoint-s3-client/tests/common/mod.rs @@ -11,13 +11,14 @@ use bytes::Bytes; use futures::{pin_mut, Stream, StreamExt}; use mountpoint_s3_client::config::{EndpointConfig, S3ClientConfig}; use mountpoint_s3_client::types::{ClientBackpressureHandle, GetObjectResponse}; -use mountpoint_s3_client::S3CrtClient; +use mountpoint_s3_client::{OnTelemetry, S3CrtClient}; use mountpoint_s3_crt::common::allocator::Allocator; use mountpoint_s3_crt::common::rust_log_adapter::RustLogAdapter; use mountpoint_s3_crt::common::uri::Uri; use rand::rngs::OsRng; use rand::RngCore; use std::ops::Range; +use std::sync::Arc; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt as _; use tracing_subscriber::{EnvFilter, Layer}; @@ -87,6 +88,13 @@ pub fn get_test_backpressure_client(initial_read_window: usize, part_size: Optio S3CrtClient::new(config).expect("could not create test client") } +pub fn get_test_client_with_custom_telemetry(telemetry_callback: Arc) -> S3CrtClient { + let config = S3ClientConfig::new() + .endpoint_config(get_test_endpoint_config()) + .telemetry_callback(telemetry_callback); + S3CrtClient::new(config).expect("could not create test client") +} + pub fn get_test_bucket_and_prefix(test_name: &str) -> (String, String) { let bucket = get_test_bucket(); let prefix = get_unique_test_prefix(test_name); diff --git a/mountpoint-s3-client/tests/metrics.rs b/mountpoint-s3-client/tests/metrics.rs index eabf70d99..57ccd233c 100644 --- a/mountpoint-s3-client/tests/metrics.rs +++ b/mountpoint-s3-client/tests/metrics.rs @@ -21,7 +21,8 @@ use metrics::{ }; use mountpoint_s3_client::error::ObjectClientError; use mountpoint_s3_client::types::{GetObjectParams, HeadObjectParams}; -use mountpoint_s3_client::{ObjectClient, S3CrtClient, S3RequestError}; +use mountpoint_s3_client::{ObjectClient, OnTelemetry, S3CrtClient, S3RequestError}; +use mountpoint_s3_crt::s3::client::RequestMetrics; use regex::Regex; use rusty_fork::rusty_fork_test; use tracing::Level; @@ -280,3 +281,73 @@ rusty_fork_test! { runtime.block_on(test_head_object_403()); } } + +async fn test_custom_telemetry_callback() { + let sdk_client = get_test_sdk_client().await; + let (bucket, prefix) = get_test_bucket_and_prefix("test_custom_telemetry_callback"); + + let key = format!("{prefix}/test"); + let body = vec![0x42; 100]; + sdk_client + .put_object() + .bucket(&bucket) + .key(&key) + .body(ByteStream::from(body.clone())) + .send() + .await + .unwrap(); + + let recorder = TestRecorder::default(); + metrics::set_global_recorder(recorder.clone()).unwrap(); + + #[derive(Debug)] + struct CustomOnTelemetry { + metric_name: String, + } + + impl OnTelemetry for CustomOnTelemetry { + fn on_telemetry(&self, request_metrics: &RequestMetrics) { + metrics::counter!(self.metric_name.clone()).absolute(request_metrics.total_duration().as_micros() as u64); + } + } + + let request_duration_metric_name = "request_duration_us"; + + let custom_telemetry_callback = CustomOnTelemetry { + metric_name: String::from(request_duration_metric_name), + }; + + let client = get_test_client_with_custom_telemetry(Arc::new(custom_telemetry_callback)); + let result = client + .get_object(&bucket, &key, &GetObjectParams::new()) + .await + .expect("get_object should succeed"); + let result = result + .map_ok(|(_offset, bytes)| bytes.len()) + .try_fold(0, |a, b| async move { Ok(a + b) }) + .await + .expect("get_object should succeed"); + assert_eq!(result, body.len()); + + let metrics = recorder.metrics.lock().unwrap().clone(); + + let (_, request_duration_us) = metrics + .get(request_duration_metric_name, None, None) + .expect("The custom metric should be emitted"); + + let Metric::Counter(request_duration_us) = request_duration_us else { + panic!("Expected a counter metric") + }; + assert!( + *request_duration_us.lock().unwrap() > 0, + "The request duration should be more than 0 microseconds" + ); +} + +rusty_fork_test! { + #[test] + fn custom_telemetry_callbacks_are_called() { + let runtime = tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap(); + runtime.block_on(test_custom_telemetry_callback()); + } +}