diff --git a/mountpoint-s3/examples/prefetch_benchmark.rs b/mountpoint-s3/examples/prefetch_benchmark.rs index 737eb3d2e..fb1b93c2b 100644 --- a/mountpoint-s3/examples/prefetch_benchmark.rs +++ b/mountpoint-s3/examples/prefetch_benchmark.rs @@ -1,13 +1,15 @@ +use std::str::FromStr; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; +use std::thread; use std::time::Instant; use clap::{Arg, Command}; -use futures::executor::{block_on, ThreadPool}; +use futures::executor::block_on; use mountpoint_s3::prefetch::{default_prefetch, Prefetch, PrefetchResult}; use mountpoint_s3_client::config::{EndpointConfig, S3ClientConfig}; use mountpoint_s3_client::types::ETag; -use mountpoint_s3_client::S3CrtClient; +use mountpoint_s3_client::{ObjectClient, S3CrtClient}; use mountpoint_s3_crt::common::rust_log_adapter::RustLogAdapter; use tracing_subscriber::fmt::Subscriber; use tracing_subscriber::util::SubscriberInitExt; @@ -32,7 +34,6 @@ fn main() { .about("Download a single key from S3 and ignore its contents") .arg(Arg::new("bucket").required(true)) .arg(Arg::new("key").required(true)) - .arg(Arg::new("size").required(true)) .arg( Arg::new("throughput-target-gbps") .long("throughput-target-gbps") @@ -43,33 +44,46 @@ fn main() { .long("part-size") .help("Part size for multi-part GET and PUT"), ) + .arg(Arg::new("read-size").long("read-size").help("Size of read requests")) .arg( Arg::new("iterations") .long("iterations") .help("Number of times to download"), ) + .arg( + Arg::new("downloads") + .long("downloads") + .help("Number of concurrent downloads"), + ) .arg(Arg::new("region").long("region").default_value("us-east-1")) .get_matches(); let bucket = matches.get_one::("bucket").unwrap(); let key = matches.get_one::("key").unwrap(); - let size = matches - .get_one::("size") - .unwrap() - .parse::() - .expect("size must be u64"); let throughput_target_gbps = matches .get_one::("throughput-target-gbps") .map(|s| s.parse::().expect("throughput target must be an f64")); let part_size = matches .get_one::("part-size") .map(|s| s.parse::().expect("part size must be a usize")); + let read_size = matches + .get_one::("read-size") + .map(|s| s.parse::().expect("read size must be a usize")) + .unwrap_or(128 * 1024); let iterations = matches .get_one::("iterations") .map(|s| s.parse::().expect("iterations must be a number")); + let downloads = matches + .get_one::("downloads") + .map(|s| s.parse::().expect("downloads must be a number")) + .unwrap_or(1); let region = matches.get_one::("region").unwrap(); - let mut config = S3ClientConfig::new().endpoint_config(EndpointConfig::new(region)); + let initial_read_window_size = 1024 * 1024 + 128 * 1024; + let mut config = S3ClientConfig::new() + .endpoint_config(EndpointConfig::new(region)) + .read_backpressure(true) + .initial_read_window(initial_read_window_size); if let Some(throughput_target_gbps) = throughput_target_gbps { config = config.throughput_target_gbps(throughput_target_gbps); } @@ -78,28 +92,38 @@ fn main() { } let client = Arc::new(S3CrtClient::new(config).expect("couldn't create client")); + let head_object_result = block_on(client.head_object(bucket, key)).expect("HeadObject failed"); + let size = head_object_result.object.size; + let etag = ETag::from_str(&head_object_result.object.etag).unwrap(); + for i in 0..iterations.unwrap_or(1) { - let runtime = ThreadPool::builder().pool_size(1).create().unwrap(); + let runtime = client.event_loop_group(); let manager = default_prefetch(runtime, Default::default()); - let received_size = Arc::new(AtomicU64::new(0)); + let received_bytes = Arc::new(AtomicU64::new(0)); let start = Instant::now(); - let mut request = manager.prefetch(client.clone(), bucket, key, size, ETag::for_tests()); - block_on(async { - loop { - let offset = received_size.load(Ordering::SeqCst); - if offset >= size { - break; - } - let bytes = request.read(offset, 1 << 20).await.unwrap(); - received_size.fetch_add(bytes.len() as u64, Ordering::SeqCst); + thread::scope(|scope| { + for _ in 0..downloads { + let received_bytes = received_bytes.clone(); + let mut request = manager.prefetch(client.clone(), bucket, key, size, etag.clone()); + + scope.spawn(|| { + futures::executor::block_on(async move { + let mut offset = 0; + while offset < size { + let bytes = request.read(offset, read_size).await.unwrap(); + offset += bytes.len() as u64; + received_bytes.fetch_add(bytes.len() as u64, Ordering::SeqCst); + } + }) + }); } }); let elapsed = start.elapsed(); - let received_size = received_size.load(Ordering::SeqCst); + let received_size = received_bytes.load(Ordering::SeqCst); println!( "{}: received {} bytes in {:.2}s: {:.2}MiB/s", i,