Skip to content

Commit

Permalink
Refactor Upload to use Tower (#50)
Browse files Browse the repository at this point in the history
  • Loading branch information
waahm7 authored Sep 24, 2024
1 parent 02887a2 commit e929fac
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 88 deletions.
81 changes: 4 additions & 77 deletions aws-s3-transfer-manager/src/operation/upload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,20 @@ mod output;

mod context;
mod handle;
mod service;

use crate::error;
use crate::io::part_reader::{Builder as PartReaderBuilder, ReadPart};
use crate::io::InputStream;
use aws_sdk_s3::primitives::ByteStream;
use aws_sdk_s3::types::CompletedPart;
use bytes::Buf;
use context::UploadContext;
pub use handle::UploadHandle;
/// Request type for uploads to Amazon S3
pub use input::{UploadInput, UploadInputBuilder};
/// Response type for uploads to Amazon S3
pub use output::{UploadOutput, UploadOutputBuilder};
use service::distribute_work;

use std::cmp;
use std::sync::Arc;
use tracing::Instrument;

/// Maximum number of parts that a single S3 multipart upload supports
const MAX_PARTS: u64 = 10_000;
Expand Down Expand Up @@ -89,21 +87,7 @@ async fn try_start_mpu_upload(
);

handle.set_response(mpu);

let part_reader = Arc::new(
PartReaderBuilder::new()
.stream(stream)
.part_size(part_size.try_into().expect("valid part size"))
.build(),
);

let n_workers = handle.ctx.handle.num_workers();
for i in 0..n_workers {
let worker = upload_parts(handle.ctx.clone(), part_reader.clone())
.instrument(tracing::debug_span!("upload-part", worker = i));
handle.tasks.spawn(worker);
}

distribute_work(handle, stream, part_size)?;
Ok(())
}

Expand Down Expand Up @@ -158,63 +142,6 @@ async fn start_mpu(handle: &UploadHandle) -> Result<UploadOutputBuilder, crate::
Ok(resp.into())
}

/// Worker function that pulls part data off the `reader` and uploads each part until the reader
/// is exhausted. If any part fails the worker will return the error and stop processing. If
/// the worker finishes successfully the completed parts uploaded by this worker are returned.
async fn upload_parts(
ctx: UploadContext,
reader: Arc<impl ReadPart>,
) -> Result<Vec<CompletedPart>, error::Error> {
let mut completed_parts = Vec::new();
loop {
let part_result = reader.next_part().await?;
let part_data = match part_result {
Some(part_data) => part_data,
None => break,
};

let part_number = part_data.part_number as i32;
tracing::trace!("recv'd part number {}", part_number);

let content_length = part_data.data.remaining();
let body = ByteStream::from(part_data.data);

// TODO(aws-sdk-rust#1159): disable payload signing
// TODO(aws-sdk-rust#1159): set checksum fields if applicable
let resp = ctx
.client()
.upload_part()
.set_bucket(ctx.request.bucket.clone())
.set_key(ctx.request.key.clone())
.set_upload_id(ctx.upload_id.clone())
.part_number(part_number)
.content_length(content_length as i64)
.body(body)
.set_sse_customer_algorithm(ctx.request.sse_customer_algorithm.clone())
.set_sse_customer_key(ctx.request.sse_customer_key.clone())
.set_sse_customer_key_md5(ctx.request.sse_customer_key_md5.clone())
.set_request_payer(ctx.request.request_payer.clone())
.set_expected_bucket_owner(ctx.request.expected_bucket_owner.clone())
.send()
.await?;

tracing::trace!("completed upload of part number {}", part_number);
let completed = CompletedPart::builder()
.part_number(part_number)
.set_e_tag(resp.e_tag.clone())
.set_checksum_crc32(resp.checksum_crc32.clone())
.set_checksum_crc32_c(resp.checksum_crc32_c.clone())
.set_checksum_sha1(resp.checksum_sha1.clone())
.set_checksum_sha256(resp.checksum_sha256.clone())
.build();

completed_parts.push(completed);
}

tracing::trace!("no more parts, worker finished");
Ok(completed_parts)
}

#[cfg(test)]
mod test {
use crate::io::InputStream;
Expand Down
43 changes: 32 additions & 11 deletions aws-s3-transfer-manager/src/operation/upload/handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,24 @@
* SPDX-License-Identifier: Apache-2.0
*/

use std::sync::Arc;

use crate::operation::upload::context::UploadContext;
use crate::operation::upload::{UploadOutput, UploadOutputBuilder};
use crate::types::{AbortedUpload, FailedMultipartUploadPolicy};
use aws_sdk_s3::error::DisplayErrorContext;
use aws_sdk_s3::types::{CompletedMultipartUpload, CompletedPart};
use tokio::sync::Mutex;
use tokio::task;

/// Response type for a single upload object request.
#[derive(Debug)]
#[non_exhaustive]
pub struct UploadHandle {
/// All child multipart upload tasks spawned for this upload
pub(crate) tasks: task::JoinSet<Result<Vec<CompletedPart>, crate::error::Error>>,
pub(crate) upload_tasks: Arc<Mutex<task::JoinSet<Result<CompletedPart, crate::error::Error>>>>,
/// All child read body tasks spawned for this upload
pub(crate) read_tasks: task::JoinSet<Result<(), crate::error::Error>>,
/// The context used to drive an upload to completion
pub(crate) ctx: UploadContext,
/// The response that will eventually be yielded to the caller.
Expand All @@ -26,7 +31,8 @@ impl UploadHandle {
/// Create a new upload handle with the given request context
pub(crate) fn new(ctx: UploadContext) -> Self {
Self {
tasks: task::JoinSet::new(),
upload_tasks: Arc::new(Mutex::new(task::JoinSet::new())),
read_tasks: task::JoinSet::new(),
ctx,
response: None,
}
Expand Down Expand Up @@ -54,11 +60,16 @@ impl UploadHandle {
pub async fn abort(&mut self) -> Result<AbortedUpload, crate::error::Error> {
// TODO(aws-sdk-rust#1159) - handle already completed upload

// cancel in-progress uploads
self.tasks.abort_all();
// cancel in-progress read_body tasks
self.read_tasks.abort_all();
while (self.read_tasks.join_next().await).is_some() {}

// cancel in-progress upload tasks
let mut tasks = self.upload_tasks.lock().await;
tasks.abort_all();

// join all tasks
while (self.tasks.join_next().await).is_some() {}
while (tasks.join_next().await).is_some() {}

if !self.ctx.is_multipart_upload() {
return Ok(AbortedUpload::default());
Expand Down Expand Up @@ -107,19 +118,29 @@ async fn complete_upload(mut handle: UploadHandle) -> Result<UploadOutput, crate
let span = tracing::debug_span!("joining upload", upload_id = handle.ctx.upload_id);
let _enter = span.enter();

let mut all_parts = Vec::new();
while let Some(join_result) = handle.read_tasks.join_next().await {
if let Err(err) = join_result.expect("task completed") {
tracing::error!("multipart upload failed while trying to read the body, aborting");
// TODO(aws-sdk-rust#1159) - if cancelling causes an error we want to propagate that in the returned error somehow?
if let Err(err) = handle.abort().await {
tracing::error!("failed to abort upload: {}", DisplayErrorContext(err))
};
return Err(err);
}
}

// join all the upload tasks
while let Some(join_result) = handle.tasks.join_next().await {
let mut all_parts = Vec::new();
// join all the upload tasks. We can safely grab the lock since all the read_tasks are done.
let mut tasks = handle.upload_tasks.lock().await;
while let Some(join_result) = tasks.join_next().await {
let result = join_result.expect("task completed");
match result {
Ok(mut completed_parts) => {
all_parts.append(&mut completed_parts);
}
Ok(completed_part) => all_parts.push(completed_part),
// TODO(aws-sdk-rust#1159, design) - do we want to return first error or collect all errors?
Err(err) => {
tracing::error!("multipart upload failed, aborting");
// TODO(aws-sdk-rust#1159) - if cancelling causes an error we want to propagate that in the returned error somehow?
drop(tasks);
if let Err(err) = handle.abort().await {
tracing::error!("failed to abort upload: {}", DisplayErrorContext(err))
};
Expand Down
135 changes: 135 additions & 0 deletions aws-s3-transfer-manager/src/operation/upload/service.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
use std::sync::Arc;

use crate::{
error,
io::{
part_reader::{Builder as PartReaderBuilder, PartData, ReadPart},
InputStream,
},
operation::upload::UploadContext,
};
use aws_sdk_s3::{primitives::ByteStream, types::CompletedPart};
use bytes::Buf;
use tokio::{sync::Mutex, task};
use tower::{service_fn, Service, ServiceBuilder, ServiceExt};
use tracing::Instrument;

use super::UploadHandle;

/// Request/input type for our "upload_part" service.
pub(super) struct UploadPartRequest {
pub(super) ctx: UploadContext,
pub(super) part_data: PartData,
}

/// handler (service fn) for a single part
async fn upload_part_handler(request: UploadPartRequest) -> Result<CompletedPart, error::Error> {
let ctx = request.ctx;
let part_data = request.part_data;
let part_number = part_data.part_number as i32;

// TODO(aws-sdk-rust#1159): disable payload signing
// TODO(aws-sdk-rust#1159): set checksum fields if applicable
let resp = ctx
.client()
.upload_part()
.set_bucket(ctx.request.bucket.clone())
.set_key(ctx.request.key.clone())
.set_upload_id(ctx.upload_id.clone())
.part_number(part_number)
.content_length(part_data.data.remaining() as i64)
.body(ByteStream::from(part_data.data))
.set_sse_customer_algorithm(ctx.request.sse_customer_algorithm.clone())
.set_sse_customer_key(ctx.request.sse_customer_key.clone())
.set_sse_customer_key_md5(ctx.request.sse_customer_key_md5.clone())
.set_request_payer(ctx.request.request_payer.clone())
.set_expected_bucket_owner(ctx.request.expected_bucket_owner.clone())
.send()
.await?;

tracing::trace!("completed upload of part number {}", part_number);
let completed = CompletedPart::builder()
.part_number(part_number)
.set_e_tag(resp.e_tag.clone())
.set_checksum_crc32(resp.checksum_crc32.clone())
.set_checksum_crc32_c(resp.checksum_crc32_c.clone())
.set_checksum_sha1(resp.checksum_sha1.clone())
.set_checksum_sha256(resp.checksum_sha256.clone())
.build();

Ok(completed)
}

/// Create a new tower::Service for uploading individual parts of an object to S3
pub(super) fn upload_part_service(
ctx: &UploadContext,
) -> impl Service<UploadPartRequest, Response = CompletedPart, Error = error::Error, Future: Send>
+ Clone
+ Send {
let svc = service_fn(upload_part_handler);
ServiceBuilder::new()
// FIXME - This setting will need to be globalized.
.concurrency_limit(ctx.handle.num_workers())
.service(svc)
}

/// Spawn tasks to read the body and upload the remaining parts of object
///
/// # Arguments
///
/// * handle - the handle for this upload
/// * stream - the body input stream
/// * part_size - the part_size for each part
pub(super) fn distribute_work(
handle: &mut UploadHandle,
stream: InputStream,
part_size: u64,
) -> Result<(), error::Error> {
let part_reader = Arc::new(
PartReaderBuilder::new()
.stream(stream)
.part_size(part_size.try_into().expect("valid part size"))
.build(),
);
let svc = upload_part_service(&handle.ctx);
let n_workers = handle.ctx.handle.num_workers();
for i in 0..n_workers {
let worker = read_body(
part_reader.clone(),
handle.ctx.clone(),
svc.clone(),
handle.upload_tasks.clone(),
)
.instrument(tracing::debug_span!("read_body", worker = i));
handle.read_tasks.spawn(worker);
}
tracing::trace!("work distributed for uploading parts");
Ok(())
}

/// Worker function that pulls part data from the `part_reader` and spawns tasks to upload each part until the reader
/// is exhausted. If any part fails, the worker will return the error and stop processing.
pub(super) async fn read_body(
part_reader: Arc<impl ReadPart>,
ctx: UploadContext,
svc: impl Service<UploadPartRequest, Response = CompletedPart, Error = error::Error, Future: Send>
+ Clone
+ Send
+ 'static,
upload_tasks: Arc<Mutex<task::JoinSet<Result<CompletedPart, crate::error::Error>>>>,
) -> Result<(), error::Error> {
while let Some(part_data) = part_reader.next_part().await? {
let part_number = part_data.part_number;
let req = UploadPartRequest {
ctx: ctx.clone(),
part_data,
};
let svc = svc.clone();
let task = svc.oneshot(req).instrument(tracing::trace_span!(
"upload_part",
part_number = part_number
));
upload_tasks.lock().await.spawn(task);
}
Ok(())
}

0 comments on commit e929fac

Please sign in to comment.