Skip to content

Commit

Permalink
Add support for concurrent downloads to prefetch_benchmark example (#…
Browse files Browse the repository at this point in the history
…1022)

* Fix prefetch_benchmark example

Signed-off-by: Alessandro Passaro <alexpax@amazon.co.uk>

* Add support for concurrent downloads to prefetch_benchmark example

Signed-off-by: Alessandro Passaro <alexpax@amazon.co.uk>

* Use CRT runtime

Signed-off-by: Alessandro Passaro <alexpax@amazon.co.uk>

---------

Signed-off-by: Alessandro Passaro <alexpax@amazon.co.uk>
  • Loading branch information
passaro authored Sep 24, 2024
1 parent ed4735d commit f92bf6c
Showing 1 changed file with 45 additions and 21 deletions.
66 changes: 45 additions & 21 deletions mountpoint-s3/examples/prefetch_benchmark.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
use std::str::FromStr;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::thread;
use std::time::Instant;

use clap::{Arg, Command};
use futures::executor::{block_on, ThreadPool};
use futures::executor::block_on;
use mountpoint_s3::prefetch::{default_prefetch, Prefetch, PrefetchResult};
use mountpoint_s3_client::config::{EndpointConfig, S3ClientConfig};
use mountpoint_s3_client::types::ETag;
use mountpoint_s3_client::S3CrtClient;
use mountpoint_s3_client::{ObjectClient, S3CrtClient};
use mountpoint_s3_crt::common::rust_log_adapter::RustLogAdapter;
use tracing_subscriber::fmt::Subscriber;
use tracing_subscriber::util::SubscriberInitExt;
Expand All @@ -32,7 +34,6 @@ fn main() {
.about("Download a single key from S3 and ignore its contents")
.arg(Arg::new("bucket").required(true))
.arg(Arg::new("key").required(true))
.arg(Arg::new("size").required(true))
.arg(
Arg::new("throughput-target-gbps")
.long("throughput-target-gbps")
Expand All @@ -43,33 +44,46 @@ fn main() {
.long("part-size")
.help("Part size for multi-part GET and PUT"),
)
.arg(Arg::new("read-size").long("read-size").help("Size of read requests"))
.arg(
Arg::new("iterations")
.long("iterations")
.help("Number of times to download"),
)
.arg(
Arg::new("downloads")
.long("downloads")
.help("Number of concurrent downloads"),
)
.arg(Arg::new("region").long("region").default_value("us-east-1"))
.get_matches();

let bucket = matches.get_one::<String>("bucket").unwrap();
let key = matches.get_one::<String>("key").unwrap();
let size = matches
.get_one::<String>("size")
.unwrap()
.parse::<u64>()
.expect("size must be u64");
let throughput_target_gbps = matches
.get_one::<String>("throughput-target-gbps")
.map(|s| s.parse::<f64>().expect("throughput target must be an f64"));
let part_size = matches
.get_one::<String>("part-size")
.map(|s| s.parse::<usize>().expect("part size must be a usize"));
let read_size = matches
.get_one::<String>("read-size")
.map(|s| s.parse::<usize>().expect("read size must be a usize"))
.unwrap_or(128 * 1024);
let iterations = matches
.get_one::<String>("iterations")
.map(|s| s.parse::<usize>().expect("iterations must be a number"));
let downloads = matches
.get_one::<String>("downloads")
.map(|s| s.parse::<usize>().expect("downloads must be a number"))
.unwrap_or(1);
let region = matches.get_one::<String>("region").unwrap();

let mut config = S3ClientConfig::new().endpoint_config(EndpointConfig::new(region));
let initial_read_window_size = 1024 * 1024 + 128 * 1024;
let mut config = S3ClientConfig::new()
.endpoint_config(EndpointConfig::new(region))
.read_backpressure(true)
.initial_read_window(initial_read_window_size);
if let Some(throughput_target_gbps) = throughput_target_gbps {
config = config.throughput_target_gbps(throughput_target_gbps);
}
Expand All @@ -78,28 +92,38 @@ fn main() {
}
let client = Arc::new(S3CrtClient::new(config).expect("couldn't create client"));

let head_object_result = block_on(client.head_object(bucket, key)).expect("HeadObject failed");
let size = head_object_result.object.size;
let etag = ETag::from_str(&head_object_result.object.etag).unwrap();

for i in 0..iterations.unwrap_or(1) {
let runtime = ThreadPool::builder().pool_size(1).create().unwrap();
let runtime = client.event_loop_group();
let manager = default_prefetch(runtime, Default::default());
let received_size = Arc::new(AtomicU64::new(0));
let received_bytes = Arc::new(AtomicU64::new(0));

let start = Instant::now();

let mut request = manager.prefetch(client.clone(), bucket, key, size, ETag::for_tests());
block_on(async {
loop {
let offset = received_size.load(Ordering::SeqCst);
if offset >= size {
break;
}
let bytes = request.read(offset, 1 << 20).await.unwrap();
received_size.fetch_add(bytes.len() as u64, Ordering::SeqCst);
thread::scope(|scope| {
for _ in 0..downloads {
let received_bytes = received_bytes.clone();
let mut request = manager.prefetch(client.clone(), bucket, key, size, etag.clone());

scope.spawn(|| {
futures::executor::block_on(async move {
let mut offset = 0;
while offset < size {
let bytes = request.read(offset, read_size).await.unwrap();
offset += bytes.len() as u64;
received_bytes.fetch_add(bytes.len() as u64, Ordering::SeqCst);
}
})
});
}
});

let elapsed = start.elapsed();

let received_size = received_size.load(Ordering::SeqCst);
let received_size = received_bytes.load(Ordering::SeqCst);
println!(
"{}: received {} bytes in {:.2}s: {:.2}MiB/s",
i,
Expand Down

0 comments on commit f92bf6c

Please sign in to comment.