Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Upload to use Tower #50

Merged
merged 24 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 4 additions & 77 deletions aws-s3-transfer-manager/src/operation/upload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,20 @@ mod output;

mod context;
mod handle;
mod service;

use crate::error;
use crate::io::part_reader::{Builder as PartReaderBuilder, ReadPart};
use crate::io::InputStream;
use aws_sdk_s3::primitives::ByteStream;
use aws_sdk_s3::types::CompletedPart;
use bytes::Buf;
use context::UploadContext;
pub use handle::UploadHandle;
/// Request type for uploads to Amazon S3
pub use input::{UploadInput, UploadInputBuilder};
/// Response type for uploads to Amazon S3
pub use output::{UploadOutput, UploadOutputBuilder};
use service::distribute_work;

use std::cmp;
use std::sync::Arc;
use tracing::Instrument;

/// Maximum number of parts that a single S3 multipart upload supports
const MAX_PARTS: u64 = 10_000;
Expand Down Expand Up @@ -89,21 +87,7 @@ async fn try_start_mpu_upload(
);

handle.set_response(mpu);

let part_reader = Arc::new(
PartReaderBuilder::new()
.stream(stream)
.part_size(part_size.try_into().expect("valid part size"))
.build(),
);

let n_workers = handle.ctx.handle.num_workers();
for i in 0..n_workers {
let worker = upload_parts(handle.ctx.clone(), part_reader.clone())
.instrument(tracing::debug_span!("upload-part", worker = i));
handle.tasks.spawn(worker);
}

distribute_work(handle, stream, part_size)?;
Ok(())
}

Expand Down Expand Up @@ -158,63 +142,6 @@ async fn start_mpu(handle: &UploadHandle) -> Result<UploadOutputBuilder, crate::
Ok(resp.into())
}

/// Worker function that pulls part data off the `reader` and uploads each part until the reader
/// is exhausted. If any part fails the worker will return the error and stop processing. If
/// the worker finishes successfully the completed parts uploaded by this worker are returned.
async fn upload_parts(
ctx: UploadContext,
reader: Arc<impl ReadPart>,
) -> Result<Vec<CompletedPart>, error::Error> {
let mut completed_parts = Vec::new();
loop {
let part_result = reader.next_part().await?;
let part_data = match part_result {
Some(part_data) => part_data,
None => break,
};

let part_number = part_data.part_number as i32;
tracing::trace!("recv'd part number {}", part_number);

let content_length = part_data.data.remaining();
let body = ByteStream::from(part_data.data);

// TODO(aws-sdk-rust#1159): disable payload signing
// TODO(aws-sdk-rust#1159): set checksum fields if applicable
let resp = ctx
.client()
.upload_part()
.set_bucket(ctx.request.bucket.clone())
.set_key(ctx.request.key.clone())
.set_upload_id(ctx.upload_id.clone())
.part_number(part_number)
.content_length(content_length as i64)
.body(body)
.set_sse_customer_algorithm(ctx.request.sse_customer_algorithm.clone())
.set_sse_customer_key(ctx.request.sse_customer_key.clone())
.set_sse_customer_key_md5(ctx.request.sse_customer_key_md5.clone())
.set_request_payer(ctx.request.request_payer.clone())
.set_expected_bucket_owner(ctx.request.expected_bucket_owner.clone())
.send()
.await?;

tracing::trace!("completed upload of part number {}", part_number);
let completed = CompletedPart::builder()
.part_number(part_number)
.set_e_tag(resp.e_tag.clone())
.set_checksum_crc32(resp.checksum_crc32.clone())
.set_checksum_crc32_c(resp.checksum_crc32_c.clone())
.set_checksum_sha1(resp.checksum_sha1.clone())
.set_checksum_sha256(resp.checksum_sha256.clone())
.build();

completed_parts.push(completed);
}

tracing::trace!("no more parts, worker finished");
Ok(completed_parts)
}

#[cfg(test)]
mod test {
use crate::io::InputStream;
Expand Down
43 changes: 32 additions & 11 deletions aws-s3-transfer-manager/src/operation/upload/handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,24 @@
* SPDX-License-Identifier: Apache-2.0
*/

use std::sync::Arc;

use crate::operation::upload::context::UploadContext;
use crate::operation::upload::{UploadOutput, UploadOutputBuilder};
use crate::types::{AbortedUpload, FailedMultipartUploadPolicy};
use aws_sdk_s3::error::DisplayErrorContext;
use aws_sdk_s3::types::{CompletedMultipartUpload, CompletedPart};
use tokio::sync::Mutex;
use tokio::task;

/// Response type for a single upload object request.
#[derive(Debug)]
#[non_exhaustive]
pub struct UploadHandle {
/// All child multipart upload tasks spawned for this upload
pub(crate) tasks: task::JoinSet<Result<Vec<CompletedPart>, crate::error::Error>>,
pub(crate) upload_tasks: Arc<Mutex<task::JoinSet<Result<CompletedPart, crate::error::Error>>>>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if we can make it a regular mutex. We do keep this lock across await points at https://github.com/awslabs/aws-s3-transfer-manager-rs/pull/50/files#diff-a98b4945e17362a1dcad7da7e15d7ef7af38ff5f88ae751261823a5f23bb3652R135.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh I missed that.

/// All child read body tasks spawned for this upload
pub(crate) read_tasks: task::JoinSet<Result<(), crate::error::Error>>,
/// The context used to drive an upload to completion
pub(crate) ctx: UploadContext,
/// The response that will eventually be yielded to the caller.
Expand All @@ -26,7 +31,8 @@ impl UploadHandle {
/// Create a new upload handle with the given request context
pub(crate) fn new(ctx: UploadContext) -> Self {
Self {
tasks: task::JoinSet::new(),
upload_tasks: Arc::new(Mutex::new(task::JoinSet::new())),
read_tasks: task::JoinSet::new(),
ctx,
response: None,
}
Expand Down Expand Up @@ -54,11 +60,16 @@ impl UploadHandle {
pub async fn abort(&mut self) -> Result<AbortedUpload, crate::error::Error> {
// TODO(aws-sdk-rust#1159) - handle already completed upload

// cancel in-progress uploads
self.tasks.abort_all();
// cancel in-progress read_body tasks
self.read_tasks.abort_all();
while (self.read_tasks.join_next().await).is_some() {}

// cancel in-progress upload tasks
let mut tasks = self.upload_tasks.lock().await;
tasks.abort_all();

// join all tasks
while (self.tasks.join_next().await).is_some() {}
while (tasks.join_next().await).is_some() {}

if !self.ctx.is_multipart_upload() {
return Ok(AbortedUpload::default());
Expand Down Expand Up @@ -107,19 +118,29 @@ async fn complete_upload(mut handle: UploadHandle) -> Result<UploadOutput, crate
let span = tracing::debug_span!("joining upload", upload_id = handle.ctx.upload_id);
let _enter = span.enter();

let mut all_parts = Vec::new();
while let Some(join_result) = handle.read_tasks.join_next().await {
if let Err(err) = join_result.expect("task completed") {
tracing::error!("multipart upload failed while trying to read the body, aborting");
// TODO(aws-sdk-rust#1159) - if cancelling causes an error we want to propagate that in the returned error somehow?
if let Err(err) = handle.abort().await {
tracing::error!("failed to abort upload: {}", DisplayErrorContext(err))
};
return Err(err);
}
}

// join all the upload tasks
while let Some(join_result) = handle.tasks.join_next().await {
let mut all_parts = Vec::new();
// join all the upload tasks. We can safely grab the lock since all the read_tasks are done.
let mut tasks = handle.upload_tasks.lock().await;
while let Some(join_result) = tasks.join_next().await {
let result = join_result.expect("task completed");
match result {
Ok(mut completed_parts) => {
all_parts.append(&mut completed_parts);
}
Ok(completed_part) => all_parts.push(completed_part),
// TODO(aws-sdk-rust#1159, design) - do we want to return first error or collect all errors?
Err(err) => {
tracing::error!("multipart upload failed, aborting");
// TODO(aws-sdk-rust#1159) - if cancelling causes an error we want to propagate that in the returned error somehow?
drop(tasks);
if let Err(err) = handle.abort().await {
tracing::error!("failed to abort upload: {}", DisplayErrorContext(err))
};
Expand Down
135 changes: 135 additions & 0 deletions aws-s3-transfer-manager/src/operation/upload/service.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
use std::sync::Arc;

use crate::{
error,
io::{
part_reader::{Builder as PartReaderBuilder, PartData, ReadPart},
InputStream,
},
operation::upload::UploadContext,
};
use aws_sdk_s3::{primitives::ByteStream, types::CompletedPart};
use bytes::Buf;
use tokio::{sync::Mutex, task};
use tower::{service_fn, Service, ServiceBuilder, ServiceExt};
use tracing::Instrument;

use super::UploadHandle;

/// Request/input type for our "upload_part" service.
pub(super) struct UploadPartRequest {
pub(super) ctx: UploadContext,
pub(super) part_data: PartData,
}

/// handler (service fn) for a single part
async fn upload_part_handler(request: UploadPartRequest) -> Result<CompletedPart, error::Error> {
let ctx = request.ctx;
let part_data = request.part_data;
let part_number = part_data.part_number as i32;

// TODO(aws-sdk-rust#1159): disable payload signing
// TODO(aws-sdk-rust#1159): set checksum fields if applicable
let resp = ctx
.client()
.upload_part()
.set_bucket(ctx.request.bucket.clone())
.set_key(ctx.request.key.clone())
.set_upload_id(ctx.upload_id.clone())
.part_number(part_number)
.content_length(part_data.data.remaining() as i64)
.body(ByteStream::from(part_data.data))
.set_sse_customer_algorithm(ctx.request.sse_customer_algorithm.clone())
.set_sse_customer_key(ctx.request.sse_customer_key.clone())
.set_sse_customer_key_md5(ctx.request.sse_customer_key_md5.clone())
.set_request_payer(ctx.request.request_payer.clone())
.set_expected_bucket_owner(ctx.request.expected_bucket_owner.clone())
.send()
.await?;

tracing::trace!("completed upload of part number {}", part_number);
let completed = CompletedPart::builder()
.part_number(part_number)
.set_e_tag(resp.e_tag.clone())
.set_checksum_crc32(resp.checksum_crc32.clone())
.set_checksum_crc32_c(resp.checksum_crc32_c.clone())
.set_checksum_sha1(resp.checksum_sha1.clone())
.set_checksum_sha256(resp.checksum_sha256.clone())
.build();

Ok(completed)
}

/// Create a new tower::Service for uploading individual parts of an object to S3
pub(super) fn upload_part_service(
ctx: &UploadContext,
) -> impl Service<UploadPartRequest, Response = CompletedPart, Error = error::Error, Future: Send>
+ Clone
+ Send {
let svc = service_fn(upload_part_handler);
ServiceBuilder::new()
// FIXME - This setting will need to be globalized.
.concurrency_limit(ctx.handle.num_workers())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add a TODO that this needs "globalized"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I have added a FIXME.

.service(svc)
}

/// Spawn tasks to read the body and upload the remaining parts of object
///
/// # Arguments
///
/// * handle - the handle for this upload
/// * stream - the body input stream
/// * part_size - the part_size for each part
pub(super) fn distribute_work(
handle: &mut UploadHandle,
stream: InputStream,
part_size: u64,
) -> Result<(), error::Error> {
let part_reader = Arc::new(
PartReaderBuilder::new()
.stream(stream)
.part_size(part_size.try_into().expect("valid part size"))
.build(),
);
let svc = upload_part_service(&handle.ctx);
let n_workers = handle.ctx.handle.num_workers();
for i in 0..n_workers {
let worker = read_body(
part_reader.clone(),
handle.ctx.clone(),
svc.clone(),
handle.upload_tasks.clone(),
)
.instrument(tracing::debug_span!("read_body", worker = i));
handle.read_tasks.spawn(worker);
}
Comment on lines +96 to +105

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to better understand, from the PR description:

This also refactors the upload pipeline to distinguish between the read_body and upload_part phases. read_body still uses a fixed pool of workers because we need to support unknown content length, and I couldn't figure out how to implement an unknown amount of work in tower.

Do these sentences imply that upload_part is NOT restricted by the fixed pool of workers, since it only refers to read_body still using a fixed pool of workers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Upload part is restricted by a pool of workers, but instead of us explicitly managing a pool of workers where each worker reads a part_body and then uploads a part, we let tower manage it using the concurrency_limit layer at https://github.com/awslabs/aws-s3-transfer-manager-rs/pull/50/files#diff-a6c023261dc31765237ede4502d30ba640bca1ef9be58cb92e48ccdf69c4768cR72, and the pool of read_body workers simply spawns N upload_part tasks.

tracing::trace!("work distributed for uploading parts");
Ok(())
}

/// Worker function that pulls part data from the `part_reader` and spawns tasks to upload each part until the reader
/// is exhausted. If any part fails, the worker will return the error and stop processing.
pub(super) async fn read_body(
part_reader: Arc<impl ReadPart>,
ctx: UploadContext,
svc: impl Service<UploadPartRequest, Response = CompletedPart, Error = error::Error, Future: Send>
+ Clone
+ Send
+ 'static,
upload_tasks: Arc<Mutex<task::JoinSet<Result<CompletedPart, crate::error::Error>>>>,
) -> Result<(), error::Error> {
while let Some(part_data) = part_reader.next_part().await? {
let part_number = part_data.part_number;
let req = UploadPartRequest {
ctx: ctx.clone(),
part_data,
};
let svc = svc.clone();
let task = svc.oneshot(req).instrument(tracing::trace_span!(
"upload_part",
part_number = part_number
));
upload_tasks.lock().await.spawn(task);
}
Ok(())
}