From b764af8bec2e59bfb69f6f21800d9761f8f58f6e Mon Sep 17 00:00:00 2001 From: Pablo Arteaga <46710067+vagaerg@users.noreply.github.com> Date: Thu, 18 Jul 2024 12:31:22 +0100 Subject: [PATCH] Add more thorough tests for aws-chunked (#110) --- .../signing/InternalChunkSigningSession.java | 4 +- .../server/AbstractTestProxiedRequests.java | 14 +- .../proxy/server/TestGenericRestRequests.java | 214 +++++++++++++----- .../proxy/server/TestPresignedRequests.java | 20 +- .../signing/TestingChunkSigningSession.java | 47 +++- .../aws/proxy/server/testing/TestingUtil.java | 40 ++++ 6 files changed, 248 insertions(+), 91 deletions(-) diff --git a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/signing/InternalChunkSigningSession.java b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/signing/InternalChunkSigningSession.java index 1af9bc56..b712e082 100644 --- a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/signing/InternalChunkSigningSession.java +++ b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/signing/InternalChunkSigningSession.java @@ -13,6 +13,7 @@ */ package io.trino.aws.proxy.server.signing; +import com.google.common.annotations.VisibleForTesting; import com.google.common.hash.Hasher; import com.google.common.hash.Hashing; import io.airlift.log.Logger; @@ -28,7 +29,8 @@ class InternalChunkSigningSession { private static final Logger log = Logger.get(InternalChunkSigningSession.class); - private final ChunkSigner chunkSigner; + @VisibleForTesting + protected final ChunkSigner chunkSigner; private String previousSignature; private String expectedSignature; private Hasher hasher; diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/AbstractTestProxiedRequests.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/AbstractTestProxiedRequests.java index 689cc924..c4e08167 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/AbstractTestProxiedRequests.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/AbstractTestProxiedRequests.java @@ -13,6 +13,7 @@ */ package io.trino.aws.proxy.server; +import io.trino.aws.proxy.server.testing.TestingUtil; import jakarta.annotation.PreDestroy; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; @@ -25,13 +26,11 @@ import software.amazon.awssdk.services.s3.model.CompletedPart; import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.CreateMultipartUploadResponse; -import software.amazon.awssdk.services.s3.model.Delete; import software.amazon.awssdk.services.s3.model.DeleteObjectRequest; import software.amazon.awssdk.services.s3.model.DeleteObjectResponse; import software.amazon.awssdk.services.s3.model.GetObjectRequest; import software.amazon.awssdk.services.s3.model.ListBucketsResponse; import software.amazon.awssdk.services.s3.model.ListObjectsResponse; -import software.amazon.awssdk.services.s3.model.ObjectIdentifier; import software.amazon.awssdk.services.s3.model.PutObjectRequest; import software.amazon.awssdk.services.s3.model.PutObjectResponse; import software.amazon.awssdk.services.s3.model.S3Object; @@ -80,16 +79,7 @@ public void shutdown() @AfterEach public void cleanupBuckets() { - remoteClient.listBuckets().buckets().forEach(bucket -> remoteClient.listObjectsV2Paginator(request -> request.bucket(bucket.name())).forEach(s3ObjectPage -> { - if (s3ObjectPage.contents().isEmpty()) { - return; - } - List objectIdentifiers = s3ObjectPage.contents() - .stream() - .map(s3Object -> ObjectIdentifier.builder().key(s3Object.key()).build()) - .collect(toImmutableList()); - remoteClient.deleteObjects(deleteRequest -> deleteRequest.bucket(bucket.name()).delete(Delete.builder().objects(objectIdentifiers).build())); - })); + TestingUtil.cleanupBuckets(remoteClient); } @Test diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestGenericRestRequests.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestGenericRestRequests.java index 7ad41495..c27ac21b 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestGenericRestRequests.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestGenericRestRequests.java @@ -13,14 +13,21 @@ */ package io.trino.aws.proxy.server; +import com.google.common.base.Splitter; import com.google.inject.Inject; import io.airlift.http.client.HttpClient; import io.airlift.http.client.Request; import io.airlift.http.client.StatusResponseHandler.StatusResponse; import io.airlift.http.server.testing.TestingHttpServer; import io.airlift.units.Duration; +import io.trino.aws.proxy.server.credentials.CredentialsController; +import io.trino.aws.proxy.server.rest.RequestLoggerController; import io.trino.aws.proxy.server.rest.TrinoS3ProxyConfig; +import io.trino.aws.proxy.server.signing.InternalSigningController; +import io.trino.aws.proxy.server.signing.SigningControllerConfig; +import io.trino.aws.proxy.server.signing.TestingChunkSigningSession; import io.trino.aws.proxy.server.testing.TestingCredentialsRolesProvider; +import io.trino.aws.proxy.server.testing.TestingRemoteS3Facade; import io.trino.aws.proxy.server.testing.TestingTrinoAwsProxyServer; import io.trino.aws.proxy.server.testing.TestingUtil.ForTesting; import io.trino.aws.proxy.server.testing.containers.S3Container.ForS3Container; @@ -28,18 +35,34 @@ import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTestCommonModules.WithTestingHttpClient; import io.trino.aws.proxy.spi.credentials.Credential; import io.trino.aws.proxy.spi.credentials.Credentials; +import io.trino.aws.proxy.spi.signing.RequestAuthorization; +import io.trino.aws.proxy.spi.signing.SigningMetadata; +import io.trino.aws.proxy.spi.signing.SigningServiceType; +import io.trino.aws.proxy.spi.util.AwsTimestamp; +import io.trino.aws.proxy.spi.util.ImmutableMultiMap; import jakarta.ws.rs.core.UriBuilder; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; +import software.amazon.awssdk.auth.signer.internal.chunkedencoding.AwsS3V4ChunkSigner; import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.DeleteBucketRequest; +import java.io.IOException; import java.net.URI; import java.nio.charset.StandardCharsets; +import java.time.Instant; +import java.util.List; import java.util.Optional; +import java.util.UUID; import java.util.concurrent.TimeUnit; +import java.util.function.Function; import static io.airlift.http.client.Request.Builder.preparePut; import static io.airlift.http.client.StaticBodyGenerator.createStaticBodyGenerator; import static io.airlift.http.client.StatusResponseHandler.createStatusResponseHandler; +import static io.trino.aws.proxy.server.testing.TestingUtil.assertFileNotInS3; +import static io.trino.aws.proxy.server.testing.TestingUtil.cleanupBuckets; +import static io.trino.aws.proxy.server.testing.TestingUtil.getFileFromStorage; import static java.util.Objects.requireNonNull; import static org.assertj.core.api.Assertions.assertThat; @@ -52,29 +75,6 @@ public class TestGenericRestRequests private final Credentials testingCredentials; private final S3Client storageClient; - private static final String goodChunkedContent = """ - 7b;chunk-signature=20e300fbbad6946a482aaa7de0bdc8f592d4c372306dd746a22d18b7b66b4527\r - Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.\r - 0;chunk-signature=ae4265701a9e0796d671d3339c71db240c0c87b2f6e2f9c6ca7cd781fdcf641a\r - \r - """; - - // first chunk-signature is bad - private static final String badChunkedContent1 = """ - 7b;chunk-signature=10e300fbbad6946a482aaa7de0bdc8f592d4c372306dd746a22d18b7b66b4527\r - Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.\r - 0;chunk-signature=ae4265701a9e0796d671d3339c71db240c0c87b2f6e2f9c6ca7cd781fdcf641a\r - \r - """; - - // second chunk-signature is bad - private static final String badChunkedContent2 = """ - 7b;chunk-signature=20e300fbbad6946a482aaa7de0bdc8f592d4c372306dd746a22d18b7b66b4527\r - Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.\r - 0;chunk-signature=9e4265701a9e0796d671d3339c71db240c0c87b2f6e2f9c6ca7cd781fdcf641a\r - \r - """; - private static final String goodContent = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Viverra aliquet eget sit amet tellus cras adipiscing. Viverra mauris in aliquam sem fringilla. Facilisis mauris sit amet massa vitae. Mauris vitae ultricies leo integer malesuada. Sed libero enim sed faucibus turpis in eu mi bibendum. Lorem sed risus ultricies tristique nulla aliquet enim. Quis blandit turpis cursus in hac habitasse platea dictumst quisque. Diam maecenas ultricies mi eget mauris pharetra et ultrices neque. Aliquam sem fringilla ut morbi."; // first char is different case @@ -110,18 +110,89 @@ public TestGenericRestRequests( this.storageClient = requireNonNull(storageClient, "storageClient is null"); } + @AfterEach + public void cleanupStorage() + { + cleanupBuckets(storageClient); + storageClient.listBuckets().buckets().forEach(bucket -> storageClient.deleteBucket(DeleteBucketRequest.builder().bucket(bucket.name()).build())); + } + @Test - public void testAwsChunkedUpload() + public void testAwsChunkedUploadValid() + throws IOException { - Credential credential = new Credential("c160cd8c-8273-4e34-bcf5-3dbddec0c6e0", "464cbc68-2d4f-4e4d-b653-5b1630db9f56"); - Credentials credentials = new Credentials(credential, testingCredentials.remote(), Optional.empty()); - credentialsRolesProvider.addCredentials(credentials); + String bucket = "test-aws-chunked"; + storageClient.createBucket(r -> r.bucket(bucket).build()); - storageClient.createBucket(r -> r.bucket("two").build()); + Credential validCredential = new Credential(UUID.randomUUID().toString(), UUID.randomUUID().toString()); + credentialsRolesProvider.addCredentials(Credentials.build(validCredential, testingCredentials.requiredRemoteCredential())); - assertThat(doAwsChunkedUpload(goodChunkedContent).getStatusCode()).isEqualTo(200); - assertThat(doAwsChunkedUpload(badChunkedContent1).getStatusCode()).isEqualTo(401); - assertThat(doAwsChunkedUpload(badChunkedContent2).getStatusCode()).isEqualTo(401); + // Upload in 2 chunks + assertThat(doAwsChunkedUpload(bucket, "aws-chunked-2-partitions", goodContent, 2, validCredential).getStatusCode()).isEqualTo(200); + assertThat(getFileFromStorage(storageClient, bucket, "aws-chunked-2-partitions")).isEqualTo(goodContent); + + // Upload in 3 chunks + assertThat(doAwsChunkedUpload(bucket, "aws-chunked-3-partitions", goodContent, 3, validCredential).getStatusCode()).isEqualTo(200); + assertThat(getFileFromStorage(storageClient, bucket, "aws-chunked-3-partitions")).isEqualTo(goodContent); + } + + @Test + public void testAwsChunkedUploadCornerCases() + throws IOException + { + String bucket = "test-aws-chunked"; + String fileKey = "sample_file_chunked"; + storageClient.createBucket(r -> r.bucket(bucket).build()); + + Credential validCredential = new Credential(UUID.randomUUID().toString(), UUID.randomUUID().toString()); + credentialsRolesProvider.addCredentials(Credentials.build(validCredential, testingCredentials.requiredRemoteCredential())); + Credential validCredentialTwo = new Credential(UUID.randomUUID().toString(), UUID.randomUUID().toString()); + credentialsRolesProvider.addCredentials(Credentials.build(validCredentialTwo, testingCredentials.requiredRemoteCredential())); + Credential unknownCredential = new Credential(UUID.randomUUID().toString(), UUID.randomUUID().toString()); + + // Credential is not known to the credential controller + assertThat(doAwsChunkedUpload(bucket, fileKey, goodContent, 2, unknownCredential).getStatusCode()).isEqualTo(401); + assertFileNotInS3(storageClient, bucket, fileKey); + + // The request and the chunks are signed with different keys - both valid, but not matching + assertThat(doAwsChunkedUpload(bucket, fileKey, goodContent, 2, validCredential, validCredentialTwo, Function.identity()).getStatusCode()).isEqualTo(401); + assertFileNotInS3(storageClient, bucket, fileKey); + + // Final chunk has an invalid size + Function changeSizeOfFinalChunk = chunked -> chunked.replaceFirst("\\r\\n0;chunk-signature=(\\w+)", "\r\n1;chunk-signature=$1"); + // TODO: this currently will be accepted, we need to add stricter validation - use a different key so it does not interfere with other cases + assertThat(doAwsChunkedUpload(bucket, "final_chunk_invalid_size", goodContent, 2, validCredential, changeSizeOfFinalChunk).getStatusCode()).isEqualTo(500); + // assertFileNotInS3(storageClient, bucket, "final_chunk_invalid_size"); + + // First chunk has an invalid size + Function changeSizeOfFirstChunk = chunked -> { + int firstChunkIdx = chunked.indexOf(";"); + String firstChunkSizeString = chunked.substring(0, firstChunkIdx); + int firstChunkSize = Integer.parseInt(firstChunkSizeString.strip(), 16); + int newSize = firstChunkSize - 1; + String newSizeAsString = Integer.toString(newSize, 16); + // We need to ensure the size (in string form) remains the same so the Content-Length is unchanged + if (newSizeAsString.length() < firstChunkSizeString.length()) { + newSizeAsString = "0" + newSizeAsString; + } + return "%s%s".formatted(newSizeAsString, chunked.substring(firstChunkIdx)); + }; + assertThat(doAwsChunkedUpload(bucket, fileKey, goodContent, 2, validCredential, changeSizeOfFirstChunk).getStatusCode()).isEqualTo(500); + assertFileNotInS3(storageClient, bucket, fileKey); + + // Change the signature of each of the chunks + assertThat(doAwsChunkedUpload(bucket, fileKey, goodContent, 3, validCredential, getMutatorToBreakSignatureForChunk(0)).getStatusCode()).isEqualTo(401); + assertFileNotInS3(storageClient, bucket, fileKey); + + assertThat(doAwsChunkedUpload(bucket, fileKey, goodContent, 3, validCredential, getMutatorToBreakSignatureForChunk(1)).getStatusCode()).isEqualTo(401); + assertFileNotInS3(storageClient, bucket, fileKey); + + assertThat(doAwsChunkedUpload(bucket, fileKey, goodContent, 3, validCredential, getMutatorToBreakSignatureForChunk(2)).getStatusCode()).isEqualTo(401); + assertFileNotInS3(storageClient, bucket, fileKey); + + // Sanity check: uploads work with this key if we do not interfere + assertThat(doAwsChunkedUpload(bucket, fileKey, goodContent, 2, validCredential).getStatusCode()).isEqualTo(200); + assertThat(getFileFromStorage(storageClient, bucket, fileKey)).isEqualTo(goodContent); } @Test @@ -139,6 +210,45 @@ public void testPutObject() assertThat(doPutObject(badContent, badSha256).getStatusCode()).isEqualTo(401); } + private StatusResponse doAwsChunkedUpload(String bucket, String key, String contentToUpload, int partitionCount, Credential credential) + { + return doAwsChunkedUpload(bucket, key, contentToUpload, partitionCount, credential, Function.identity()); + } + + private StatusResponse doAwsChunkedUpload(String bucket, String key, String contentToUpload, int partitionCount, Credential credential, Function chunkedPayloadMutator) + { + return doAwsChunkedUpload(bucket, key, contentToUpload, partitionCount, credential, credential, chunkedPayloadMutator); + } + + private StatusResponse doAwsChunkedUpload(String bucket, String key, String contentToUpload, int partitionCount, Credential requestSigningCredential, Credential chunkSigningCredential, Function chunkedPayloadMutator) + { + ImmutableMultiMap.Builder requestHeaderBuilder = ImmutableMultiMap.builder(false); + Instant requestDate = Instant.now(); + requestHeaderBuilder + .add("Host", "%s:%d".formatted(baseUri.getHost(), baseUri.getPort())) + .add("X-Amz-Date", AwsTimestamp.toRequestFormat(requestDate)) + .add("X-Amz-Content-Sha256", "STREAMING-AWS4-HMAC-SHA256-PAYLOAD") + .add("X-Amz-Decoded-Content-Length", String.valueOf(contentToUpload.length())) + .add("Content-Length", String.valueOf(TestingChunkSigningSession.getExpectedChunkedStreamSize(contentToUpload, partitionCount))) + .add("Content-Type", "text/plain") + .add("Content-Encoding", "aws-chunked"); + InternalSigningController signingController = new InternalSigningController( + new CredentialsController(new TestingRemoteS3Facade(), credentialsRolesProvider), + new SigningControllerConfig().setMaxClockDrift(new Duration(10, TimeUnit.SECONDS)), + new RequestLoggerController()); + + URI requestUri = UriBuilder.fromUri(baseUri).path(bucket).path(key).build(); + RequestAuthorization requestAuthorization = signingController.signRequest(new SigningMetadata(SigningServiceType.S3, Credentials.build(requestSigningCredential, testingCredentials.requiredRemoteCredential()), Optional.empty()), + "us-east-1", requestDate, Optional.empty(), Credentials::emulated, requestUri, requestHeaderBuilder.build(), ImmutableMultiMap.empty(), "PUT"); + String chunkedContent = chunkedPayloadMutator.apply(TestingChunkSigningSession.build(chunkSigningCredential, requestAuthorization.signature(), requestDate).generateChunkedStream(contentToUpload, partitionCount)); + Request.Builder requestBuilder = preparePut().setUri(requestUri); + + requestHeaderBuilder.add("Authorization", requestAuthorization.authorization()); + requestHeaderBuilder.build().forEachEntry(requestBuilder::addHeader); + requestBuilder.setBodyGenerator(createStaticBodyGenerator(chunkedContent.getBytes(StandardCharsets.UTF_8))); + return httpClient.execute(requestBuilder.build(), createStatusResponseHandler()); + } + private StatusResponse doPutObject(String content, String sha256) { URI uri = UriBuilder.fromUri(baseUri) @@ -164,29 +274,25 @@ private StatusResponse doPutObject(String content, String sha256) return httpClient.execute(request, createStatusResponseHandler()); } - private StatusResponse doAwsChunkedUpload(String content) + private static Function getMutatorToBreakSignatureForChunk(int chunkNumber) { - URI uri = UriBuilder.fromUri(baseUri) - .path("two") - .path("test") - .build(); - - // values discovered from an AWS CLI request sent to a dummy local HTTP server - Request request = preparePut().setUri(uri) - .setHeader("Host", "127.0.0.1:62820") - .setHeader("User-Agent", "aws-sdk-java/2.25.32 Mac_OS_X/13.6.7 OpenJDK_64-Bit_Server_VM/22.0.1+8-16 Java/22.0.1 kotlin/1.9.23-release-779 vendor/Oracle_Corporation io/sync http/Apache cfg/retry-mode/legacy") - .setHeader("X-Amz-Date", "20240618T080640Z") - .setHeader("x-amz-content-sha256", "STREAMING-AWS4-HMAC-SHA256-PAYLOAD") - .setHeader("x-amz-decoded-content-length", "123") - .setHeader("Authorization", "AWS4-HMAC-SHA256 Credential=c160cd8c-8273-4e34-bcf5-3dbddec0c6e0/20240618/us-east-1/s3/aws4_request, SignedHeaders=amz-sdk-invocation-id;amz-sdk-request;content-encoding;content-length;content-type;host;x-amz-content-sha256;x-amz-date;x-amz-decoded-content-length, Signature=3bdce17ef4446ba2900c8f90b2e8ee812ccfa4625abb67030fae01dd1a9d347b") - .setHeader("Content-Encoding", "aws-chunked") - .setHeader("amz-sdk-invocation-id", "0c59609c-1c7b-e503-0583-b0271b5e8b21") - .setHeader("amz-sdk-request", "attempt=1; max=4") - .setHeader("Content-Length", "296") - .setHeader("Content-Type", "text/plain") - .setBodyGenerator(createStaticBodyGenerator(content, StandardCharsets.UTF_8)) - .build(); - - return httpClient.execute(request, createStatusResponseHandler()); + return chunkedContent -> { + int remainingChunks = chunkNumber; + StringBuilder resultBuilder = new StringBuilder(); + List parts = Splitter.on("\r\n").omitEmptyStrings().splitToList(chunkedContent); + for (String part : parts) { + if (part.contains(";chunk-signature=")) { + if (remainingChunks-- == 0) { + resultBuilder.append(part.replaceFirst("([0-9a-f]+;chunk-signature=)(\\w+)", "$1" + "0".repeat(AwsS3V4ChunkSigner.getSignatureLength()))); + resultBuilder.append("\r\n"); + continue; + } + } + resultBuilder.append(part); + resultBuilder.append("\r\n"); + } + resultBuilder.append("\r\n"); + return resultBuilder.toString(); + }; } } diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestPresignedRequests.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestPresignedRequests.java index b7715c79..274b85cb 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestPresignedRequests.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestPresignedRequests.java @@ -25,6 +25,7 @@ import io.airlift.http.client.StringResponseHandler.StringResponse; import io.airlift.http.server.testing.TestingHttpServer; import io.trino.aws.proxy.server.rest.TrinoS3ProxyConfig; +import io.trino.aws.proxy.server.testing.TestingUtil; import io.trino.aws.proxy.server.testing.TestingUtil.ForTesting; import io.trino.aws.proxy.server.testing.containers.S3Container.ForS3Container; import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTest; @@ -44,7 +45,6 @@ import software.amazon.awssdk.services.s3.model.GetObjectRequest; import software.amazon.awssdk.services.s3.model.PutObjectRequest; import software.amazon.awssdk.services.s3.model.PutObjectResponse; -import software.amazon.awssdk.services.s3.model.S3Exception; import software.amazon.awssdk.services.s3.model.UploadPartRequest; import software.amazon.awssdk.services.s3.presigner.S3Presigner; import software.amazon.awssdk.services.s3.presigner.model.CompleteMultipartUploadPresignRequest; @@ -60,7 +60,6 @@ import software.amazon.awssdk.services.s3.presigner.model.PutObjectPresignRequest; import software.amazon.awssdk.services.s3.presigner.model.UploadPartPresignRequest; -import java.io.ByteArrayOutputStream; import java.io.IOException; import java.net.URI; import java.net.URISyntaxException; @@ -76,9 +75,9 @@ import static io.airlift.http.client.StatusResponseHandler.createStatusResponseHandler; import static io.airlift.http.client.StringResponseHandler.createStringResponseHandler; import static io.trino.aws.proxy.server.testing.TestingUtil.TEST_FILE; +import static io.trino.aws.proxy.server.testing.TestingUtil.assertFileNotInS3; import static java.util.Objects.requireNonNull; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatExceptionOfType; @TrinoAwsProxyTest(filters = {WithConfiguredBuckets.class, WithTestingHttpClient.class, TestProxiedRequests.Filter.class}) public class TestPresignedRequests @@ -159,10 +158,7 @@ public void testPresignedPut() assertThat(response.getStatusCode()).isEqualTo(200); } - GetObjectRequest getObjectRequest = GetObjectRequest.builder().bucket("two").key("presignedPut").build(); - ByteArrayOutputStream readContents = new ByteArrayOutputStream(); - internalClient.getObject(getObjectRequest).transferTo(readContents); - assertThat(readContents.toString()).isEqualTo(fileContents); + assertThat(getFileFromStorage("two", "presignedPut")).isEqualTo(fileContents); } @Test @@ -186,10 +182,7 @@ public void testPresignedDelete() assertThat(response.getStatusCode()).isEqualTo(204); } - assertThatExceptionOfType(S3Exception.class) - .isThrownBy(() -> getFileFromStorage("three", "fileToDelete")) - .extracting(S3Exception::statusCode) - .isEqualTo(404); + assertFileNotInS3(storageClient, "three", "fileToDelete"); } @Test @@ -302,10 +295,7 @@ private void uploadFileToStorage(String bucketName, String key, Path filePath) private String getFileFromStorage(String bucketName, String key) throws IOException { - GetObjectRequest getObjectRequest = GetObjectRequest.builder().bucket(bucketName).key(key).build(); - ByteArrayOutputStream readContents = new ByteArrayOutputStream(); - internalClient.getObject(getObjectRequest).transferTo(readContents); - return readContents.toString(); + return TestingUtil.getFileFromStorage(internalClient, bucketName, key); } private S3Presigner buildPresigner() diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/signing/TestingChunkSigningSession.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/signing/TestingChunkSigningSession.java index f38ef2fb..67552229 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/signing/TestingChunkSigningSession.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/signing/TestingChunkSigningSession.java @@ -24,6 +24,9 @@ import software.amazon.awssdk.regions.Region; import java.time.Instant; +import java.time.ZoneId; +import java.time.format.DateTimeFormatter; +import java.util.Locale; import java.util.UUID; import static com.google.common.base.Preconditions.checkArgument; @@ -33,9 +36,9 @@ public class TestingChunkSigningSession extends InternalChunkSigningSession { + private static final DateTimeFormatter CHUNK_DATETIME_FORMAT = DateTimeFormatter.ofPattern("yyyyMMdd", Locale.US).withZone(ZoneId.of("Z")); + private final String seed; - private final Instant instant; - private final byte[] signingKey; public static TestingChunkSigningSession build() { @@ -45,6 +48,11 @@ public static TestingChunkSigningSession build() } public static TestingChunkSigningSession build(Credential credential, String seed) + { + return build(credential, seed, Instant.now()); + } + + public static TestingChunkSigningSession build(Credential credential, String seed, Instant instant) { AwsCredentials credentials = AwsBasicCredentials.create(credential.accessKey(), credential.secretKey()); Aws4SignerParams.Builder builder = Aws4SignerParams.builder() @@ -54,7 +62,31 @@ public static TestingChunkSigningSession build(Credential credential, String see .signingRegion(Region.US_EAST_1); byte[] signingKey = Signer.signingKey(credentials, new Aws4SignerRequestParams(builder.build())); - return new TestingChunkSigningSession(seed, Instant.now(), signingKey); + return new TestingChunkSigningSession(seed, instant, signingKey, "%s/us-east-1/s3/aws4_request".formatted(CHUNK_DATETIME_FORMAT.format(instant))); + } + + public static int getExpectedChunkedStreamSize(String rawContent, int partitions) + { + int contentSizeInBytes = rawContent.getBytes(UTF_8).length; + int standardChunkSize = Math.ceilDiv(contentSizeInBytes, partitions); + // The penultimate chunk may be smaller if alignment is not perfect + int penultimateChunkSize = contentSizeInBytes - (standardChunkSize * (partitions - 1)); + // Each chunk has: + // - A header consisting of ";chunk-signature=" + // - \r\n + // - + // - \r\n + int baseChunkSize = ";chunk-signature=".length() + AwsS3V4ChunkSigner.getSignatureLength() + 4; + // Chunk headers without including the size + return (baseChunkSize * (partitions + 1)) + + // Size of the size field for all chunk except for the last 2 + (Integer.toString(standardChunkSize, 16).length() * (partitions - 1)) + + // Size of the size field for the penultimate chunk + (Integer.toString(penultimateChunkSize, 16).length()) + + // Size of the size field for the last chunk (size=0, 1 character) + 1 + + // Size of the actual content + contentSizeInBytes; } @SuppressWarnings("UnstableApiUsage") @@ -62,11 +94,10 @@ public String generateChunkedStream(String content, int partitions) { checkArgument(partitions > 1, "partitions must be greater than 1"); - ChunkSigner chunkSigner = new ChunkSigner(instant, "/dummy", signingKey); String previousSignature = seed; StringBuilder chunkedStream = new StringBuilder(); - int chunkSize = content.length() / partitions; + int chunkSize = Math.ceilDiv(content.length(), partitions); int index = 0; while (index < content.length()) { int thisLength = Math.min(chunkSize, content.length() - index); @@ -88,12 +119,10 @@ public String generateChunkedStream(String content, int partitions) return chunkedStream.toString(); } - private TestingChunkSigningSession(String seed, Instant instant, byte[] signingKey) + private TestingChunkSigningSession(String seed, Instant instant, byte[] signingKey, String keyPath) { - super(new ChunkSigner(instant, "/dummy", signingKey), seed); + super(new ChunkSigner(instant, keyPath, signingKey), seed); this.seed = requireNonNull(seed, "seed is null"); - this.instant = requireNonNull(instant, "instant is null"); - this.signingKey = requireNonNull(signingKey, "signingKey is null"); } } diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingUtil.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingUtil.java index 8d5e15f7..5fd31f67 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingUtil.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingUtil.java @@ -20,21 +20,30 @@ import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.s3.S3ClientBuilder; +import software.amazon.awssdk.services.s3.model.Delete; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.ObjectIdentifier; +import software.amazon.awssdk.services.s3.model.S3Exception; +import java.io.ByteArrayOutputStream; import java.io.File; +import java.io.IOException; import java.lang.annotation.Retention; import java.lang.annotation.Target; import java.net.URI; import java.nio.file.Path; +import java.util.List; import java.util.Optional; import java.util.UUID; import java.util.stream.Stream; import static com.google.common.base.MoreObjects.firstNonNull; +import static com.google.common.collect.ImmutableList.toImmutableList; import static java.lang.annotation.ElementType.FIELD; import static java.lang.annotation.ElementType.METHOD; import static java.lang.annotation.ElementType.PARAMETER; import static java.lang.annotation.RetentionPolicy.RUNTIME; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; public final class TestingUtil { @@ -78,4 +87,35 @@ public static File findTestJar(String name) .findFirst() .orElseThrow(() -> new AssertionError("Unable to find test jar: " + name)); } + + public static String getFileFromStorage(S3Client storageClient, String bucketName, String key) + throws IOException + { + GetObjectRequest getObjectRequest = GetObjectRequest.builder().bucket(bucketName).key(key).build(); + ByteArrayOutputStream readContents = new ByteArrayOutputStream(); + storageClient.getObject(getObjectRequest).transferTo(readContents); + return readContents.toString(); + } + + public static void cleanupBuckets(S3Client storageClient) + { + storageClient.listBuckets().buckets().forEach(bucket -> storageClient.listObjectsV2Paginator(request -> request.bucket(bucket.name())).forEach(s3ObjectPage -> { + if (s3ObjectPage.contents().isEmpty()) { + return; + } + List objectIdentifiers = s3ObjectPage.contents() + .stream() + .map(s3Object -> ObjectIdentifier.builder().key(s3Object.key()).build()) + .collect(toImmutableList()); + storageClient.deleteObjects(deleteRequest -> deleteRequest.bucket(bucket.name()).delete(Delete.builder().objects(objectIdentifiers).build())); + })); + } + + public static void assertFileNotInS3(S3Client storageClient, String bucket, String key) + { + assertThatExceptionOfType(S3Exception.class) + .isThrownBy(() -> getFileFromStorage(storageClient, bucket, key)) + .extracting(S3Exception::statusCode) + .isEqualTo(404); + } }