Skip to content

Commit 6cb6301

Browse files
committed
feat(storage:s3): multi-part upload: upload parts concurrently
1 parent b4d97d8 commit 6cb6301

File tree

3 files changed

+118
-86
lines changed

3 files changed

+118
-86
lines changed

checkstyle/suppressions.xml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
<suppress checks="ClassDataAbstractionCoupling" files=".*Test\.java"/>
2323
<suppress checks="ClassFanOutComplexity" files=".*Test\.java"/>
2424
<suppress checks="ClassFanOutComplexity" files="RemoteStorageManager.java"/>
25+
<suppress checks="ClassDataAbstractionCoupling" files="S3MultiPartOutputStream.java"/>
2526
<suppress checks="ClassDataAbstractionCoupling" files="S3StorageConfig.java"/>
2627
<suppress checks="ClassDataAbstractionCoupling" files="RemoteStorageManager.java"/>
2728
</suppressions>

storage/s3/src/main/java/io/aiven/kafka/tieredstorage/storage/s3/S3MultiPartOutputStream.java

Lines changed: 58 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -18,29 +18,34 @@
1818

1919
import java.io.ByteArrayInputStream;
2020
import java.io.IOException;
21-
import java.io.InputStream;
2221
import java.io.OutputStream;
2322
import java.nio.ByteBuffer;
24-
import java.util.ArrayList;
25-
import java.util.List;
23+
import java.util.Map;
2624
import java.util.Objects;
25+
import java.util.concurrent.CompletableFuture;
26+
import java.util.concurrent.ConcurrentHashMap;
27+
import java.util.concurrent.ExecutionException;
28+
import java.util.concurrent.atomic.AtomicInteger;
29+
import java.util.stream.Collectors;
2730

2831
import com.amazonaws.services.s3.AmazonS3;
2932
import com.amazonaws.services.s3.model.AbortMultipartUploadRequest;
3033
import com.amazonaws.services.s3.model.CompleteMultipartUploadRequest;
31-
import com.amazonaws.services.s3.model.CompleteMultipartUploadResult;
3234
import com.amazonaws.services.s3.model.InitiateMultipartUploadRequest;
3335
import com.amazonaws.services.s3.model.InitiateMultipartUploadResult;
3436
import com.amazonaws.services.s3.model.PartETag;
3537
import com.amazonaws.services.s3.model.UploadPartRequest;
36-
import com.amazonaws.services.s3.model.UploadPartResult;
3738
import org.slf4j.Logger;
3839
import org.slf4j.LoggerFactory;
3940

4041
/**
4142
* S3 multipart output stream.
4243
* Enable uploads to S3 with unknown size by feeding input bytes to multiple parts and upload them on close.
4344
*
45+
* <p>OutputStream is used to write sequentially, but
46+
* uploading parts happen asynchronously to reduce full upload latency.
47+
* Concurrency happens within the output stream implementation and does not require changes on the callers.
48+
*
4449
* <p>Requires S3 client and starts a multipart transaction when instantiated. Do not reuse.
4550
*/
4651
public class S3MultiPartOutputStream extends OutputStream {
@@ -54,7 +59,9 @@ public class S3MultiPartOutputStream extends OutputStream {
5459
final int partSize;
5560

5661
private final String uploadId;
57-
private final List<PartETag> partETags = new ArrayList<>();
62+
private CompletableFuture<Map<Integer, String>> partUploads =
63+
CompletableFuture.completedFuture(new ConcurrentHashMap<>());
64+
private final AtomicInteger partNumber = new AtomicInteger(0);
5865

5966
private boolean closed;
6067

@@ -88,32 +95,45 @@ public void write(final byte[] b, final int off, final int len) throws IOExcepti
8895
}
8996
final ByteBuffer source = ByteBuffer.wrap(b, off, len);
9097
while (source.hasRemaining()) {
91-
final int transferred = Math.min(partBuffer.remaining(), source.remaining());
92-
final int offset = source.arrayOffset() + source.position();
93-
// TODO: get rid of this array copying
94-
partBuffer.put(source.array(), offset, transferred);
95-
source.position(source.position() + transferred);
98+
final int toCopy = Math.min(partBuffer.remaining(), source.remaining());
99+
final int positionAfterCopying = source.position() + toCopy;
100+
source.limit(positionAfterCopying);
101+
partBuffer.put(source.slice());
102+
source.clear(); // reset limit
103+
source.position(positionAfterCopying);
96104
if (!partBuffer.hasRemaining()) {
97-
flushBuffer(0, partSize);
105+
partBuffer.position(0);
106+
partBuffer.limit(partSize);
107+
flushBuffer(partBuffer.slice(), partSize);
108+
partBuffer.clear();
98109
}
99110
}
100111
}
101112

102113
@Override
103114
public void close() throws IOException {
104115
if (partBuffer.position() > 0) {
105-
flushBuffer(partBuffer.arrayOffset(), partBuffer.position());
116+
final int actualPartSize = partBuffer.position();
117+
partBuffer.position(0);
118+
partBuffer.limit(actualPartSize);
119+
flushBuffer(partBuffer.slice(), actualPartSize);
106120
}
107121
if (Objects.nonNull(uploadId)) {
108-
if (!partETags.isEmpty()) {
122+
if (partNumber.get() > 0) {
109123
try {
110-
final CompleteMultipartUploadRequest request =
111-
new CompleteMultipartUploadRequest(bucketName, key, uploadId, partETags);
112-
final CompleteMultipartUploadResult result = client.completeMultipartUpload(request);
124+
// wait for all uploads to complete successfully before committing
125+
final var partETags = partUploads.get(); // TODO: maybe set a timeout?
126+
final var tags = partETags.entrySet()
127+
.stream()
128+
.map(entry -> new PartETag(entry.getKey(), entry.getValue()))
129+
.collect(Collectors.toList());
130+
final var request = new CompleteMultipartUploadRequest(bucketName, key, uploadId, tags);
131+
final var result = client.completeMultipartUpload(request);
113132
log.debug("Completed multipart upload {} with result {}", uploadId, result);
114-
} catch (final Exception e) {
133+
} catch (final InterruptedException | ExecutionException e) {
115134
log.error("Failed to complete multipart upload {}, aborting transaction", uploadId, e);
116135
client.abortMultipartUpload(new AbortMultipartUploadRequest(bucketName, key, uploadId));
136+
throw new IOException(e);
117137
}
118138
} else {
119139
client.abortMultipartUpload(new AbortMultipartUploadRequest(bucketName, key, uploadId));
@@ -122,31 +142,32 @@ public void close() throws IOException {
122142
closed = true;
123143
}
124144

125-
private void flushBuffer(final int offset,
126-
final int actualPartSize) throws IOException {
145+
private void flushBuffer(final ByteBuffer partBuffer, final int actualPartSize) throws IOException {
127146
try {
128-
final ByteArrayInputStream in = new ByteArrayInputStream(partBuffer.array(), offset, actualPartSize);
129-
uploadPart(in, actualPartSize);
130-
partBuffer.clear();
147+
final byte[] array = new byte[actualPartSize];
148+
partBuffer.get(array, 0, actualPartSize);
149+
150+
final UploadPartRequest uploadPartRequest =
151+
new UploadPartRequest()
152+
.withBucketName(bucketName)
153+
.withKey(key)
154+
.withUploadId(uploadId)
155+
.withPartSize(actualPartSize)
156+
.withPartNumber(partNumber.incrementAndGet())
157+
.withInputStream(new ByteArrayInputStream(array));
158+
159+
// Run request async
160+
partUploads = partUploads.thenCombine(
161+
CompletableFuture.supplyAsync(() -> client.uploadPart(uploadPartRequest)),
162+
(partETags, result) -> {
163+
partETags.put(result.getPartETag().getPartNumber(), result.getPartETag().getETag());
164+
return partETags;
165+
});
131166
} catch (final Exception e) {
132167
log.error("Failed to upload part in multipart upload {}, aborting transaction", uploadId, e);
133168
client.abortMultipartUpload(new AbortMultipartUploadRequest(bucketName, key, uploadId));
134169
closed = true;
135170
throw new IOException(e);
136171
}
137172
}
138-
139-
private void uploadPart(final InputStream in, final int actualPartSize) {
140-
final int partNumber = partETags.size() + 1;
141-
final UploadPartRequest uploadPartRequest =
142-
new UploadPartRequest()
143-
.withBucketName(bucketName)
144-
.withKey(key)
145-
.withUploadId(uploadId)
146-
.withPartSize(actualPartSize)
147-
.withPartNumber(partNumber)
148-
.withInputStream(in);
149-
final UploadPartResult uploadResult = client.uploadPart(uploadPartRequest);
150-
partETags.add(uploadResult.getPartETag());
151-
}
152173
}

storage/s3/src/test/java/io/aiven/kafka/tieredstorage/storage/s3/S3MultiPartOutputStreamTest.java

Lines changed: 59 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@
1616

1717
package io.aiven.kafka.tieredstorage.storage.s3;
1818

19-
import java.io.ByteArrayInputStream;
2019
import java.io.IOException;
21-
import java.util.ArrayList;
22-
import java.util.List;
20+
import java.util.HashMap;
21+
import java.util.Map;
2322
import java.util.Random;
23+
import java.util.concurrent.ConcurrentHashMap;
24+
import java.util.stream.Collectors;
2425

2526
import com.amazonaws.services.s3.AmazonS3;
2627
import com.amazonaws.services.s3.model.AbortMultipartUploadRequest;
@@ -83,7 +84,7 @@ void sendAbortForAnyExceptionWhileWriting() {
8384
new S3MultiPartOutputStream(BUCKET_NAME, FILE_KEY, 100, mockedS3)) {
8485
out.write(new byte[] {1, 2, 3});
8586
}
86-
}).isInstanceOf(IOException.class).hasCause(testException);
87+
}).isInstanceOf(IOException.class).hasRootCause(testException);
8788

8889
verify(mockedS3).initiateMultipartUpload(any(InitiateMultipartUploadRequest.class));
8990
verify(mockedS3).uploadPart(any(UploadPartRequest.class));
@@ -135,102 +136,116 @@ void writesOneByte() throws Exception {
135136
verify(mockedS3).uploadPart(any(UploadPartRequest.class));
136137
verify(mockedS3).completeMultipartUpload(any(CompleteMultipartUploadRequest.class));
137138

139+
final UploadPartRequest value = uploadPartRequestCaptor.getValue();
138140
assertUploadPartRequest(
139-
uploadPartRequestCaptor.getValue(),
141+
value,
142+
value.getInputStream().readAllBytes(),
140143
1,
141144
1,
142145
new byte[] {1});
143146
assertCompleteMultipartUploadRequest(
144147
completeMultipartUploadRequestCaptor.getValue(),
145-
List.of(new PartETag(1, "SOME_ETAG"))
148+
Map.of(1, "SOME_ETAG")
146149
);
147150
}
148151

149152
@Test
150153
void writesMultipleMessages() throws Exception {
151154
final int bufferSize = 10;
152-
final byte[] message = new byte[bufferSize];
153155

154156
when(mockedS3.initiateMultipartUpload(any(InitiateMultipartUploadRequest.class)))
155157
.thenReturn(newInitiateMultipartUploadResult());
158+
159+
final Map<Integer, UploadPartRequest> uploadPartRequests = new ConcurrentHashMap<>();
160+
final Map<Integer, byte[]> uploadPartContents = new ConcurrentHashMap<>();
156161
when(mockedS3.uploadPart(uploadPartRequestCaptor.capture()))
157-
.thenAnswer(a -> {
158-
final UploadPartRequest up = a.getArgument(0);
162+
.thenAnswer(answer -> {
163+
final UploadPartRequest up = answer.getArgument(0);
164+
//emulate behave of S3 client otherwise we will get wrong array in the memory
165+
uploadPartRequests.put(up.getPartNumber(), up);
166+
uploadPartContents.put(up.getPartNumber(), up.getInputStream().readAllBytes());
167+
159168
return newUploadPartResult(up.getPartNumber(), "SOME_TAG#" + up.getPartNumber());
160169
});
161170
when(mockedS3.completeMultipartUpload(completeMultipartUploadRequestCaptor.capture()))
162171
.thenReturn(new CompleteMultipartUploadResult());
163172

164-
final List<byte[]> expectedMessagesList = new ArrayList<>();
173+
final Map<Integer, byte[]> expectedMessageParts = new HashMap<>();
165174
try (final S3MultiPartOutputStream out =
166175
new S3MultiPartOutputStream(BUCKET_NAME, FILE_KEY, bufferSize, mockedS3)) {
167176
for (int i = 0; i < 3; i++) {
177+
final byte[] message = new byte[bufferSize];
168178
random.nextBytes(message);
169179
out.write(message, 0, message.length);
170-
expectedMessagesList.add(message);
180+
expectedMessageParts.put(i + 1, message);
171181
}
172182
}
173183

174184
verify(mockedS3).initiateMultipartUpload(any(InitiateMultipartUploadRequest.class));
175185
verify(mockedS3, times(3)).uploadPart(any(UploadPartRequest.class));
176186
verify(mockedS3).completeMultipartUpload(any(CompleteMultipartUploadRequest.class));
177187

178-
final List<UploadPartRequest> uploadRequests = uploadPartRequestCaptor.getAllValues();
179-
int counter = 0;
180-
for (final byte[] expectedMessage : expectedMessagesList) {
188+
for (final Integer part : expectedMessageParts.keySet()) {
181189
assertUploadPartRequest(
182-
uploadRequests.get(counter),
190+
uploadPartRequests.get(part),
191+
uploadPartContents.get(part),
183192
bufferSize,
184-
counter + 1,
185-
expectedMessage);
186-
counter++;
193+
part,
194+
expectedMessageParts.get(part)
195+
);
187196
}
188197
assertCompleteMultipartUploadRequest(
189198
completeMultipartUploadRequestCaptor.getValue(),
190-
List.of(new PartETag(1, "SOME_TAG#1"),
191-
new PartETag(2, "SOME_TAG#2"),
192-
new PartETag(3, "SOME_TAG#3"))
199+
Map.of(1, "SOME_TAG#1",
200+
2, "SOME_TAG#2",
201+
3, "SOME_TAG#3")
193202
);
194203
}
195204

196205
@Test
197206
void writesTailMessages() throws Exception {
198207
final int messageSize = 20;
199208

200-
final List<UploadPartRequest> uploadPartRequests = new ArrayList<>();
209+
final Map<Integer, UploadPartRequest> uploadPartRequests = new ConcurrentHashMap<>();
210+
final Map<Integer, byte[]> uploadPartContents = new ConcurrentHashMap<>();
201211

202212
when(mockedS3.initiateMultipartUpload(any(InitiateMultipartUploadRequest.class)))
203213
.thenReturn(newInitiateMultipartUploadResult());
204214
when(mockedS3.uploadPart(any(UploadPartRequest.class)))
205-
.thenAnswer(a -> {
206-
final UploadPartRequest up = a.getArgument(0);
215+
.thenAnswer(answer -> {
216+
final UploadPartRequest up = answer.getArgument(0);
207217
//emulate behave of S3 client otherwise we will get wrong array in the memory
208-
up.setInputStream(new ByteArrayInputStream(up.getInputStream().readAllBytes()));
209-
uploadPartRequests.add(up);
218+
uploadPartRequests.put(up.getPartNumber(), up);
219+
uploadPartContents.put(up.getPartNumber(), up.getInputStream().readAllBytes());
210220

211221
return newUploadPartResult(up.getPartNumber(), "SOME_TAG#" + up.getPartNumber());
212222
});
213223
when(mockedS3.completeMultipartUpload(completeMultipartUploadRequestCaptor.capture()))
214224
.thenReturn(new CompleteMultipartUploadResult());
215225

216-
final byte[] message = new byte[messageSize];
217226

218227
final byte[] expectedFullMessage = new byte[messageSize + 10];
219228
final byte[] expectedTailMessage = new byte[10];
220229

221-
final S3MultiPartOutputStream
222-
out = new S3MultiPartOutputStream(BUCKET_NAME, FILE_KEY, messageSize + 10, mockedS3);
223-
random.nextBytes(message);
224-
out.write(message);
225-
System.arraycopy(message, 0, expectedFullMessage, 0, message.length);
226-
random.nextBytes(message);
227-
out.write(message);
228-
System.arraycopy(message, 0, expectedFullMessage, 20, 10);
229-
System.arraycopy(message, 10, expectedTailMessage, 0, 10);
230+
final var out = new S3MultiPartOutputStream(BUCKET_NAME, FILE_KEY, messageSize + 10, mockedS3);
231+
{
232+
final byte[] message = new byte[messageSize];
233+
random.nextBytes(message);
234+
out.write(message);
235+
System.arraycopy(message, 0, expectedFullMessage, 0, message.length);
236+
}
237+
{
238+
final byte[] message = new byte[messageSize];
239+
random.nextBytes(message);
240+
out.write(message);
241+
System.arraycopy(message, 0, expectedFullMessage, 20, 10);
242+
System.arraycopy(message, 10, expectedTailMessage, 0, 10);
243+
}
230244
out.close();
231245

232-
assertUploadPartRequest(uploadPartRequests.get(0), 30, 1, expectedFullMessage);
233-
assertUploadPartRequest(uploadPartRequests.get(1), 10, 2, expectedTailMessage);
246+
assertThat(uploadPartRequests).hasSize(2);
247+
assertUploadPartRequest(uploadPartRequests.get(1), uploadPartContents.get(1), 30, 1, expectedFullMessage);
248+
assertUploadPartRequest(uploadPartRequests.get(2), uploadPartContents.get(2), 10, 2, expectedTailMessage);
234249

235250
verify(mockedS3).initiateMultipartUpload(any(InitiateMultipartUploadRequest.class));
236251
verify(mockedS3, times(2)).uploadPart(any(UploadPartRequest.class));
@@ -251,6 +266,7 @@ private static UploadPartResult newUploadPartResult(final int partNumber, final
251266
}
252267

253268
private static void assertUploadPartRequest(final UploadPartRequest uploadPartRequest,
269+
final byte[] bytes,
254270
final int expectedPartSize,
255271
final int expectedPartNumber,
256272
final byte[] expectedBytes) {
@@ -259,23 +275,17 @@ private static void assertUploadPartRequest(final UploadPartRequest uploadPartRe
259275
assertThat(uploadPartRequest.getPartNumber()).isEqualTo(expectedPartNumber);
260276
assertThat(uploadPartRequest.getBucketName()).isEqualTo(BUCKET_NAME);
261277
assertThat(uploadPartRequest.getKey()).isEqualTo(FILE_KEY);
262-
assertThat(uploadPartRequest.getInputStream()).hasBinaryContent(expectedBytes);
278+
assertThat(bytes).isEqualTo(expectedBytes);
263279
}
264280

265281
private static void assertCompleteMultipartUploadRequest(final CompleteMultipartUploadRequest request,
266-
final List<PartETag> expectedETags) {
282+
final Map<Integer, String> expectedETags) {
267283
assertThat(request.getBucketName()).isEqualTo(BUCKET_NAME);
268284
assertThat(request.getKey()).isEqualTo(FILE_KEY);
269285
assertThat(request.getUploadId()).isEqualTo(UPLOAD_ID);
270-
assertThat(request.getPartETags()).hasSameSizeAs(expectedETags);
271-
272-
for (int i = 0; i < expectedETags.size(); i++) {
273-
final PartETag expectedETag = expectedETags.get(i);
274-
final PartETag etag = request.getPartETags().get(i);
275-
276-
assertThat(etag.getPartNumber()).isEqualTo(expectedETag.getPartNumber());
277-
assertThat(etag.getETag()).isEqualTo(expectedETag.getETag());
278-
}
286+
final Map<Integer, String> tags = request.getPartETags().stream()
287+
.collect(Collectors.toMap(PartETag::getPartNumber, PartETag::getETag));
288+
assertThat(tags).containsExactlyInAnyOrderEntriesOf(expectedETags);
279289
}
280290

281291
private static void assertAbortMultipartUploadRequest(final AbortMultipartUploadRequest request) {

0 commit comments

Comments
 (0)