Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Download Memory Usage #61

Merged
merged 15 commits into from
Oct 8, 2024
27 changes: 20 additions & 7 deletions aws-s3-transfer-manager/src/operation/download.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use aws_smithy_types::byte_stream::ByteStream;
use body::Body;
use discovery::discover_obj;
use service::{distribute_work, ChunkResponse};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::task::JoinSet;
Expand Down Expand Up @@ -96,9 +97,8 @@ fn handle_discovery_chunk(
completed: &mpsc::Sender<Result<ChunkResponse, crate::error::Error>>,
permit: OwnedWorkPermit,
) -> u64 {
let mut start_seq = 0;

if let Some(stream) = initial_chunk {
let seq = handle.ctx.next_seq();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

start_seq and seq aren't connected anywhere here, should probably set start_seq to seq instead of hard coded to 1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, we would have to set start_seq = seq+1. I have added a current_seq() function which will return 0 or 1 depending upon we called next_seq or not.

let completed = completed.clone();
// spawn a task to actually read the discovery chunk without waiting for it so we
// can get started sooner on any remaining work (if any)
Expand All @@ -107,7 +107,7 @@ fn handle_discovery_chunk(
.collect()
.await
.map(|aggregated| ChunkResponse {
seq: start_seq,
seq,
data: Some(aggregated),
})
.map_err(error::discovery_failed);
Expand All @@ -122,25 +122,38 @@ fn handle_discovery_chunk(
);
}
});
start_seq = 1;
}
start_seq
handle.ctx.current_seq()
}

/// Download operation specific state
#[derive(Debug)]
pub(crate) struct DownloadState {}
pub(crate) struct DownloadState {
current_seq: AtomicU64,
}

type DownloadContext = TransferContext<DownloadState>;

impl DownloadContext {
fn new(handle: Arc<crate::client::Handle>) -> Self {
let state = Arc::new(DownloadState {});
let state = Arc::new(DownloadState {
current_seq: AtomicU64::new(0),
});
TransferContext { handle, state }
}

/// The target part size to use for this download
fn target_part_size_bytes(&self) -> u64 {
self.handle.download_part_size_bytes()
}

/// Returns the next seq to download
fn next_seq(&self) -> u64 {
self.state.current_seq.fetch_add(1, Ordering::SeqCst)
}

/// Returns the current seq
fn current_seq(&self) -> u64 {
self.state.current_seq.load(Ordering::SeqCst)
}
}
94 changes: 34 additions & 60 deletions aws-s3-transfer-manager/src/operation/download/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,39 @@ use super::{DownloadHandle, DownloadInput, DownloadInputBuilder};
#[derive(Debug, Clone)]
pub(super) struct DownloadChunkRequest {
pub(super) ctx: DownloadContext,
pub(super) request: ChunkRequest,
pub(super) remaining: RangeInclusive<u64>,
pub(super) input: DownloadInputBuilder,
pub(super) start_seq: u64,
}

fn next_chunk(
seq: u64,
remaining: RangeInclusive<u64>,
part_size: u64,
start_seq: u64,
input: DownloadInputBuilder,
) -> DownloadInputBuilder {
let start = remaining.start() + ((seq - start_seq) * part_size);
let end_inclusive = cmp::min(start + part_size - 1, *remaining.end());
input.range(header::Range::bytes_inclusive(start, end_inclusive))
}

/// handler (service fn) for a single chunk
async fn download_chunk_handler(
request: DownloadChunkRequest,
) -> Result<ChunkResponse, error::Error> {
let ctx = request.ctx;
let request = request.request;

let op = request.input.into_sdk_operation(ctx.client());

let seq = ctx.next_seq();
let part_size = ctx.handle.download_part_size_bytes();
let input = next_chunk(
seq,
request.remaining,
part_size,
request.start_seq,
request.input,
);

let op = input.into_sdk_operation(ctx.client());
let mut resp = op
.send()
.await
Expand All @@ -43,12 +64,12 @@ async fn download_chunk_handler(

let bytes = body
.collect()
.instrument(tracing::debug_span!("collect-body", seq = request.seq))
.instrument(tracing::debug_span!("collect-body", seq = seq))
.await
.map_err(error::from_kind(error::ErrorKind::ChunkFailed))?;

Ok(ChunkResponse {
seq: request.seq,
seq,
data: Some(bytes),
})
}
Expand All @@ -68,23 +89,6 @@ pub(super) fn chunk_service(
.service(svc)
}

// FIXME - should probably be enum ChunkRequest { Range(..), Part(..) } or have an inner field like such
#[derive(Debug, Clone)]
pub(super) struct ChunkRequest {
// byte range to download
pub(super) range: RangeInclusive<u64>,
pub(super) input: DownloadInputBuilder,
// sequence number
pub(super) seq: u64,
}

impl ChunkRequest {
/// Size of this chunk request in bytes
pub(super) fn size(&self) -> u64 {
self.range.end() - self.range.start() + 1
}
}

#[derive(Debug, Clone)]
pub(crate) struct ChunkResponse {
// TODO(aws-sdk-rust#1159, design) - consider PartialOrd for ChunkResponse and hiding `seq` as internal only detail
Expand All @@ -110,31 +114,18 @@ pub(super) fn distribute_work(
start_seq: u64,
comp_tx: mpsc::Sender<Result<ChunkResponse, error::Error>>,
) {
let end = *remaining.end();
let mut pos = *remaining.start();
let mut remaining = end - pos + 1;
let mut seq = start_seq;

let svc = chunk_service(&handle.ctx);

let part_size = handle.ctx.target_part_size_bytes();
let input: DownloadInputBuilder = input.into();

while remaining > 0 {
let start = pos;
let end_inclusive = cmp::min(pos + part_size - 1, end);

let chunk_req = next_chunk(start, end_inclusive, seq, input.clone());
tracing::trace!(
"distributing chunk(size={}): {:?}",
chunk_req.size(),
chunk_req
);
let chunk_size = chunk_req.size();

let size = *remaining.end() - *remaining.start() + 1;
let num_parts = size.div_ceil(part_size);
for seq in 0..num_parts {
let req = DownloadChunkRequest {
ctx: handle.ctx.clone(),
request: chunk_req,
remaining: remaining.clone(),
input: input.clone(),
start_seq,
};

let svc = svc.clone();
Expand All @@ -147,25 +138,8 @@ pub(super) fn distribute_work(
}
}
.instrument(tracing::debug_span!("download-chunk", seq = seq));

handle.tasks.spawn(task);

seq += 1;
remaining -= chunk_size;
tracing::trace!("remaining = {}", remaining);
pos += chunk_size;
}

tracing::trace!("work fully distributed");
}

fn next_chunk(
start: u64,
end_inclusive: u64,
seq: u64,
input: DownloadInputBuilder,
) -> ChunkRequest {
let range = start..=end_inclusive;
let input = input.range(header::Range::bytes_inclusive(start, end_inclusive));
ChunkRequest { seq, range, input }
}
Loading