-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor Upload to use Tower #50
Changes from all commits
ea1a26c
b2818e7
70d505e
4ac47d7
fefb266
07cf4af
5e6de70
27f87ee
67508ac
91bb163
667ee03
12d19a1
001926b
7e0f0af
b8c8832
ae930c7
fc96c1d
a064a51
1c1c44a
3fe0f61
a90195f
347ce5f
b5a1e3e
c754db6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
use std::sync::Arc; | ||
|
||
use crate::{ | ||
error, | ||
io::{ | ||
part_reader::{Builder as PartReaderBuilder, PartData, ReadPart}, | ||
InputStream, | ||
}, | ||
operation::upload::UploadContext, | ||
}; | ||
use aws_sdk_s3::{primitives::ByteStream, types::CompletedPart}; | ||
use bytes::Buf; | ||
use tokio::{sync::Mutex, task}; | ||
use tower::{service_fn, Service, ServiceBuilder, ServiceExt}; | ||
use tracing::Instrument; | ||
|
||
use super::UploadHandle; | ||
|
||
/// Request/input type for our "upload_part" service. | ||
pub(super) struct UploadPartRequest { | ||
pub(super) ctx: UploadContext, | ||
pub(super) part_data: PartData, | ||
} | ||
|
||
/// handler (service fn) for a single part | ||
async fn upload_part_handler(request: UploadPartRequest) -> Result<CompletedPart, error::Error> { | ||
let ctx = request.ctx; | ||
let part_data = request.part_data; | ||
let part_number = part_data.part_number as i32; | ||
|
||
// TODO(aws-sdk-rust#1159): disable payload signing | ||
// TODO(aws-sdk-rust#1159): set checksum fields if applicable | ||
let resp = ctx | ||
.client() | ||
.upload_part() | ||
.set_bucket(ctx.request.bucket.clone()) | ||
.set_key(ctx.request.key.clone()) | ||
.set_upload_id(ctx.upload_id.clone()) | ||
.part_number(part_number) | ||
.content_length(part_data.data.remaining() as i64) | ||
.body(ByteStream::from(part_data.data)) | ||
.set_sse_customer_algorithm(ctx.request.sse_customer_algorithm.clone()) | ||
.set_sse_customer_key(ctx.request.sse_customer_key.clone()) | ||
.set_sse_customer_key_md5(ctx.request.sse_customer_key_md5.clone()) | ||
.set_request_payer(ctx.request.request_payer.clone()) | ||
.set_expected_bucket_owner(ctx.request.expected_bucket_owner.clone()) | ||
.send() | ||
.await?; | ||
|
||
tracing::trace!("completed upload of part number {}", part_number); | ||
let completed = CompletedPart::builder() | ||
.part_number(part_number) | ||
.set_e_tag(resp.e_tag.clone()) | ||
.set_checksum_crc32(resp.checksum_crc32.clone()) | ||
.set_checksum_crc32_c(resp.checksum_crc32_c.clone()) | ||
.set_checksum_sha1(resp.checksum_sha1.clone()) | ||
.set_checksum_sha256(resp.checksum_sha256.clone()) | ||
.build(); | ||
|
||
Ok(completed) | ||
} | ||
|
||
/// Create a new tower::Service for uploading individual parts of an object to S3 | ||
pub(super) fn upload_part_service( | ||
ctx: &UploadContext, | ||
) -> impl Service<UploadPartRequest, Response = CompletedPart, Error = error::Error, Future: Send> | ||
+ Clone | ||
+ Send { | ||
let svc = service_fn(upload_part_handler); | ||
ServiceBuilder::new() | ||
// FIXME - This setting will need to be globalized. | ||
.concurrency_limit(ctx.handle.num_workers()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe add a TODO that this needs "globalized" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, I have added a FIXME. |
||
.service(svc) | ||
} | ||
|
||
/// Spawn tasks to read the body and upload the remaining parts of object | ||
/// | ||
/// # Arguments | ||
/// | ||
/// * handle - the handle for this upload | ||
/// * stream - the body input stream | ||
/// * part_size - the part_size for each part | ||
pub(super) fn distribute_work( | ||
handle: &mut UploadHandle, | ||
stream: InputStream, | ||
part_size: u64, | ||
) -> Result<(), error::Error> { | ||
let part_reader = Arc::new( | ||
PartReaderBuilder::new() | ||
.stream(stream) | ||
.part_size(part_size.try_into().expect("valid part size")) | ||
.build(), | ||
); | ||
let svc = upload_part_service(&handle.ctx); | ||
let n_workers = handle.ctx.handle.num_workers(); | ||
for i in 0..n_workers { | ||
let worker = read_body( | ||
part_reader.clone(), | ||
handle.ctx.clone(), | ||
svc.clone(), | ||
handle.upload_tasks.clone(), | ||
) | ||
.instrument(tracing::debug_span!("read_body", worker = i)); | ||
handle.read_tasks.spawn(worker); | ||
} | ||
Comment on lines
+96
to
+105
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just to better understand, from the PR description:
Do these sentences imply that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Upload part is restricted by a pool of workers, but instead of us explicitly managing a pool of workers where each worker reads a part_body and then uploads a part, we let tower manage it using the concurrency_limit layer at https://github.com/awslabs/aws-s3-transfer-manager-rs/pull/50/files#diff-a6c023261dc31765237ede4502d30ba640bca1ef9be58cb92e48ccdf69c4768cR72, and the pool of read_body workers simply spawns N upload_part tasks. |
||
tracing::trace!("work distributed for uploading parts"); | ||
Ok(()) | ||
} | ||
|
||
/// Worker function that pulls part data from the `part_reader` and spawns tasks to upload each part until the reader | ||
/// is exhausted. If any part fails, the worker will return the error and stop processing. | ||
pub(super) async fn read_body( | ||
part_reader: Arc<impl ReadPart>, | ||
ctx: UploadContext, | ||
svc: impl Service<UploadPartRequest, Response = CompletedPart, Error = error::Error, Future: Send> | ||
+ Clone | ||
+ Send | ||
+ 'static, | ||
upload_tasks: Arc<Mutex<task::JoinSet<Result<CompletedPart, crate::error::Error>>>>, | ||
) -> Result<(), error::Error> { | ||
while let Some(part_data) = part_reader.next_part().await? { | ||
let part_number = part_data.part_number; | ||
let req = UploadPartRequest { | ||
ctx: ctx.clone(), | ||
part_data, | ||
}; | ||
let svc = svc.clone(); | ||
let task = svc.oneshot(req).instrument(tracing::trace_span!( | ||
"upload_part", | ||
part_number = part_number | ||
)); | ||
upload_tasks.lock().await.spawn(task); | ||
} | ||
Ok(()) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can probably be a regular mutex from stdlib. See https://doc.servo.org/tokio/sync/struct.Mutex.html#which-kind-of-mutex-should-you-use
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure if we can make it a regular mutex. We do keep this lock across await points at https://github.com/awslabs/aws-s3-transfer-manager-rs/pull/50/files#diff-a98b4945e17362a1dcad7da7e15d7ef7af38ff5f88ae751261823a5f23bb3652R135.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahh I missed that.