diff --git a/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/AwsCmdbRecordHandler.java b/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/AwsCmdbRecordHandler.java index 60c63b437a..dc530d3f90 100644 --- a/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/AwsCmdbRecordHandler.java +++ b/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/AwsCmdbRecordHandler.java @@ -25,9 +25,9 @@ import com.amazonaws.athena.connector.lambda.handlers.RecordHandler; import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest; import com.amazonaws.athena.connectors.aws.cmdb.tables.TableProvider; -import com.amazonaws.services.s3.AmazonS3; import org.apache.arrow.util.VisibleForTesting; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -56,7 +56,7 @@ public AwsCmdbRecordHandler(java.util.Map configOptions) } @VisibleForTesting - protected AwsCmdbRecordHandler(AmazonS3 amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, TableProviderFactory tableProviderFactory, java.util.Map configOptions) + protected AwsCmdbRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, TableProviderFactory tableProviderFactory, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, SOURCE_TYPE, configOptions); tableProviders = tableProviderFactory.getTableProviders(); diff --git a/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/TableProviderFactory.java b/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/TableProviderFactory.java index d5868d33db..7a5099e0a7 100644 --- a/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/TableProviderFactory.java +++ b/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/TableProviderFactory.java @@ -38,9 +38,8 @@ import com.amazonaws.services.elasticmapreduce.AmazonElasticMapReduceClientBuilder; import com.amazonaws.services.rds.AmazonRDS; import com.amazonaws.services.rds.AmazonRDSClientBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; import org.apache.arrow.util.VisibleForTesting; +import software.amazon.awssdk.services.s3.S3Client; import java.util.ArrayList; import java.util.HashMap; @@ -62,12 +61,12 @@ public TableProviderFactory(java.util.Map configOptions) AmazonEC2ClientBuilder.standard().build(), AmazonElasticMapReduceClientBuilder.standard().build(), AmazonRDSClientBuilder.standard().build(), - AmazonS3ClientBuilder.standard().build(), + S3Client.create(), configOptions); } @VisibleForTesting - protected TableProviderFactory(AmazonEC2 ec2, AmazonElasticMapReduce emr, AmazonRDS rds, AmazonS3 amazonS3, java.util.Map configOptions) + protected TableProviderFactory(AmazonEC2 ec2, AmazonElasticMapReduce emr, AmazonRDS rds, S3Client amazonS3, java.util.Map configOptions) { addProvider(new Ec2TableProvider(ec2)); addProvider(new EbsTableProvider(ec2)); diff --git a/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/s3/S3BucketsTableProvider.java b/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/s3/S3BucketsTableProvider.java index 0387ac6bf7..7ff28b61e5 100644 --- a/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/s3/S3BucketsTableProvider.java +++ b/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/s3/S3BucketsTableProvider.java @@ -29,10 +29,12 @@ import com.amazonaws.athena.connector.lambda.metadata.GetTableResponse; import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest; import com.amazonaws.athena.connectors.aws.cmdb.tables.TableProvider; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.Bucket; -import com.amazonaws.services.s3.model.Owner; import org.apache.arrow.vector.types.pojo.Schema; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.Bucket; +import software.amazon.awssdk.services.s3.model.GetBucketAclRequest; +import software.amazon.awssdk.services.s3.model.GetBucketAclResponse; +import software.amazon.awssdk.services.s3.model.Owner; /** * Maps your S3 Objects to a table. @@ -41,9 +43,9 @@ public class S3BucketsTableProvider implements TableProvider { private static final Schema SCHEMA; - private AmazonS3 amazonS3; + private S3Client amazonS3; - public S3BucketsTableProvider(AmazonS3 amazonS3) + public S3BucketsTableProvider(S3Client amazonS3) { this.amazonS3 = amazonS3; } @@ -84,7 +86,7 @@ public GetTableResponse getTable(BlockAllocator blockAllocator, GetTableRequest @Override public void readWithConstraint(BlockSpiller spiller, ReadRecordsRequest recordsRequest, QueryStatusChecker queryStatusChecker) { - for (Bucket next : amazonS3.listBuckets()) { + for (Bucket next : amazonS3.listBuckets().buckets()) { toRow(next, spiller); } } @@ -102,13 +104,15 @@ private void toRow(Bucket bucket, { spiller.writeRows((Block block, int row) -> { boolean matched = true; - matched &= block.offerValue("bucket_name", row, bucket.getName()); - matched &= block.offerValue("create_date", row, bucket.getCreationDate()); + matched &= block.offerValue("bucket_name", row, bucket.name()); + matched &= block.offerValue("create_date", row, bucket.creationDate()); - Owner owner = bucket.getOwner(); + GetBucketAclResponse response = amazonS3.getBucketAcl(GetBucketAclRequest.builder().bucket(bucket.name()).build()); + + Owner owner = response.owner(); if (owner != null) { - matched &= block.offerValue("owner_name", row, bucket.getOwner().getDisplayName()); - matched &= block.offerValue("owner_id", row, bucket.getOwner().getId()); + matched &= block.offerValue("owner_name", row, owner.displayName()); + matched &= block.offerValue("owner_id", row, owner.id()); } return matched ? 1 : 0; diff --git a/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/s3/S3ObjectsTableProvider.java b/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/s3/S3ObjectsTableProvider.java index c58315f49e..88179b9382 100644 --- a/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/s3/S3ObjectsTableProvider.java +++ b/athena-aws-cmdb/src/main/java/com/amazonaws/athena/connectors/aws/cmdb/tables/s3/S3ObjectsTableProvider.java @@ -30,12 +30,12 @@ import com.amazonaws.athena.connector.lambda.metadata.GetTableResponse; import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest; import com.amazonaws.athena.connectors.aws.cmdb.tables.TableProvider; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.ListObjectsV2Request; -import com.amazonaws.services.s3.model.ListObjectsV2Result; -import com.amazonaws.services.s3.model.Owner; -import com.amazonaws.services.s3.model.S3ObjectSummary; import org.apache.arrow.vector.types.pojo.Schema; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.ListObjectsV2Request; +import software.amazon.awssdk.services.s3.model.ListObjectsV2Response; +import software.amazon.awssdk.services.s3.model.Owner; +import software.amazon.awssdk.services.s3.model.S3Object; /** * Maps your S3 Objects to a table. @@ -45,9 +45,9 @@ public class S3ObjectsTableProvider { private static final int MAX_KEYS = 1000; private static final Schema SCHEMA; - private AmazonS3 amazonS3; + private S3Client amazonS3; - public S3ObjectsTableProvider(AmazonS3 amazonS3) + public S3ObjectsTableProvider(S3Client amazonS3) { this.amazonS3 = amazonS3; } @@ -98,42 +98,44 @@ public void readWithConstraint(BlockSpiller spiller, ReadRecordsRequest recordsR "(e.g. where bucket_name='my_bucket'."); } - ListObjectsV2Request req = new ListObjectsV2Request().withBucketName(bucket).withMaxKeys(MAX_KEYS); - ListObjectsV2Result result; + ListObjectsV2Request req = ListObjectsV2Request.builder().bucket(bucket).maxKeys(MAX_KEYS).build(); + ListObjectsV2Response response; do { - result = amazonS3.listObjectsV2(req); - for (S3ObjectSummary objectSummary : result.getObjectSummaries()) { - toRow(objectSummary, spiller); + response = amazonS3.listObjectsV2(req); + for (S3Object s3Object : response.contents()) { + toRow(s3Object, spiller, bucket); } - req.setContinuationToken(result.getNextContinuationToken()); + req = req.toBuilder().continuationToken(response.nextContinuationToken()).build(); } - while (result.isTruncated() && queryStatusChecker.isQueryRunning()); + while (response.isTruncated() && queryStatusChecker.isQueryRunning()); } /** * Maps a DBInstance into a row in our Apache Arrow response block(s). * - * @param objectSummary The S3 ObjectSummary to map. + * @param s3Object The S3 object to map. * @param spiller The BlockSpiller to use when we want to write a matching row to the response. + * @param bucket The name of the S3 bucket * @note The current implementation is rather naive in how it maps fields. It leverages a static * list of fields that we'd like to provide and then explicitly filters and converts each field. */ - private void toRow(S3ObjectSummary objectSummary, - BlockSpiller spiller) + private void toRow(S3Object s3Object, + BlockSpiller spiller, + String bucket) { spiller.writeRows((Block block, int row) -> { boolean matched = true; - matched &= block.offerValue("bucket_name", row, objectSummary.getBucketName()); - matched &= block.offerValue("e_tag", row, objectSummary.getETag()); - matched &= block.offerValue("key", row, objectSummary.getKey()); - matched &= block.offerValue("bytes", row, objectSummary.getSize()); - matched &= block.offerValue("storage_class", row, objectSummary.getStorageClass()); - matched &= block.offerValue("last_modified", row, objectSummary.getLastModified()); + matched &= block.offerValue("bucket_name", row, bucket); + matched &= block.offerValue("e_tag", row, s3Object.eTag()); + matched &= block.offerValue("key", row, s3Object.key()); + matched &= block.offerValue("bytes", row, s3Object.size()); + matched &= block.offerValue("storage_class", row, s3Object.storageClassAsString()); + matched &= block.offerValue("last_modified", row, s3Object.lastModified()); - Owner owner = objectSummary.getOwner(); + Owner owner = s3Object.owner(); if (owner != null) { - matched &= block.offerValue("owner_name", row, owner.getDisplayName()); - matched &= block.offerValue("owner_id", row, owner.getId()); + matched &= block.offerValue("owner_name", row, owner.displayName()); + matched &= block.offerValue("owner_id", row, owner.id()); } return matched ? 1 : 0; diff --git a/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/AwsCmdbMetadataHandlerTest.java b/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/AwsCmdbMetadataHandlerTest.java index 7aeefea094..6c755e65aa 100644 --- a/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/AwsCmdbMetadataHandlerTest.java +++ b/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/AwsCmdbMetadataHandlerTest.java @@ -38,7 +38,6 @@ import com.amazonaws.athena.connector.lambda.security.FederatedIdentity; import com.amazonaws.athena.connector.lambda.security.LocalKeyFactory; import com.amazonaws.athena.connectors.aws.cmdb.tables.TableProvider; -import com.amazonaws.services.s3.AmazonS3; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -46,6 +45,7 @@ import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.ArrayList; @@ -75,7 +75,7 @@ public class AwsCmdbMetadataHandlerTest private FederatedIdentity identity = new FederatedIdentity("arn", "account", Collections.emptyMap(), Collections.emptyList()); @Mock - private AmazonS3 mockS3; + private S3Client mockS3; @Mock private TableProviderFactory mockTableProviderFactory; diff --git a/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/AwsCmdbRecordHandlerTest.java b/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/AwsCmdbRecordHandlerTest.java index 940df77986..09000c9e60 100644 --- a/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/AwsCmdbRecordHandlerTest.java +++ b/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/AwsCmdbRecordHandlerTest.java @@ -32,7 +32,6 @@ import com.amazonaws.athena.connector.lambda.security.FederatedIdentity; import com.amazonaws.athena.connector.lambda.security.LocalKeyFactory; import com.amazonaws.athena.connectors.aws.cmdb.tables.TableProvider; -import com.amazonaws.services.s3.AmazonS3; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -40,6 +39,7 @@ import org.mockito.Mockito; import org.mockito.junit.MockitoJUnitRunner; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Collections; @@ -62,7 +62,7 @@ public class AwsCmdbRecordHandlerTest private FederatedIdentity identity = new FederatedIdentity("arn", "account", Collections.emptyMap(), Collections.emptyList()); @Mock - private AmazonS3 mockS3; + private S3Client mockS3; @Mock private TableProviderFactory mockTableProviderFactory; diff --git a/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/TableProviderFactoryTest.java b/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/TableProviderFactoryTest.java index 19a77878e4..c196e379d6 100644 --- a/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/TableProviderFactoryTest.java +++ b/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/TableProviderFactoryTest.java @@ -24,11 +24,11 @@ import com.amazonaws.services.ec2.AmazonEC2; import com.amazonaws.services.elasticmapreduce.AmazonElasticMapReduce; import com.amazonaws.services.rds.AmazonRDS; -import com.amazonaws.services.s3.AmazonS3; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; +import software.amazon.awssdk.services.s3.S3Client; import java.util.List; import java.util.Map; @@ -51,7 +51,7 @@ public class TableProviderFactoryTest private AmazonRDS mockRds; @Mock - private AmazonS3 amazonS3; + private S3Client amazonS3; private TableProviderFactory factory = new TableProviderFactory(mockEc2, mockEmr, mockRds, amazonS3, com.google.common.collect.ImmutableMap.of()); diff --git a/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/AbstractTableProviderTest.java b/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/AbstractTableProviderTest.java index f4d6ba505a..8ab8620921 100644 --- a/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/AbstractTableProviderTest.java +++ b/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/AbstractTableProviderTest.java @@ -43,11 +43,6 @@ import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory; import com.amazonaws.athena.connector.lambda.security.FederatedIdentity; import com.amazonaws.athena.connector.lambda.security.LocalKeyFactory; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.PutObjectRequest; -import com.amazonaws.services.s3.model.PutObjectResult; -import com.amazonaws.services.s3.model.S3Object; -import com.amazonaws.services.s3.model.S3ObjectInputStream; import com.google.common.io.ByteStreams; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Schema; @@ -59,8 +54,16 @@ import org.mockito.Mockito; import org.mockito.invocation.InvocationOnMock; import org.mockito.junit.MockitoJUnitRunner; +import org.mockito.stubbing.Answer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; import java.io.ByteArrayInputStream; import java.io.InputStream; @@ -74,8 +77,6 @@ import static com.amazonaws.athena.connector.lambda.domain.predicate.Constraints.DEFAULT_NO_LIMIT; import static org.junit.Assert.*; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.nullable; -import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @RunWith(MockitoJUnitRunner.class) @@ -99,7 +100,7 @@ public abstract class AbstractTableProviderTest private final List mockS3Store = new ArrayList<>(); @Mock - private AmazonS3 amazonS3; + private S3Client amazonS3; @Mock private QueryStatusChecker queryStatusChecker; @@ -129,24 +130,24 @@ public void setUp() { allocator = new BlockAllocatorImpl(); - when(amazonS3.putObject(any())) + when(amazonS3.putObject(any(PutObjectRequest.class), any(RequestBody.class))) .thenAnswer((InvocationOnMock invocationOnMock) -> { - InputStream inputStream = ((PutObjectRequest) invocationOnMock.getArguments()[0]).getInputStream(); + InputStream inputStream = ((RequestBody) invocationOnMock.getArguments()[1]).contentStreamProvider().newStream(); ByteHolder byteHolder = new ByteHolder(); byteHolder.setBytes(ByteStreams.toByteArray(inputStream)); mockS3Store.add(byteHolder); - return mock(PutObjectResult.class); + return PutObjectResponse.builder().build(); }); - when(amazonS3.getObject(nullable(String.class), nullable(String.class))) - .thenAnswer((InvocationOnMock invocationOnMock) -> { - S3Object mockObject = mock(S3Object.class); - ByteHolder byteHolder = mockS3Store.get(0); - mockS3Store.remove(0); - when(mockObject.getObjectContent()).thenReturn( - new S3ObjectInputStream( - new ByteArrayInputStream(byteHolder.getBytes()), null)); - return mockObject; + when(amazonS3.getObject(any(GetObjectRequest.class))) + .thenAnswer(new Answer() + { + @Override + public Object answer(InvocationOnMock invocationOnMock) + throws Throwable + { + return new ResponseInputStream<>(GetObjectResponse.builder().build(), new ByteArrayInputStream(mockS3Store.get(0).getBytes())); + } }); blockSpillReader = new S3BlockSpillReader(amazonS3, allocator); diff --git a/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/s3/S3BucketsTableProviderTest.java b/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/s3/S3BucketsTableProviderTest.java index cb1372a917..348a077164 100644 --- a/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/s3/S3BucketsTableProviderTest.java +++ b/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/s3/S3BucketsTableProviderTest.java @@ -23,9 +23,6 @@ import com.amazonaws.athena.connector.lambda.data.BlockUtils; import com.amazonaws.athena.connectors.aws.cmdb.tables.AbstractTableProviderTest; import com.amazonaws.athena.connectors.aws.cmdb.tables.TableProvider; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.Bucket; -import com.amazonaws.services.s3.model.Owner; import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Field; @@ -33,12 +30,19 @@ import org.mockito.invocation.InvocationOnMock; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.Bucket; +import software.amazon.awssdk.services.s3.model.GetBucketAclRequest; +import software.amazon.awssdk.services.s3.model.GetBucketAclResponse; +import software.amazon.awssdk.services.s3.model.ListBucketsResponse; +import software.amazon.awssdk.services.s3.model.Owner; import java.util.ArrayList; import java.util.Date; import java.util.List; import static org.junit.Assert.*; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.when; public class S3BucketsTableProviderTest @@ -47,7 +51,7 @@ public class S3BucketsTableProviderTest private static final Logger logger = LoggerFactory.getLogger(S3BucketsTableProviderTest.class); @Mock - private AmazonS3 mockS3; + private S3Client mockS3; protected String getIdField() { @@ -87,7 +91,15 @@ protected void setUpRead() values.add(makeBucket(getIdValue())); values.add(makeBucket(getIdValue())); values.add(makeBucket("fake-id")); - return values; + return ListBucketsResponse.builder().buckets(values).build(); + }); + when(mockS3.getBucketAcl(any(GetBucketAclRequest.class))).thenAnswer((InvocationOnMock invocation) -> { + return GetBucketAclResponse.builder() + .owner(Owner.builder() + .displayName("owner_name") + .id("owner_id") + .build()) + .build(); }); } @@ -143,13 +155,10 @@ private void validate(FieldReader fieldReader) private Bucket makeBucket(String id) { - Bucket bucket = new Bucket(); - bucket.setName(id); - Owner owner = new Owner(); - owner.setDisplayName("owner_name"); - owner.setId("owner_id"); - bucket.setOwner(owner); - bucket.setCreationDate(new Date(100_000)); + Bucket bucket = Bucket.builder() + .name(id) + .creationDate(new Date(100_000).toInstant()) + .build(); return bucket; } } diff --git a/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/s3/S3ObjectsTableProviderTest.java b/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/s3/S3ObjectsTableProviderTest.java index ec77efc11a..761730ee08 100644 --- a/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/s3/S3ObjectsTableProviderTest.java +++ b/athena-aws-cmdb/src/test/java/com/amazonaws/athena/connectors/aws/cmdb/tables/s3/S3ObjectsTableProviderTest.java @@ -23,11 +23,6 @@ import com.amazonaws.athena.connector.lambda.data.BlockUtils; import com.amazonaws.athena.connectors.aws.cmdb.tables.AbstractTableProviderTest; import com.amazonaws.athena.connectors.aws.cmdb.tables.TableProvider; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.ListObjectsV2Request; -import com.amazonaws.services.s3.model.ListObjectsV2Result; -import com.amazonaws.services.s3.model.Owner; -import com.amazonaws.services.s3.model.S3ObjectSummary; import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Field; @@ -35,6 +30,11 @@ import org.mockito.invocation.InvocationOnMock; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.ListObjectsV2Request; +import software.amazon.awssdk.services.s3.model.ListObjectsV2Response; +import software.amazon.awssdk.services.s3.model.Owner; +import software.amazon.awssdk.services.s3.model.S3Object; import java.util.ArrayList; import java.util.Date; @@ -45,7 +45,6 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.nullable; -import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; public class S3ObjectsTableProviderTest @@ -54,7 +53,7 @@ public class S3ObjectsTableProviderTest private static final Logger logger = LoggerFactory.getLogger(S3ObjectsTableProviderTest.class); @Mock - private AmazonS3 mockS3; + private S3Client mockS3; protected String getIdField() { @@ -92,25 +91,26 @@ protected void setUpRead() AtomicLong count = new AtomicLong(0); when(mockS3.listObjectsV2(nullable(ListObjectsV2Request.class))).thenAnswer((InvocationOnMock invocation) -> { ListObjectsV2Request request = (ListObjectsV2Request) invocation.getArguments()[0]; - assertEquals(getIdValue(), request.getBucketName()); + assertEquals(getIdValue(), request.bucket()); - ListObjectsV2Result mockResult = mock(ListObjectsV2Result.class); - List values = new ArrayList<>(); - values.add(makeObjectSummary(getIdValue())); - values.add(makeObjectSummary(getIdValue())); - values.add(makeObjectSummary("fake-id")); - when(mockResult.getObjectSummaries()).thenReturn(values); + List values = new ArrayList<>(); + values.add(makeS3Object()); + values.add(makeS3Object()); + ListObjectsV2Response.Builder responseBuilder = ListObjectsV2Response.builder().contents(values); if (count.get() > 0) { - assertNotNull(request.getContinuationToken()); + assertNotNull(request.continuationToken()); } if (count.incrementAndGet() < 2) { - when(mockResult.isTruncated()).thenReturn(true); - when(mockResult.getNextContinuationToken()).thenReturn("token"); + responseBuilder.isTruncated(true); + responseBuilder.nextContinuationToken("token"); + } + else { + responseBuilder.isTruncated(false); } - return mockResult; + return responseBuilder.build(); }); } @@ -167,19 +167,17 @@ private void validate(FieldReader fieldReader) } } - private S3ObjectSummary makeObjectSummary(String id) + private S3Object makeS3Object() { - S3ObjectSummary summary = new S3ObjectSummary(); - Owner owner = new Owner(); - owner.setId("owner_id"); - owner.setDisplayName("owner_name"); - summary.setOwner(owner); - summary.setBucketName(id); - summary.setETag("e_tag"); - summary.setKey("key"); - summary.setSize(100); - summary.setLastModified(new Date(100_000)); - summary.setStorageClass("storage_class"); - return summary; + Owner owner = Owner.builder().id("owner_id").displayName("owner_name").build(); + S3Object s3Object = S3Object.builder() + .owner(owner) + .eTag("e_tag") + .key("key") + .size((long)100) + .lastModified(new Date(100_000).toInstant()) + .storageClass("storage_class") + .build(); + return s3Object; } } diff --git a/athena-clickhouse/src/main/java/com/amazonaws/athena/connectors/clickhouse/ClickHouseMuxRecordHandler.java b/athena-clickhouse/src/main/java/com/amazonaws/athena/connectors/clickhouse/ClickHouseMuxRecordHandler.java index ebe3ce9fa8..ccb54ee88a 100644 --- a/athena-clickhouse/src/main/java/com/amazonaws/athena/connectors/clickhouse/ClickHouseMuxRecordHandler.java +++ b/athena-clickhouse/src/main/java/com/amazonaws/athena/connectors/clickhouse/ClickHouseMuxRecordHandler.java @@ -25,9 +25,9 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandlerFactory; -import com.amazonaws.services.s3.AmazonS3; import com.google.common.annotations.VisibleForTesting; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -58,7 +58,7 @@ public ClickHouseMuxRecordHandler(java.util.Map configOptions) } @VisibleForTesting - ClickHouseMuxRecordHandler(AmazonS3 amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, + ClickHouseMuxRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, DatabaseConnectionConfig databaseConnectionConfig, Map recordHandlerMap, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, jdbcConnectionFactory, databaseConnectionConfig, recordHandlerMap, configOptions); diff --git a/athena-clickhouse/src/main/java/com/amazonaws/athena/connectors/clickhouse/ClickHouseRecordHandler.java b/athena-clickhouse/src/main/java/com/amazonaws/athena/connectors/clickhouse/ClickHouseRecordHandler.java index da5187bf7f..6728a6c4e1 100644 --- a/athena-clickhouse/src/main/java/com/amazonaws/athena/connectors/clickhouse/ClickHouseRecordHandler.java +++ b/athena-clickhouse/src/main/java/com/amazonaws/athena/connectors/clickhouse/ClickHouseRecordHandler.java @@ -32,14 +32,13 @@ import com.amazonaws.athena.connectors.mysql.MySqlFederationExpressionParser; import com.amazonaws.athena.connectors.mysql.MySqlMuxCompositeHandler; import com.amazonaws.athena.connectors.mysql.MySqlQueryStringBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; import com.google.common.annotations.VisibleForTesting; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.commons.lang3.Validate; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; @@ -76,12 +75,12 @@ public ClickHouseRecordHandler(DatabaseConnectionConfig databaseConnectionConfig public ClickHouseRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, JdbcConnectionFactory jdbcConnectionFactory, java.util.Map configOptions) { - this(databaseConnectionConfig, AmazonS3ClientBuilder.defaultClient(), SecretsManagerClient.create(), AthenaClient.create(), + this(databaseConnectionConfig, S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(), jdbcConnectionFactory, new MySqlQueryStringBuilder(MYSQL_QUOTE_CHARACTER, new MySqlFederationExpressionParser(MYSQL_QUOTE_CHARACTER)), configOptions); } @VisibleForTesting - ClickHouseRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, final AmazonS3 amazonS3, final SecretsManagerClient secretsManager, + ClickHouseRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, final S3Client amazonS3, final SecretsManagerClient secretsManager, final AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, databaseConnectionConfig, jdbcConnectionFactory, configOptions); diff --git a/athena-clickhouse/src/test/java/com/amazonaws/athena/connectors/clickhouse/ClickHouseMuxJdbcRecordHandlerTest.java b/athena-clickhouse/src/test/java/com/amazonaws/athena/connectors/clickhouse/ClickHouseMuxJdbcRecordHandlerTest.java index 1b0b1f629f..9adc8e9096 100644 --- a/athena-clickhouse/src/test/java/com/amazonaws/athena/connectors/clickhouse/ClickHouseMuxJdbcRecordHandlerTest.java +++ b/athena-clickhouse/src/test/java/com/amazonaws/athena/connectors/clickhouse/ClickHouseMuxJdbcRecordHandlerTest.java @@ -28,12 +28,12 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; -import com.amazonaws.services.s3.AmazonS3; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; @@ -46,7 +46,7 @@ public class ClickHouseMuxJdbcRecordHandlerTest private Map recordHandlerMap; private ClickHouseRecordHandler recordHandler; private JdbcRecordHandler jdbcRecordHandler; - private AmazonS3 amazonS3; + private S3Client amazonS3; private SecretsManagerClient secretsManager; private AthenaClient athena; private QueryStatusChecker queryStatusChecker; @@ -57,7 +57,7 @@ public void setup() { this.recordHandler = Mockito.mock(ClickHouseRecordHandler.class); this.recordHandlerMap = Collections.singletonMap(ClickHouseConstants.NAME, this.recordHandler); - this.amazonS3 = Mockito.mock(AmazonS3.class); + this.amazonS3 = Mockito.mock(S3Client.class); this.secretsManager = Mockito.mock(SecretsManagerClient.class); this.athena = Mockito.mock(AthenaClient.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); diff --git a/athena-cloudera-hive/src/main/java/com/amazonaws/athena/connectors/cloudera/HiveMuxRecordHandler.java b/athena-cloudera-hive/src/main/java/com/amazonaws/athena/connectors/cloudera/HiveMuxRecordHandler.java index a552bf34ad..3dd28acccc 100644 --- a/athena-cloudera-hive/src/main/java/com/amazonaws/athena/connectors/cloudera/HiveMuxRecordHandler.java +++ b/athena-cloudera-hive/src/main/java/com/amazonaws/athena/connectors/cloudera/HiveMuxRecordHandler.java @@ -25,9 +25,9 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandlerFactory; -import com.amazonaws.services.s3.AmazonS3; import com.google.common.annotations.VisibleForTesting; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -57,7 +57,7 @@ public HiveMuxRecordHandler(java.util.Map configOptions) } @VisibleForTesting - HiveMuxRecordHandler(AmazonS3 amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, + HiveMuxRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, DatabaseConnectionConfig databaseConnectionConfig, Map recordHandlerMap, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, jdbcConnectionFactory, databaseConnectionConfig, recordHandlerMap, configOptions); diff --git a/athena-cloudera-hive/src/main/java/com/amazonaws/athena/connectors/cloudera/HiveRecordHandler.java b/athena-cloudera-hive/src/main/java/com/amazonaws/athena/connectors/cloudera/HiveRecordHandler.java index 839c8d0f31..95ff9f6a3e 100644 --- a/athena-cloudera-hive/src/main/java/com/amazonaws/athena/connectors/cloudera/HiveRecordHandler.java +++ b/athena-cloudera-hive/src/main/java/com/amazonaws/athena/connectors/cloudera/HiveRecordHandler.java @@ -28,12 +28,11 @@ import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; import com.google.common.annotations.VisibleForTesting; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.commons.lang3.Validate; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; @@ -60,11 +59,11 @@ public HiveRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, java } public HiveRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, JdbcConnectionFactory jdbcConnectionFactory, java.util.Map configOptions) { - this(databaseConnectionConfig, AmazonS3ClientBuilder.defaultClient(), SecretsManagerClient.create(), AthenaClient.create(), + this(databaseConnectionConfig, S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(), jdbcConnectionFactory, new HiveQueryStringBuilder(HIVE_QUOTE_CHARACTER, new HiveFederationExpressionParser(HIVE_QUOTE_CHARACTER)), configOptions); } @VisibleForTesting - HiveRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, AmazonS3 amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) + HiveRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, databaseConnectionConfig, jdbcConnectionFactory, configOptions); this.jdbcSplitQueryBuilder = Validate.notNull(jdbcSplitQueryBuilder, "query builder must not be null"); diff --git a/athena-cloudera-hive/src/test/java/com/amazonaws/athena/connectors/cloudera/HiveMuxRecordHandlerTest.java b/athena-cloudera-hive/src/test/java/com/amazonaws/athena/connectors/cloudera/HiveMuxRecordHandlerTest.java index f7888d4f2f..31035ae1a8 100644 --- a/athena-cloudera-hive/src/test/java/com/amazonaws/athena/connectors/cloudera/HiveMuxRecordHandlerTest.java +++ b/athena-cloudera-hive/src/test/java/com/amazonaws/athena/connectors/cloudera/HiveMuxRecordHandlerTest.java @@ -29,7 +29,6 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; -import com.amazonaws.services.s3.AmazonS3; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Before; import org.junit.BeforeClass; @@ -37,6 +36,7 @@ import org.mockito.Mockito; import org.testng.Assert; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; @@ -50,7 +50,7 @@ public class HiveMuxRecordHandlerTest private Map recordHandlerMap; private HiveRecordHandler hiveRecordHandler; private JdbcRecordHandler jdbcRecordHandler; - private AmazonS3 amazonS3; + private S3Client amazonS3; private SecretsManagerClient secretsManager; private AthenaClient athena; private QueryStatusChecker queryStatusChecker; @@ -64,7 +64,7 @@ public void setup() { this.hiveRecordHandler = Mockito.mock(HiveRecordHandler.class); this.recordHandlerMap = Collections.singletonMap("recordHive", this.hiveRecordHandler); - this.amazonS3 = Mockito.mock(AmazonS3.class); + this.amazonS3 = Mockito.mock(S3Client.class); this.secretsManager = Mockito.mock(SecretsManagerClient.class); this.athena = Mockito.mock(AthenaClient.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); diff --git a/athena-cloudera-hive/src/test/java/com/amazonaws/athena/connectors/cloudera/HiveRecordHandlerTest.java b/athena-cloudera-hive/src/test/java/com/amazonaws/athena/connectors/cloudera/HiveRecordHandlerTest.java index 57ebfd68ff..108474f096 100644 --- a/athena-cloudera-hive/src/test/java/com/amazonaws/athena/connectors/cloudera/HiveRecordHandlerTest.java +++ b/athena-cloudera-hive/src/test/java/com/amazonaws/athena/connectors/cloudera/HiveRecordHandlerTest.java @@ -32,7 +32,6 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.s3.AmazonS3; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.apache.arrow.vector.types.Types; @@ -41,7 +40,9 @@ import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; + import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; @@ -62,7 +63,7 @@ public class HiveRecordHandlerTest private Connection connection; private JdbcConnectionFactory jdbcConnectionFactory; private JdbcSplitQueryBuilder jdbcSplitQueryBuilder; - private AmazonS3 amazonS3; + private S3Client amazonS3; private SecretsManagerClient secretsManager; private AthenaClient athena; @@ -70,7 +71,7 @@ public class HiveRecordHandlerTest public void setup() throws Exception { - this.amazonS3 = Mockito.mock(AmazonS3.class); + this.amazonS3 = Mockito.mock(S3Client.class); this.secretsManager = Mockito.mock(SecretsManagerClient.class); this.athena = Mockito.mock(AthenaClient.class); Mockito.when(this.secretsManager.getSecretValue(Mockito.eq(GetSecretValueRequest.builder().secretId("testSecret").build()))).thenReturn(GetSecretValueResponse.builder().secretString("{\"username\": \"testUser\", \"password\": \"testPassword\"}").build()); diff --git a/athena-cloudera-impala/src/main/java/com/amazonaws/athena/connectors/cloudera/ImpalaMuxRecordHandler.java b/athena-cloudera-impala/src/main/java/com/amazonaws/athena/connectors/cloudera/ImpalaMuxRecordHandler.java index 0faae53965..8dbac1f9e3 100644 --- a/athena-cloudera-impala/src/main/java/com/amazonaws/athena/connectors/cloudera/ImpalaMuxRecordHandler.java +++ b/athena-cloudera-impala/src/main/java/com/amazonaws/athena/connectors/cloudera/ImpalaMuxRecordHandler.java @@ -25,9 +25,9 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandlerFactory; -import com.amazonaws.services.s3.AmazonS3; import com.google.common.annotations.VisibleForTesting; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -57,7 +57,7 @@ public ImpalaMuxRecordHandler(java.util.Map configOptions) } @VisibleForTesting - ImpalaMuxRecordHandler(AmazonS3 amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, + ImpalaMuxRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, DatabaseConnectionConfig databaseConnectionConfig, Map recordHandlerMap, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, jdbcConnectionFactory, databaseConnectionConfig, recordHandlerMap, configOptions); diff --git a/athena-cloudera-impala/src/main/java/com/amazonaws/athena/connectors/cloudera/ImpalaRecordHandler.java b/athena-cloudera-impala/src/main/java/com/amazonaws/athena/connectors/cloudera/ImpalaRecordHandler.java index 4c72168dc4..59912af693 100644 --- a/athena-cloudera-impala/src/main/java/com/amazonaws/athena/connectors/cloudera/ImpalaRecordHandler.java +++ b/athena-cloudera-impala/src/main/java/com/amazonaws/athena/connectors/cloudera/ImpalaRecordHandler.java @@ -28,12 +28,11 @@ import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; import com.google.common.annotations.VisibleForTesting; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.commons.lang3.Validate; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; @@ -60,11 +59,11 @@ public ImpalaRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, ja } public ImpalaRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, JdbcConnectionFactory jdbcConnectionFactory, java.util.Map configOptions) { - this(databaseConnectionConfig, AmazonS3ClientBuilder.defaultClient(), SecretsManagerClient.create(), AthenaClient.create(), + this(databaseConnectionConfig, S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(), jdbcConnectionFactory, new ImpalaQueryStringBuilder(IMPALA_QUOTE_CHARACTER, new ImpalaFederationExpressionParser(IMPALA_QUOTE_CHARACTER)), configOptions); } @VisibleForTesting - ImpalaRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, AmazonS3 amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) + ImpalaRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, databaseConnectionConfig, jdbcConnectionFactory, configOptions); this.jdbcSplitQueryBuilder = Validate.notNull(jdbcSplitQueryBuilder, "query builder must not be null"); diff --git a/athena-cloudera-impala/src/test/java/com/amazonaws/athena/connectors/cloudera/ImpalaMuxRecordHandlerTest.java b/athena-cloudera-impala/src/test/java/com/amazonaws/athena/connectors/cloudera/ImpalaMuxRecordHandlerTest.java index dbbd6aef09..cff80beebb 100644 --- a/athena-cloudera-impala/src/test/java/com/amazonaws/athena/connectors/cloudera/ImpalaMuxRecordHandlerTest.java +++ b/athena-cloudera-impala/src/test/java/com/amazonaws/athena/connectors/cloudera/ImpalaMuxRecordHandlerTest.java @@ -29,7 +29,6 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; -import com.amazonaws.services.s3.AmazonS3; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Before; import org.junit.BeforeClass; @@ -37,6 +36,7 @@ import org.mockito.Mockito; import org.testng.Assert; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; @@ -50,7 +50,7 @@ public class ImpalaMuxRecordHandlerTest private Map recordHandlerMap; private ImpalaRecordHandler impalaRecordHandler; private JdbcRecordHandler jdbcRecordHandler; - private AmazonS3 amazonS3; + private S3Client amazonS3; private SecretsManagerClient secretsManager; private AthenaClient athena; private QueryStatusChecker queryStatusChecker; @@ -64,7 +64,7 @@ public void setup() { this.impalaRecordHandler = Mockito.mock(ImpalaRecordHandler.class); this.recordHandlerMap = Collections.singletonMap("recordImpala", this.impalaRecordHandler); - this.amazonS3 = Mockito.mock(AmazonS3.class); + this.amazonS3 = Mockito.mock(S3Client.class); this.secretsManager = Mockito.mock(SecretsManagerClient.class); this.athena = Mockito.mock(AthenaClient.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); diff --git a/athena-cloudera-impala/src/test/java/com/amazonaws/athena/connectors/cloudera/ImpalaRecordHandlerTest.java b/athena-cloudera-impala/src/test/java/com/amazonaws/athena/connectors/cloudera/ImpalaRecordHandlerTest.java index 8f502431a6..bd0909b48f 100644 --- a/athena-cloudera-impala/src/test/java/com/amazonaws/athena/connectors/cloudera/ImpalaRecordHandlerTest.java +++ b/athena-cloudera-impala/src/test/java/com/amazonaws/athena/connectors/cloudera/ImpalaRecordHandlerTest.java @@ -32,8 +32,8 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.s3.AmazonS3; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; @@ -63,7 +63,7 @@ public class ImpalaRecordHandlerTest private Connection connection; private JdbcConnectionFactory jdbcConnectionFactory; private JdbcSplitQueryBuilder jdbcSplitQueryBuilder; - private AmazonS3 amazonS3; + private S3Client amazonS3; private SecretsManagerClient secretsManager; private AthenaClient athena; @@ -71,7 +71,7 @@ public class ImpalaRecordHandlerTest public void setup() throws Exception { - this.amazonS3 = Mockito.mock(AmazonS3.class); + this.amazonS3 = Mockito.mock(S3Client.class); this.secretsManager = Mockito.mock(SecretsManagerClient.class); this.athena = Mockito.mock(AthenaClient.class); Mockito.when(this.secretsManager.getSecretValue(Mockito.eq(GetSecretValueRequest.builder().secretId("testSecret").build()))).thenReturn(GetSecretValueResponse.builder().secretString("{\"username\": \"testUser\", \"password\": \"testPassword\"}").build()); diff --git a/athena-cloudwatch-metrics/src/main/java/com/amazonaws/athena/connectors/cloudwatch/metrics/MetricsRecordHandler.java b/athena-cloudwatch-metrics/src/main/java/com/amazonaws/athena/connectors/cloudwatch/metrics/MetricsRecordHandler.java index 9bfe7efd82..5560b39f85 100644 --- a/athena-cloudwatch-metrics/src/main/java/com/amazonaws/athena/connectors/cloudwatch/metrics/MetricsRecordHandler.java +++ b/athena-cloudwatch-metrics/src/main/java/com/amazonaws/athena/connectors/cloudwatch/metrics/MetricsRecordHandler.java @@ -40,13 +40,12 @@ import com.amazonaws.services.cloudwatch.model.MetricDataQuery; import com.amazonaws.services.cloudwatch.model.MetricDataResult; import com.amazonaws.services.cloudwatch.model.MetricStat; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; import org.apache.arrow.util.VisibleForTesting; import org.apache.arrow.vector.types.pojo.Field; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Date; @@ -97,19 +96,19 @@ public class MetricsRecordHandler //Used to handle throttling events by applying AIMD congestion control private final ThrottlingInvoker invoker; - private final AmazonS3 amazonS3; + private final S3Client amazonS3; private final AmazonCloudWatch metrics; public MetricsRecordHandler(java.util.Map configOptions) { - this(AmazonS3ClientBuilder.defaultClient(), + this(S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(), AmazonCloudWatchClientBuilder.standard().build(), configOptions); } @VisibleForTesting - protected MetricsRecordHandler(AmazonS3 amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, AmazonCloudWatch metrics, java.util.Map configOptions) + protected MetricsRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, AmazonCloudWatch metrics, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, SOURCE_TYPE, configOptions); this.amazonS3 = amazonS3; diff --git a/athena-cloudwatch-metrics/src/test/java/com/amazonaws/athena/connectors/cloudwatch/metrics/MetricsRecordHandlerTest.java b/athena-cloudwatch-metrics/src/test/java/com/amazonaws/athena/connectors/cloudwatch/metrics/MetricsRecordHandlerTest.java index 51a69cf686..ae25003e62 100644 --- a/athena-cloudwatch-metrics/src/test/java/com/amazonaws/athena/connectors/cloudwatch/metrics/MetricsRecordHandlerTest.java +++ b/athena-cloudwatch-metrics/src/test/java/com/amazonaws/athena/connectors/cloudwatch/metrics/MetricsRecordHandlerTest.java @@ -47,11 +47,6 @@ import com.amazonaws.services.cloudwatch.model.MetricDataQuery; import com.amazonaws.services.cloudwatch.model.MetricDataResult; import com.amazonaws.services.cloudwatch.model.MetricStat; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.PutObjectRequest; -import com.amazonaws.services.s3.model.PutObjectResult; -import com.amazonaws.services.s3.model.S3Object; -import com.amazonaws.services.s3.model.S3ObjectInputStream; import com.google.common.io.ByteStreams; import org.junit.After; import org.junit.Before; @@ -66,6 +61,14 @@ import software.amazon.awssdk.services.athena.AthenaClient; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; + import java.io.ByteArrayInputStream; import java.io.InputStream; import java.util.ArrayList; @@ -115,7 +118,7 @@ public class MetricsRecordHandlerTest private AmazonCloudWatch mockMetrics; @Mock - private AmazonS3 mockS3; + private S3Client mockS3; @Mock private SecretsManagerClient mockSecretsManager; @@ -132,31 +135,27 @@ public void setUp() handler = new MetricsRecordHandler(mockS3, mockSecretsManager, mockAthena, mockMetrics, com.google.common.collect.ImmutableMap.of()); spillReader = new S3BlockSpillReader(mockS3, allocator); - Mockito.lenient().when(mockS3.putObject(any())) + Mockito.lenient().when(mockS3.putObject(any(PutObjectRequest.class), any(RequestBody.class))) .thenAnswer((InvocationOnMock invocationOnMock) -> { - InputStream inputStream = ((PutObjectRequest) invocationOnMock.getArguments()[0]).getInputStream(); + InputStream inputStream = ((RequestBody) invocationOnMock.getArguments()[1]).contentStreamProvider().newStream(); ByteHolder byteHolder = new ByteHolder(); byteHolder.setBytes(ByteStreams.toByteArray(inputStream)); synchronized (mockS3Storage) { mockS3Storage.add(byteHolder); logger.info("puObject: total size " + mockS3Storage.size()); } - return mock(PutObjectResult.class); + return PutObjectResponse.builder().build(); }); - Mockito.lenient().when(mockS3.getObject(nullable(String.class), nullable(String.class))) + Mockito.lenient().when(mockS3.getObject(any(GetObjectRequest.class))) .thenAnswer((InvocationOnMock invocationOnMock) -> { - S3Object mockObject = mock(S3Object.class); ByteHolder byteHolder; synchronized (mockS3Storage) { byteHolder = mockS3Storage.get(0); mockS3Storage.remove(0); logger.info("getObject: total size " + mockS3Storage.size()); } - when(mockObject.getObjectContent()).thenReturn( - new S3ObjectInputStream( - new ByteArrayInputStream(byteHolder.getBytes()), null)); - return mockObject; + return new ResponseInputStream<>(GetObjectResponse.builder().build(), new ByteArrayInputStream(byteHolder.getBytes())); }); } diff --git a/athena-cloudwatch/src/main/java/com/amazonaws/athena/connectors/cloudwatch/CloudwatchRecordHandler.java b/athena-cloudwatch/src/main/java/com/amazonaws/athena/connectors/cloudwatch/CloudwatchRecordHandler.java index 05f8c25430..7b4aa47596 100644 --- a/athena-cloudwatch/src/main/java/com/amazonaws/athena/connectors/cloudwatch/CloudwatchRecordHandler.java +++ b/athena-cloudwatch/src/main/java/com/amazonaws/athena/connectors/cloudwatch/CloudwatchRecordHandler.java @@ -39,12 +39,11 @@ import com.amazonaws.services.logs.model.GetQueryResultsResult; import com.amazonaws.services.logs.model.OutputLogEvent; import com.amazonaws.services.logs.model.ResultField; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; import org.apache.arrow.util.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.List; @@ -82,7 +81,7 @@ public class CloudwatchRecordHandler public CloudwatchRecordHandler(java.util.Map configOptions) { this( - AmazonS3ClientBuilder.defaultClient(), + S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(), AWSLogsClientBuilder.defaultClient(), @@ -90,7 +89,7 @@ public CloudwatchRecordHandler(java.util.Map configOptions) } @VisibleForTesting - protected CloudwatchRecordHandler(AmazonS3 amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, AWSLogs awsLogs, java.util.Map configOptions) + protected CloudwatchRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, AWSLogs awsLogs, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, SOURCE_TYPE, configOptions); this.awsLogs = awsLogs; diff --git a/athena-cloudwatch/src/test/java/com/amazonaws/athena/connectors/cloudwatch/CloudwatchRecordHandlerTest.java b/athena-cloudwatch/src/test/java/com/amazonaws/athena/connectors/cloudwatch/CloudwatchRecordHandlerTest.java index 5e494cb9d0..758deacb50 100644 --- a/athena-cloudwatch/src/test/java/com/amazonaws/athena/connectors/cloudwatch/CloudwatchRecordHandlerTest.java +++ b/athena-cloudwatch/src/test/java/com/amazonaws/athena/connectors/cloudwatch/CloudwatchRecordHandlerTest.java @@ -43,11 +43,6 @@ import com.amazonaws.services.logs.model.GetLogEventsRequest; import com.amazonaws.services.logs.model.GetLogEventsResult; import com.amazonaws.services.logs.model.OutputLogEvent; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.PutObjectRequest; -import com.amazonaws.services.s3.model.PutObjectResult; -import com.amazonaws.services.s3.model.S3Object; -import com.amazonaws.services.s3.model.S3ObjectInputStream; import com.google.common.collect.ImmutableList; import com.google.common.io.ByteStreams; import org.apache.arrow.vector.types.Types; @@ -61,7 +56,14 @@ import org.mockito.junit.MockitoJUnitRunner; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.core.sync.RequestBody; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.io.ByteArrayInputStream; @@ -77,7 +79,6 @@ import static org.junit.Assert.*; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.nullable; -import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @RunWith(MockitoJUnitRunner.class) @@ -97,7 +98,7 @@ public class CloudwatchRecordHandlerTest private AWSLogs mockAwsLogs; @Mock - private AmazonS3 mockS3; + private S3Client mockS3; @Mock private SecretsManagerClient mockSecretsManager; @@ -116,31 +117,27 @@ public void setUp() handler = new CloudwatchRecordHandler(mockS3, mockSecretsManager, mockAthena, mockAwsLogs, com.google.common.collect.ImmutableMap.of()); spillReader = new S3BlockSpillReader(mockS3, allocator); - when(mockS3.putObject(any())) + when(mockS3.putObject(any(PutObjectRequest.class), any(RequestBody.class))) .thenAnswer((InvocationOnMock invocationOnMock) -> { - InputStream inputStream = ((PutObjectRequest) invocationOnMock.getArguments()[0]).getInputStream(); + InputStream inputStream = ((RequestBody) invocationOnMock.getArguments()[1]).contentStreamProvider().newStream(); ByteHolder byteHolder = new ByteHolder(); byteHolder.setBytes(ByteStreams.toByteArray(inputStream)); synchronized (mockS3Storage) { mockS3Storage.add(byteHolder); logger.info("puObject: total size " + mockS3Storage.size()); } - return mock(PutObjectResult.class); + return PutObjectResponse.builder().build(); }); - when(mockS3.getObject(nullable(String.class), nullable(String.class))) + when(mockS3.getObject(any(GetObjectRequest.class))) .thenAnswer((InvocationOnMock invocationOnMock) -> { - S3Object mockObject = mock(S3Object.class); ByteHolder byteHolder; synchronized (mockS3Storage) { byteHolder = mockS3Storage.get(0); mockS3Storage.remove(0); logger.info("getObject: total size " + mockS3Storage.size()); } - when(mockObject.getObjectContent()).thenReturn( - new S3ObjectInputStream( - new ByteArrayInputStream(byteHolder.getBytes()), null)); - return mockObject; + return new ResponseInputStream<>(GetObjectResponse.builder().build(), new ByteArrayInputStream(byteHolder.getBytes())); }); when(mockAwsLogs.getLogEvents(nullable(GetLogEventsRequest.class))).thenAnswer((InvocationOnMock invocationOnMock) -> { diff --git a/athena-datalakegen2/src/main/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2MuxRecordHandler.java b/athena-datalakegen2/src/main/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2MuxRecordHandler.java index 81dcb42ba7..dd7c643f82 100644 --- a/athena-datalakegen2/src/main/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2MuxRecordHandler.java +++ b/athena-datalakegen2/src/main/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2MuxRecordHandler.java @@ -24,9 +24,9 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandlerFactory; -import com.amazonaws.services.s3.AmazonS3; import com.google.common.annotations.VisibleForTesting; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -54,7 +54,7 @@ public DataLakeGen2MuxRecordHandler(java.util.Map configOptions) } @VisibleForTesting - DataLakeGen2MuxRecordHandler(AmazonS3 amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, + DataLakeGen2MuxRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, DatabaseConnectionConfig databaseConnectionConfig, Map recordHandlerMap, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, jdbcConnectionFactory, databaseConnectionConfig, recordHandlerMap, configOptions); diff --git a/athena-datalakegen2/src/main/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2RecordHandler.java b/athena-datalakegen2/src/main/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2RecordHandler.java index a31bb1ee3a..f80e8bd0c0 100644 --- a/athena-datalakegen2/src/main/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2RecordHandler.java +++ b/athena-datalakegen2/src/main/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2RecordHandler.java @@ -28,12 +28,11 @@ import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; import com.google.common.annotations.VisibleForTesting; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.commons.lang3.Validate; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; @@ -52,12 +51,12 @@ public DataLakeGen2RecordHandler(java.util.Map configOptions) } public DataLakeGen2RecordHandler(DatabaseConnectionConfig databaseConnectionConfig, java.util.Map configOptions) { - this(databaseConnectionConfig, AmazonS3ClientBuilder.defaultClient(), SecretsManagerClient.create(), AthenaClient.create(), + this(databaseConnectionConfig, S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(), new GenericJdbcConnectionFactory(databaseConnectionConfig, DataLakeGen2MetadataHandler.JDBC_PROPERTIES, new DatabaseConnectionInfo(DataLakeGen2Constants.DRIVER_CLASS, DataLakeGen2Constants.DEFAULT_PORT)), new DataLakeGen2QueryStringBuilder(QUOTE_CHARACTER, new DataLakeGen2FederationExpressionParser(QUOTE_CHARACTER)), configOptions); } @VisibleForTesting - DataLakeGen2RecordHandler(DatabaseConnectionConfig databaseConnectionConfig, AmazonS3 amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) + DataLakeGen2RecordHandler(DatabaseConnectionConfig databaseConnectionConfig, S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, databaseConnectionConfig, jdbcConnectionFactory, configOptions); this.jdbcSplitQueryBuilder = Validate.notNull(jdbcSplitQueryBuilder, "query builder must not be null"); diff --git a/athena-datalakegen2/src/test/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2MuxRecordHandlerTest.java b/athena-datalakegen2/src/test/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2MuxRecordHandlerTest.java index 3d4c0bbb0a..dc2fa02473 100644 --- a/athena-datalakegen2/src/test/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2MuxRecordHandlerTest.java +++ b/athena-datalakegen2/src/test/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeGen2MuxRecordHandlerTest.java @@ -28,12 +28,12 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; -import com.amazonaws.services.s3.AmazonS3; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; @@ -46,7 +46,7 @@ public class DataLakeGen2MuxRecordHandlerTest private Map recordHandlerMap; private DataLakeGen2RecordHandler dataLakeGen2RecordHandler; private JdbcRecordHandler jdbcRecordHandler; - private AmazonS3 amazonS3; + private S3Client amazonS3; private SecretsManagerClient secretsManager; private AthenaClient athena; private QueryStatusChecker queryStatusChecker; @@ -57,7 +57,7 @@ public void setup() { this.dataLakeGen2RecordHandler = Mockito.mock(DataLakeGen2RecordHandler.class); this.recordHandlerMap = Collections.singletonMap(DataLakeGen2Constants.NAME, this.dataLakeGen2RecordHandler); - this.amazonS3 = Mockito.mock(AmazonS3.class); + this.amazonS3 = Mockito.mock(S3Client.class); this.secretsManager = Mockito.mock(SecretsManagerClient.class); this.athena = Mockito.mock(AthenaClient.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); diff --git a/athena-datalakegen2/src/test/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeRecordHandlerTest.java b/athena-datalakegen2/src/test/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeRecordHandlerTest.java index 0b77403715..912d328fa3 100644 --- a/athena-datalakegen2/src/test/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeRecordHandlerTest.java +++ b/athena-datalakegen2/src/test/java/com/amazonaws/athena/connectors/datalakegen2/DataLakeRecordHandlerTest.java @@ -31,7 +31,6 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.s3.AmazonS3; import com.google.common.collect.ImmutableMap; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Schema; @@ -40,6 +39,7 @@ import org.junit.Test; import org.mockito.Mockito; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; @@ -56,7 +56,7 @@ public class DataLakeRecordHandlerTest private Connection connection; private JdbcConnectionFactory jdbcConnectionFactory; private JdbcSplitQueryBuilder jdbcSplitQueryBuilder; - private AmazonS3 amazonS3; + private S3Client amazonS3; private SecretsManagerClient secretsManager; private AthenaClient athena; @@ -65,7 +65,7 @@ public void setup() throws Exception { System.setProperty("aws.region", "us-east-1"); - this.amazonS3 = Mockito.mock(AmazonS3.class); + this.amazonS3 = Mockito.mock(S3Client.class); this.secretsManager = Mockito.mock(SecretsManagerClient.class); this.athena = Mockito.mock(AthenaClient.class); this.connection = Mockito.mock(Connection.class); diff --git a/athena-db2-as400/src/main/java/com/amazonaws/athena/connectors/db2as400/Db2As400MuxRecordHandler.java b/athena-db2-as400/src/main/java/com/amazonaws/athena/connectors/db2as400/Db2As400MuxRecordHandler.java index 3f17c52a56..3d4706a208 100644 --- a/athena-db2-as400/src/main/java/com/amazonaws/athena/connectors/db2as400/Db2As400MuxRecordHandler.java +++ b/athena-db2-as400/src/main/java/com/amazonaws/athena/connectors/db2as400/Db2As400MuxRecordHandler.java @@ -24,9 +24,9 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandlerFactory; -import com.amazonaws.services.s3.AmazonS3; import com.google.common.annotations.VisibleForTesting; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -54,7 +54,7 @@ public Db2As400MuxRecordHandler(java.util.Map configOptions) } @VisibleForTesting - Db2As400MuxRecordHandler(AmazonS3 amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, + Db2As400MuxRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, DatabaseConnectionConfig databaseConnectionConfig, Map recordHandlerMap, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, jdbcConnectionFactory, databaseConnectionConfig, recordHandlerMap, configOptions); diff --git a/athena-db2-as400/src/main/java/com/amazonaws/athena/connectors/db2as400/Db2As400RecordHandler.java b/athena-db2-as400/src/main/java/com/amazonaws/athena/connectors/db2as400/Db2As400RecordHandler.java index 60898b6524..69d0711852 100644 --- a/athena-db2-as400/src/main/java/com/amazonaws/athena/connectors/db2as400/Db2As400RecordHandler.java +++ b/athena-db2-as400/src/main/java/com/amazonaws/athena/connectors/db2as400/Db2As400RecordHandler.java @@ -29,12 +29,11 @@ import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; import com.google.common.annotations.VisibleForTesting; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.commons.lang3.Validate; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; @@ -58,13 +57,13 @@ public Db2As400RecordHandler(java.util.Map configOptions) */ public Db2As400RecordHandler(DatabaseConnectionConfig databaseConnectionConfig, java.util.Map configOptions) { - this(databaseConnectionConfig, AmazonS3ClientBuilder.defaultClient(), SecretsManagerClient.create(), AthenaClient.create(), + this(databaseConnectionConfig, S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(), new GenericJdbcConnectionFactory(databaseConnectionConfig, null, new DatabaseConnectionInfo(Db2As400Constants.DRIVER_CLASS, Db2As400Constants.DEFAULT_PORT)), new Db2As400QueryStringBuilder(QUOTE_CHARACTER), configOptions); } @VisibleForTesting - Db2As400RecordHandler(DatabaseConnectionConfig databaseConnectionConfig, AmazonS3 amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) + Db2As400RecordHandler(DatabaseConnectionConfig databaseConnectionConfig, S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, databaseConnectionConfig, jdbcConnectionFactory, configOptions); this.jdbcSplitQueryBuilder = Validate.notNull(jdbcSplitQueryBuilder, "query builder must not be null"); diff --git a/athena-db2-as400/src/test/java/com/amazonaws/athena/connectors/db2as400/Db2As400RecordHandlerTest.java b/athena-db2-as400/src/test/java/com/amazonaws/athena/connectors/db2as400/Db2As400RecordHandlerTest.java index e1bd503827..4ca5b947a8 100644 --- a/athena-db2-as400/src/test/java/com/amazonaws/athena/connectors/db2as400/Db2As400RecordHandlerTest.java +++ b/athena-db2-as400/src/test/java/com/amazonaws/athena/connectors/db2as400/Db2As400RecordHandlerTest.java @@ -31,7 +31,6 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.s3.AmazonS3; import com.google.common.collect.ImmutableMap; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Schema; @@ -40,6 +39,7 @@ import org.junit.Test; import org.mockito.Mockito; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; @@ -54,14 +54,14 @@ public class Db2As400RecordHandlerTest { private Connection connection; private JdbcConnectionFactory jdbcConnectionFactory; private JdbcSplitQueryBuilder jdbcSplitQueryBuilder; - private AmazonS3 amazonS3; + private S3Client amazonS3; private SecretsManagerClient secretsManager; private AthenaClient athena; @Before public void setup() throws Exception { System.setProperty("aws.region", "us-east-1"); - this.amazonS3 = Mockito.mock(AmazonS3.class); + this.amazonS3 = Mockito.mock(S3Client.class); this.secretsManager = Mockito.mock(SecretsManagerClient.class); this.athena = Mockito.mock(AthenaClient.class); this.connection = Mockito.mock(Connection.class); diff --git a/athena-db2/src/main/java/com/amazonaws/athena/connectors/db2/Db2MuxRecordHandler.java b/athena-db2/src/main/java/com/amazonaws/athena/connectors/db2/Db2MuxRecordHandler.java index da988d8f3b..94fbe8c395 100644 --- a/athena-db2/src/main/java/com/amazonaws/athena/connectors/db2/Db2MuxRecordHandler.java +++ b/athena-db2/src/main/java/com/amazonaws/athena/connectors/db2/Db2MuxRecordHandler.java @@ -24,9 +24,9 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandlerFactory; -import com.amazonaws.services.s3.AmazonS3; import com.google.common.annotations.VisibleForTesting; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -54,7 +54,7 @@ public Db2MuxRecordHandler(java.util.Map configOptions) } @VisibleForTesting - Db2MuxRecordHandler(AmazonS3 amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, + Db2MuxRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, DatabaseConnectionConfig databaseConnectionConfig, Map recordHandlerMap, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, jdbcConnectionFactory, databaseConnectionConfig, recordHandlerMap, configOptions); diff --git a/athena-db2/src/main/java/com/amazonaws/athena/connectors/db2/Db2RecordHandler.java b/athena-db2/src/main/java/com/amazonaws/athena/connectors/db2/Db2RecordHandler.java index 731d03b0e1..8e9941f220 100644 --- a/athena-db2/src/main/java/com/amazonaws/athena/connectors/db2/Db2RecordHandler.java +++ b/athena-db2/src/main/java/com/amazonaws/athena/connectors/db2/Db2RecordHandler.java @@ -29,12 +29,11 @@ import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; import com.google.common.annotations.VisibleForTesting; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.commons.lang3.Validate; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; @@ -59,13 +58,13 @@ public Db2RecordHandler(java.util.Map configOptions) */ public Db2RecordHandler(DatabaseConnectionConfig databaseConnectionConfig, java.util.Map configOptions) { - this(databaseConnectionConfig, AmazonS3ClientBuilder.defaultClient(), SecretsManagerClient.create(), AthenaClient.create(), + this(databaseConnectionConfig, S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(), new GenericJdbcConnectionFactory(databaseConnectionConfig, null, new DatabaseConnectionInfo(Db2Constants.DRIVER_CLASS, Db2Constants.DEFAULT_PORT)), new Db2QueryStringBuilder(QUOTE_CHARACTER, new Db2FederationExpressionParser(QUOTE_CHARACTER)), configOptions); } @VisibleForTesting - Db2RecordHandler(DatabaseConnectionConfig databaseConnectionConfig, AmazonS3 amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) + Db2RecordHandler(DatabaseConnectionConfig databaseConnectionConfig, S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, databaseConnectionConfig, jdbcConnectionFactory, configOptions); this.jdbcSplitQueryBuilder = Validate.notNull(jdbcSplitQueryBuilder, "query builder must not be null"); diff --git a/athena-db2/src/test/java/com/amazonaws/athena/connectors/db2/Db2RecordHandlerTest.java b/athena-db2/src/test/java/com/amazonaws/athena/connectors/db2/Db2RecordHandlerTest.java index 5daef72df5..b7de058f8d 100644 --- a/athena-db2/src/test/java/com/amazonaws/athena/connectors/db2/Db2RecordHandlerTest.java +++ b/athena-db2/src/test/java/com/amazonaws/athena/connectors/db2/Db2RecordHandlerTest.java @@ -31,7 +31,6 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.s3.AmazonS3; import com.google.common.collect.ImmutableMap; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Schema; @@ -40,6 +39,7 @@ import org.junit.Test; import org.mockito.Mockito; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; @@ -55,14 +55,14 @@ public class Db2RecordHandlerTest { private Connection connection; private JdbcConnectionFactory jdbcConnectionFactory; private JdbcSplitQueryBuilder jdbcSplitQueryBuilder; - private AmazonS3 amazonS3; + private S3Client amazonS3; private SecretsManagerClient secretsManager; private AthenaClient athena; @Before public void setup() throws Exception { System.setProperty("aws.region", "us-east-1"); - this.amazonS3 = Mockito.mock(AmazonS3.class); + this.amazonS3 = Mockito.mock(S3Client.class); this.secretsManager = Mockito.mock(SecretsManagerClient.class); this.athena = Mockito.mock(AthenaClient.class); this.connection = Mockito.mock(Connection.class); diff --git a/athena-docdb/src/main/java/com/amazonaws/athena/connectors/docdb/DocDBRecordHandler.java b/athena-docdb/src/main/java/com/amazonaws/athena/connectors/docdb/DocDBRecordHandler.java index 0b103a82b2..4b0459f57e 100644 --- a/athena-docdb/src/main/java/com/amazonaws/athena/connectors/docdb/DocDBRecordHandler.java +++ b/athena-docdb/src/main/java/com/amazonaws/athena/connectors/docdb/DocDBRecordHandler.java @@ -28,8 +28,6 @@ import com.amazonaws.athena.connector.lambda.handlers.RecordHandler; import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest; import com.amazonaws.athena.connectors.docdb.qpt.DocDBQueryPassthrough; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; import com.mongodb.client.MongoClient; import com.mongodb.client.MongoCollection; import com.mongodb.client.MongoCursor; @@ -41,6 +39,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -79,7 +78,7 @@ public class DocDBRecordHandler public DocDBRecordHandler(java.util.Map configOptions) { this( - AmazonS3ClientBuilder.defaultClient(), + S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(), new DocDBConnectionFactory(), @@ -87,7 +86,7 @@ public DocDBRecordHandler(java.util.Map configOptions) } @VisibleForTesting - protected DocDBRecordHandler(AmazonS3 amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, DocDBConnectionFactory connectionFactory, java.util.Map configOptions) + protected DocDBRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, DocDBConnectionFactory connectionFactory, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, SOURCE_TYPE, configOptions); this.connectionFactory = connectionFactory; diff --git a/athena-docdb/src/test/java/com/amazonaws/athena/connectors/docdb/DocDBRecordHandlerTest.java b/athena-docdb/src/test/java/com/amazonaws/athena/connectors/docdb/DocDBRecordHandlerTest.java index 93d6e63c30..866bc1ac41 100644 --- a/athena-docdb/src/test/java/com/amazonaws/athena/connectors/docdb/DocDBRecordHandlerTest.java +++ b/athena-docdb/src/test/java/com/amazonaws/athena/connectors/docdb/DocDBRecordHandlerTest.java @@ -40,11 +40,6 @@ import com.amazonaws.athena.connector.lambda.records.RemoteReadRecordsResponse; import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory; import com.amazonaws.athena.connector.lambda.security.LocalKeyFactory; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.PutObjectRequest; -import com.amazonaws.services.s3.model.PutObjectResult; -import com.amazonaws.services.s3.model.S3Object; -import com.amazonaws.services.s3.model.S3ObjectInputStream; import com.google.common.collect.ImmutableList; import com.google.common.io.ByteStreams; import com.mongodb.client.FindIterable; @@ -72,6 +67,14 @@ import software.amazon.awssdk.services.glue.GlueClient; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; + import java.io.ByteArrayInputStream; import java.io.InputStream; import java.util.ArrayList; @@ -100,7 +103,7 @@ public class DocDBRecordHandlerTest private DocDBRecordHandler handler; private BlockAllocator allocator; private List mockS3Storage = new ArrayList<>(); - private AmazonS3 amazonS3; + private S3Client amazonS3; private S3BlockSpillReader spillReader; private Schema schemaForRead; private EncryptionKeyFactory keyFactory = new LocalKeyFactory(); @@ -171,7 +174,7 @@ public void setUp() allocator = new BlockAllocatorImpl(); - amazonS3 = mock(AmazonS3.class); + amazonS3 = mock(S3Client.class); mockDatabase = mock(MongoDatabase.class); mockCollection = mock(MongoCollection.class); mockIterable = mock(FindIterable.class); @@ -179,31 +182,27 @@ public void setUp() when(mockClient.getDatabase(eq(DEFAULT_SCHEMA))).thenReturn(mockDatabase); when(mockDatabase.getCollection(eq(TEST_TABLE))).thenReturn(mockCollection); - when(amazonS3.putObject(any())) + when(amazonS3.putObject(any(PutObjectRequest.class), any(RequestBody.class))) .thenAnswer((InvocationOnMock invocationOnMock) -> { - InputStream inputStream = ((PutObjectRequest) invocationOnMock.getArguments()[0]).getInputStream(); + InputStream inputStream = ((RequestBody) invocationOnMock.getArguments()[1]).contentStreamProvider().newStream(); ByteHolder byteHolder = new ByteHolder(); byteHolder.setBytes(ByteStreams.toByteArray(inputStream)); synchronized (mockS3Storage) { mockS3Storage.add(byteHolder); logger.info("puObject: total size " + mockS3Storage.size()); } - return mock(PutObjectResult.class); + return PutObjectResponse.builder().build(); }); - when(amazonS3.getObject(nullable(String.class), nullable(String.class))) + when(amazonS3.getObject(any(GetObjectRequest.class))) .thenAnswer((InvocationOnMock invocationOnMock) -> { - S3Object mockObject = mock(S3Object.class); ByteHolder byteHolder; synchronized (mockS3Storage) { byteHolder = mockS3Storage.get(0); mockS3Storage.remove(0); logger.info("getObject: total size " + mockS3Storage.size()); } - when(mockObject.getObjectContent()).thenReturn( - new S3ObjectInputStream( - new ByteArrayInputStream(byteHolder.getBytes()), null)); - return mockObject; + return new ResponseInputStream<>(GetObjectResponse.builder().build(), new ByteArrayInputStream(byteHolder.getBytes())); }); handler = new DocDBRecordHandler(amazonS3, mockSecretsManager, mockAthena, connectionFactory, com.google.common.collect.ImmutableMap.of()); diff --git a/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/DynamoDBRecordHandler.java b/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/DynamoDBRecordHandler.java index 33b467d2b8..42c38478a0 100644 --- a/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/DynamoDBRecordHandler.java +++ b/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/DynamoDBRecordHandler.java @@ -35,7 +35,6 @@ import com.amazonaws.athena.connectors.dynamodb.util.DDBPredicateUtils; import com.amazonaws.athena.connectors.dynamodb.util.DDBRecordMetadata; import com.amazonaws.athena.connectors.dynamodb.util.DDBTypeUtils; -import com.amazonaws.services.s3.AmazonS3; import com.amazonaws.util.json.Jackson; import com.fasterxml.jackson.core.type.TypeReference; import com.google.common.annotations.VisibleForTesting; @@ -56,6 +55,7 @@ import software.amazon.awssdk.services.dynamodb.model.QueryResponse; import software.amazon.awssdk.services.dynamodb.model.ScanRequest; import software.amazon.awssdk.services.dynamodb.model.ScanResponse; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.io.IOException; @@ -128,7 +128,7 @@ public ThrottlingInvoker load(String tableName) } @VisibleForTesting - DynamoDBRecordHandler(DynamoDbClient ddbClient, AmazonS3 amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, String sourceType, java.util.Map configOptions) + DynamoDBRecordHandler(DynamoDbClient ddbClient, S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, String sourceType, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, sourceType, configOptions); this.ddbClient = ddbClient; diff --git a/athena-dynamodb/src/test/java/com/amazonaws/athena/connectors/dynamodb/DynamoDBRecordHandlerTest.java b/athena-dynamodb/src/test/java/com/amazonaws/athena/connectors/dynamodb/DynamoDBRecordHandlerTest.java index 878542ad5c..9972e3fc0f 100644 --- a/athena-dynamodb/src/test/java/com/amazonaws/athena/connectors/dynamodb/DynamoDBRecordHandlerTest.java +++ b/athena-dynamodb/src/test/java/com/amazonaws/athena/connectors/dynamodb/DynamoDBRecordHandlerTest.java @@ -38,7 +38,6 @@ import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory; import com.amazonaws.athena.connector.lambda.security.LocalKeyFactory; import com.amazonaws.athena.connectors.dynamodb.util.DDBTypeUtils; -import com.amazonaws.services.s3.AmazonS3; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.apache.arrow.vector.complex.impl.UnionListReader; @@ -66,6 +65,7 @@ import software.amazon.awssdk.services.glue.model.EntityNotFoundException; import software.amazon.awssdk.services.glue.model.StorageDescriptor; import software.amazon.awssdk.services.glue.model.Table; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.time.LocalDate; @@ -136,7 +136,7 @@ public void setup() logger.info("{}: enter", testName.getMethodName()); allocator = new BlockAllocatorImpl(); - handler = new DynamoDBRecordHandler(ddbClient, mock(AmazonS3.class), mock(SecretsManagerClient.class), mock(AthenaClient.class), "source_type", com.google.common.collect.ImmutableMap.of()); + handler = new DynamoDBRecordHandler(ddbClient, mock(S3Client.class), mock(SecretsManagerClient.class), mock(AthenaClient.class), "source_type", com.google.common.collect.ImmutableMap.of()); metadataHandler = new DynamoDBMetadataHandler(new LocalKeyFactory(), secretsManager, athena, "spillBucket", "spillPrefix", ddbClient, glueClient, com.google.common.collect.ImmutableMap.of()); } diff --git a/athena-elasticsearch/pom.xml b/athena-elasticsearch/pom.xml index 7e9748b077..3ef358cfcc 100644 --- a/athena-elasticsearch/pom.xml +++ b/athena-elasticsearch/pom.xml @@ -62,33 +62,6 @@ ${log4j2Version} runtime - - com.amazonaws - aws-java-sdk-core - ${aws-sdk.version} - - - com.fasterxml.jackson.datatype - jackson-datatype-jsr310 - - - com.fasterxml.jackson.dataformat - jackson-dataformat-cbor - - - com.fasterxml.jackson.core - jackson-core - - - com.fasterxml.jackson.core - jackson-databind - - - com.fasterxml.jackson.core - jackson-annotations - - - software.amazon.jsii jsii-runtime diff --git a/athena-elasticsearch/src/main/java/com/amazonaws/athena/connectors/elasticsearch/ElasticsearchRecordHandler.java b/athena-elasticsearch/src/main/java/com/amazonaws/athena/connectors/elasticsearch/ElasticsearchRecordHandler.java index facfe16209..1d90956ad1 100644 --- a/athena-elasticsearch/src/main/java/com/amazonaws/athena/connectors/elasticsearch/ElasticsearchRecordHandler.java +++ b/athena-elasticsearch/src/main/java/com/amazonaws/athena/connectors/elasticsearch/ElasticsearchRecordHandler.java @@ -27,8 +27,6 @@ import com.amazonaws.athena.connector.lambda.handlers.RecordHandler; import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest; import com.amazonaws.athena.connectors.elasticsearch.qpt.ElasticsearchQueryPassthrough; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; import org.apache.arrow.util.VisibleForTesting; import org.apache.arrow.vector.types.pojo.Field; import org.elasticsearch.action.search.ClearScrollRequest; @@ -45,6 +43,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.io.IOException; @@ -89,7 +88,7 @@ public class ElasticsearchRecordHandler public ElasticsearchRecordHandler(Map configOptions) { - super(AmazonS3ClientBuilder.defaultClient(), SecretsManagerClient.create(), + super(S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(), SOURCE_TYPE, configOptions); this.typeUtils = new ElasticsearchTypeUtils(); @@ -100,7 +99,7 @@ public ElasticsearchRecordHandler(Map configOptions) @VisibleForTesting protected ElasticsearchRecordHandler( - AmazonS3 amazonS3, + S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient amazonAthena, AwsRestHighLevelClientFactory clientFactory, diff --git a/athena-elasticsearch/src/test/java/com/amazonaws/athena/connectors/elasticsearch/ElasticsearchRecordHandlerTest.java b/athena-elasticsearch/src/test/java/com/amazonaws/athena/connectors/elasticsearch/ElasticsearchRecordHandlerTest.java index 99c641d32f..1336badd71 100644 --- a/athena-elasticsearch/src/test/java/com/amazonaws/athena/connectors/elasticsearch/ElasticsearchRecordHandlerTest.java +++ b/athena-elasticsearch/src/test/java/com/amazonaws/athena/connectors/elasticsearch/ElasticsearchRecordHandlerTest.java @@ -37,11 +37,6 @@ import com.amazonaws.athena.connector.lambda.records.RecordResponse; import com.amazonaws.athena.connector.lambda.records.RemoteReadRecordsResponse; import com.amazonaws.athena.connector.lambda.security.FederatedIdentity; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.PutObjectRequest; -import com.amazonaws.services.s3.model.PutObjectResult; -import com.amazonaws.services.s3.model.S3Object; -import com.amazonaws.services.s3.model.S3ObjectInputStream; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -68,6 +63,14 @@ import software.amazon.awssdk.services.athena.AthenaClient; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; + import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; @@ -116,7 +119,7 @@ public class ElasticsearchRecordHandlerTest private SearchResponse mockScrollResponse; @Mock - private AmazonS3 amazonS3; + private S3Client amazonS3; @Mock private SecretsManagerClient awsSecretsManager; @@ -124,12 +127,6 @@ public class ElasticsearchRecordHandlerTest @Mock private AthenaClient athena; - @Mock - PutObjectResult putObjectResult; - - @Mock - S3Object s3Object; - String[] expectedDocuments = {"[mytext : My favorite Sci-Fi movie is Interstellar.], [mykeyword : I love keywords.], [mylong : {11,12,13}], [myinteger : 666115], [myshort : 1972], [mybyte : 5], [mydouble : 47.5], [myscaled : 7], [myfloat : 5.6], [myhalf : 6.2], [mydatemilli : 2020-05-15T06:49:30], [mydatenano : {2020-05-15T06:50:01.457}], [myboolean : true], [mybinary : U29tZSBiaW5hcnkgYmxvYg==], [mynested : {[l1long : 357345987],[l1date : 2020-05-15T06:57:44.123],[l1nested : {[l2short : {1,2,3,4,5,6,7,8,9,10}],[l2binary : U29tZSBiaW5hcnkgYmxvYg==]}]}], [objlistouter : {}]" ,"[mytext : My favorite TV comedy is Seinfeld.], [mykeyword : I hate key-values.], [mylong : {14,null,16}], [myinteger : 732765666], [myshort : 1971], [mybyte : 7], [mydouble : 27.6], [myscaled : 10], [myfloat : 7.8], [myhalf : 7.3], [mydatemilli : null], [mydatenano : {2020-05-15T06:49:30.001}], [myboolean : false], [mybinary : U29tZSBiaW5hcnkgYmxvYg==], [mynested : {[l1long : 7322775555],[l1date : 2020-05-15T01:57:44.777],[l1nested : {[l2short : {11,12,13,14,15,16,null,18,19,20}],[l2binary : U29tZSBiaW5hcnkgYmxvYg==]}]}], [objlistouter : {{[objlistinner : {{[title : somebook],[hi : hi]}}],[test2 : title]}}]"}; @@ -276,31 +273,27 @@ public void setUp() allocator = new BlockAllocatorImpl(); - when(amazonS3.putObject(any())) + when(amazonS3.putObject(any(PutObjectRequest.class), any(RequestBody.class))) .thenAnswer((InvocationOnMock invocationOnMock) -> { - InputStream inputStream = ((PutObjectRequest) invocationOnMock.getArguments()[0]).getInputStream(); + InputStream inputStream = ((RequestBody) invocationOnMock.getArguments()[1]).contentStreamProvider().newStream(); ByteHolder byteHolder = new ByteHolder(); byteHolder.setBytes(ByteStreams.toByteArray(inputStream)); synchronized (mockS3Storage) { mockS3Storage.add(byteHolder); logger.info("puObject: total size " + mockS3Storage.size()); } - return mock(PutObjectResult.class); + return PutObjectResponse.builder().build(); }); - when(amazonS3.getObject(nullable(String.class), nullable(String.class))) + when(amazonS3.getObject(any(GetObjectRequest.class))) .thenAnswer((InvocationOnMock invocationOnMock) -> { - S3Object mockObject = mock(S3Object.class); ByteHolder byteHolder; synchronized (mockS3Storage) { byteHolder = mockS3Storage.get(0); mockS3Storage.remove(0); logger.info("getObject: total size " + mockS3Storage.size()); } - when(mockObject.getObjectContent()).thenReturn( - new S3ObjectInputStream( - new ByteArrayInputStream(byteHolder.getBytes()), null)); - return mockObject; + return new ResponseInputStream<>(GetObjectResponse.builder().build(), new ByteArrayInputStream(byteHolder.getBytes())); }); spillReader = new S3BlockSpillReader(amazonS3, allocator); diff --git a/athena-example/src/main/java/com/amazonaws/athena/connectors/example/ExampleRecordHandler.java b/athena-example/src/main/java/com/amazonaws/athena/connectors/example/ExampleRecordHandler.java index 87f3064d92..402420e0bb 100644 --- a/athena-example/src/main/java/com/amazonaws/athena/connectors/example/ExampleRecordHandler.java +++ b/athena-example/src/main/java/com/amazonaws/athena/connectors/example/ExampleRecordHandler.java @@ -34,17 +34,17 @@ import com.amazonaws.athena.connector.lambda.domain.predicate.ConstraintProjector; import com.amazonaws.athena.connector.lambda.handlers.RecordHandler; import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.s3.model.S3Object; import org.apache.arrow.util.VisibleForTesting; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.holders.NullableIntHolder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; - import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.NoSuchKeyException; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.io.BufferedReader; @@ -77,15 +77,15 @@ public class ExampleRecordHandler */ private static final String SOURCE_TYPE = "example"; - private AmazonS3 amazonS3; + private S3Client amazonS3; public ExampleRecordHandler(java.util.Map configOptions) { - this(AmazonS3ClientBuilder.defaultClient(), SecretsManagerClient.create(), AthenaClient.create(), configOptions); + this(S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(), configOptions); } @VisibleForTesting - protected ExampleRecordHandler(AmazonS3 amazonS3, SecretsManagerClient secretsManager, AthenaClient amazonAthena, java.util.Map configOptions) + protected ExampleRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient amazonAthena, java.util.Map configOptions) { super(amazonS3, secretsManager, amazonAthena, SOURCE_TYPE, configOptions); this.amazonS3 = amazonS3; @@ -230,10 +230,13 @@ protected void readWithConstraint(BlockSpiller spiller, ReadRecordsRequest recor private BufferedReader openS3File(String bucket, String key) { logger.info("openS3File: opening file " + bucket + ":" + key); - if (amazonS3.doesObjectExist(bucket, key)) { - S3Object obj = amazonS3.getObject(bucket, key); + try { + ResponseInputStream responseStream = amazonS3.getObject(GetObjectRequest.builder().bucket(bucket).key(key).build()); logger.info("openS3File: opened file " + bucket + ":" + key); - return new BufferedReader(new InputStreamReader(obj.getObjectContent())); + return new BufferedReader(new InputStreamReader(responseStream)); + } + catch (NoSuchKeyException e) { + logger.error("openS3File: failed to open file " + bucket + ":" + key, e); } return null; } diff --git a/athena-example/src/test/java/com/amazonaws/athena/connectors/example/ExampleRecordHandlerTest.java b/athena-example/src/test/java/com/amazonaws/athena/connectors/example/ExampleRecordHandlerTest.java index c69dc6b3df..de2b30524b 100644 --- a/athena-example/src/test/java/com/amazonaws/athena/connectors/example/ExampleRecordHandlerTest.java +++ b/athena-example/src/test/java/com/amazonaws/athena/connectors/example/ExampleRecordHandlerTest.java @@ -33,9 +33,6 @@ import com.amazonaws.athena.connector.lambda.records.ReadRecordsResponse; import com.amazonaws.athena.connector.lambda.records.RecordResponse; import com.amazonaws.athena.connector.lambda.security.FederatedIdentity; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.S3Object; -import com.amazonaws.services.s3.model.S3ObjectInputStream; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.After; @@ -48,6 +45,10 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.io.ByteArrayInputStream; @@ -59,6 +60,7 @@ import static com.amazonaws.athena.connector.lambda.domain.predicate.Constraints.DEFAULT_NO_LIMIT; import static org.junit.Assert.*; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.nullable; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -72,7 +74,7 @@ public class ExampleRecordHandlerTest System.getenv("publishing").equalsIgnoreCase("true"); private BlockAllocatorImpl allocator; private Schema schemaForRead; - private AmazonS3 amazonS3; + private S3Client amazonS3; private SecretsManagerClient awsSecretsManager; private AthenaClient athena; private S3BlockSpillReader spillReader; @@ -105,23 +107,18 @@ public void setUp() allocator = new BlockAllocatorImpl(); - amazonS3 = mock(AmazonS3.class); + amazonS3 = mock(S3Client.class); awsSecretsManager = mock(SecretsManagerClient.class); athena = mock(AthenaClient.class); - when(amazonS3.doesObjectExist(nullable(String.class), nullable(String.class))).thenReturn(true); - when(amazonS3.getObject(nullable(String.class), nullable(String.class))) + when(amazonS3.getObject(any(GetObjectRequest.class))) .thenAnswer(new Answer() { @Override public Object answer(InvocationOnMock invocationOnMock) throws Throwable { - S3Object mockObject = mock(S3Object.class); - when(mockObject.getObjectContent()).thenReturn( - new S3ObjectInputStream( - new ByteArrayInputStream(getFakeObject()), null)); - return mockObject; + return new ResponseInputStream<>(GetObjectResponse.builder().build(), new ByteArrayInputStream(getFakeObject())); } }); diff --git a/athena-federation-sdk/pom.xml b/athena-federation-sdk/pom.xml index 43fb9f3388..c532d70ef0 100644 --- a/athena-federation-sdk/pom.xml +++ b/athena-federation-sdk/pom.xml @@ -80,6 +80,66 @@ + + software.amazon.awssdk + apache-client + ${aws-sdk-v2.version} + + + software.amazon.awssdk + athena + ${aws-sdk-v2.version} + + + software.amazon.awssdk + netty-nio-client + + + + + software.amazon.awssdk + glue + ${aws-sdk-v2.version} + + + software.amazon.awssdk + netty-nio-client + + + + + software.amazon.awssdk + kms + ${aws-sdk-v2.version} + + + software.amazon.awssdk + netty-nio-client + + + + + software.amazon.awssdk + lambda + ${aws-sdk-v2.version} + + + software.amazon.awssdk + netty-nio-client + + + + + software.amazon.awssdk + s3 + ${aws-sdk-v2.version} + + + software.amazon.awssdk + netty-nio-client + + + software.amazon.awssdk secretsmanager @@ -121,29 +181,6 @@ com.fasterxml.jackson.core jackson-annotations - - - - software.amazon.awssdk - glue - ${aws-sdk-v2.version} - - - software.amazon.awssdk - netty-nio-client - - - - - software.amazon.awssdk - apache-client - ${aws-sdk-v2.version} - - - software.amazon.awssdk - athena - ${aws-sdk-v2.version} - software.amazon.awssdk netty-nio-client @@ -192,40 +229,6 @@ aws-lambda-java-core 1.2.3 - - software.amazon.awssdk - lambda - ${aws-sdk-v2.version} - - - software.amazon.awssdk - netty-nio-client - - - - - com.amazonaws - aws-java-sdk-s3 - ${aws-sdk.version} - - - - com.amazonaws - aws-java-sdk-kms - - - - - software.amazon.awssdk - kms - ${aws-sdk-v2.version} - - - software.amazon.awssdk - netty-nio-client - - - com.google.guava guava diff --git a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/data/BlockUtils.java b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/data/BlockUtils.java index dfac5b00d6..268dad7ddb 100644 --- a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/data/BlockUtils.java +++ b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/data/BlockUtils.java @@ -77,6 +77,7 @@ import java.math.BigDecimal; import java.math.RoundingMode; +import java.time.Instant; import java.time.LocalDate; import java.time.LocalDateTime; import java.time.ZoneId; @@ -273,6 +274,9 @@ else if (value instanceof LocalDateTime) { pos, ((LocalDateTime) value).atZone(UTC_ZONE_ID).toInstant().toEpochMilli()); } + else if (value instanceof Instant) { + ((DateMilliVector) vector).setSafe(pos, ((Instant) value).toEpochMilli()); + } else { ((DateMilliVector) vector).setSafe(pos, (long) value); } diff --git a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/data/S3BlockSpillReader.java b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/data/S3BlockSpillReader.java index 48806b99dc..6415484b41 100644 --- a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/data/S3BlockSpillReader.java +++ b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/data/S3BlockSpillReader.java @@ -25,12 +25,14 @@ import com.amazonaws.athena.connector.lambda.security.BlockCrypto; import com.amazonaws.athena.connector.lambda.security.EncryptionKey; import com.amazonaws.athena.connector.lambda.security.NoOpBlockCrypto; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.S3Object; import com.google.common.io.ByteStreams; import org.apache.arrow.vector.types.pojo.Schema; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; import java.io.IOException; @@ -40,10 +42,10 @@ public class S3BlockSpillReader { private static final Logger logger = LoggerFactory.getLogger(S3BlockSpillReader.class); - private final AmazonS3 amazonS3; + private final S3Client amazonS3; private final BlockAllocator allocator; - public S3BlockSpillReader(AmazonS3 amazonS3, BlockAllocator allocator) + public S3BlockSpillReader(S3Client amazonS3, BlockAllocator allocator) { this.amazonS3 = requireNonNull(amazonS3, "amazonS3 was null"); this.allocator = requireNonNull(allocator, "allocator was null"); @@ -59,13 +61,16 @@ public S3BlockSpillReader(AmazonS3 amazonS3, BlockAllocator allocator) */ public Block read(S3SpillLocation spillLocation, EncryptionKey key, Schema schema) { - S3Object fullObject = null; + ResponseInputStream responseStream = null; try { logger.debug("read: Started reading block from S3"); - fullObject = amazonS3.getObject(spillLocation.getBucket(), spillLocation.getKey()); + responseStream = amazonS3.getObject(GetObjectRequest.builder() + .bucket(spillLocation.getBucket()) + .key(spillLocation.getKey()) + .build()); logger.debug("read: Completed reading block from S3"); BlockCrypto blockCrypto = (key != null) ? new AesGcmBlockCrypto(allocator) : new NoOpBlockCrypto(allocator); - Block block = blockCrypto.decrypt(key, ByteStreams.toByteArray(fullObject.getObjectContent()), schema); + Block block = blockCrypto.decrypt(key, ByteStreams.toByteArray(responseStream), schema); logger.debug("read: Completed decrypting block of size."); return block; } @@ -73,12 +78,12 @@ public Block read(S3SpillLocation spillLocation, EncryptionKey key, Schema schem throw new RuntimeException(ex); } finally { - if (fullObject != null) { + if (responseStream != null) { try { - fullObject.close(); + responseStream.close(); } catch (IOException ex) { - logger.warn("read: Exception while closing S3 object", ex); + logger.warn("read: Exception while closing S3 response stream", ex); } } } @@ -93,24 +98,27 @@ public Block read(S3SpillLocation spillLocation, EncryptionKey key, Schema schem */ public byte[] read(S3SpillLocation spillLocation, EncryptionKey key) { - S3Object fullObject = null; + ResponseInputStream responseStream = null; try { logger.debug("read: Started reading block from S3"); - fullObject = amazonS3.getObject(spillLocation.getBucket(), spillLocation.getKey()); + responseStream = amazonS3.getObject(GetObjectRequest.builder() + .bucket(spillLocation.getBucket()) + .key(spillLocation.getKey()) + .build()); logger.debug("read: Completed reading block from S3"); BlockCrypto blockCrypto = (key != null) ? new AesGcmBlockCrypto(allocator) : new NoOpBlockCrypto(allocator); - return blockCrypto.decrypt(key, ByteStreams.toByteArray(fullObject.getObjectContent())); + return blockCrypto.decrypt(key, ByteStreams.toByteArray(responseStream)); } catch (IOException ex) { throw new RuntimeException(ex); } finally { - if (fullObject != null) { + if (responseStream != null) { try { - fullObject.close(); + responseStream.close(); } catch (IOException ex) { - logger.warn("read: Exception while closing S3 object", ex); + logger.warn("read: Exception while closing S3 response stream", ex); } } } diff --git a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/data/S3BlockSpiller.java b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/data/S3BlockSpiller.java index 2604b5e228..de879feafd 100644 --- a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/data/S3BlockSpiller.java +++ b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/data/S3BlockSpiller.java @@ -27,10 +27,6 @@ import com.amazonaws.athena.connector.lambda.security.BlockCrypto; import com.amazonaws.athena.connector.lambda.security.EncryptionKey; import com.amazonaws.athena.connector.lambda.security.NoOpBlockCrypto; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.ObjectMetadata; -import com.amazonaws.services.s3.model.PutObjectRequest; -import com.amazonaws.services.s3.model.S3Object; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.io.ByteStreams; @@ -38,10 +34,16 @@ import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; -import java.io.ByteArrayInputStream; import java.io.IOException; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.concurrent.ExecutorService; @@ -76,7 +78,7 @@ public class S3BlockSpiller private static final String SPILL_PUT_REQUEST_HEADERS_ENV = "spill_put_request_headers"; //Used to write to S3 - private final AmazonS3 amazonS3; + private final S3Client amazonS3; //Used to optionally encrypt Blocks. private final BlockCrypto blockCrypto; //Used to create new blocks. @@ -125,7 +127,7 @@ public class S3BlockSpiller * @param constraintEvaluator The ConstraintEvaluator that should be used to constrain writes. */ public S3BlockSpiller( - AmazonS3 amazonS3, + S3Client amazonS3, SpillConfig spillConfig, BlockAllocator allocator, Schema schema, @@ -146,7 +148,7 @@ public S3BlockSpiller( * @param maxRowsPerCall The max number of rows to allow callers to write in one call. */ public S3BlockSpiller( - AmazonS3 amazonS3, + S3Client amazonS3, SpillConfig spillConfig, BlockAllocator allocator, Schema schema, @@ -318,29 +320,24 @@ public void close() /** * Grabs the request headers from env and sets them on the request */ - private void setRequestHeadersFromEnv(PutObjectRequest request) + private Map getRequestHeadersFromEnv() { String headersFromEnvStr = configOptions.get(SPILL_PUT_REQUEST_HEADERS_ENV); if (headersFromEnvStr == null || headersFromEnvStr.isEmpty()) { - return; + return Collections.emptyMap(); } try { ObjectMapper mapper = new ObjectMapper(); TypeReference> typeRef = new TypeReference>() {}; Map headers = mapper.readValue(headersFromEnvStr, typeRef); - for (Map.Entry entry : headers.entrySet()) { - String oldValue = request.putCustomRequestHeader(entry.getKey(), entry.getValue()); - if (oldValue != null) { - logger.warn("Key: %s has been overwritten with: %s. Old value: %s", - entry.getKey(), entry.getValue(), oldValue); - } - } + return headers; } catch (com.fasterxml.jackson.core.JsonProcessingException e) { String message = String.format("Invalid value for environment variable: %s : %s", SPILL_PUT_REQUEST_HEADERS_ENV, headersFromEnvStr); logger.error(message, e); } + return Collections.emptyMap(); } /** @@ -361,15 +358,13 @@ protected SpillLocation write(Block block) // Set the contentLength otherwise the s3 client will buffer again since it // only sees the InputStream wrapper. - ObjectMetadata objMeta = new ObjectMetadata(); - objMeta.setContentLength(bytes.length); - PutObjectRequest request = new PutObjectRequest( - spillLocation.getBucket(), - spillLocation.getKey(), - new ByteArrayInputStream(bytes), - objMeta); - setRequestHeadersFromEnv(request); - amazonS3.putObject(request); + PutObjectRequest request = PutObjectRequest.builder() + .bucket(spillLocation.getBucket()) + .key(spillLocation.getKey()) + .contentLength((long) bytes.length) + .metadata(getRequestHeadersFromEnv()) + .build(); + amazonS3.putObject(request, RequestBody.fromBytes(bytes)); logger.info("write: Completed spilling block of size {} bytes", bytes.length); return spillLocation; @@ -393,9 +388,12 @@ protected Block read(S3SpillLocation spillLocation, EncryptionKey key, Schema sc { try { logger.debug("write: Started reading block from S3"); - S3Object fullObject = amazonS3.getObject(spillLocation.getBucket(), spillLocation.getKey()); + ResponseInputStream responseStream = amazonS3.getObject(GetObjectRequest.builder() + .bucket(spillLocation.getBucket()) + .key(spillLocation.getKey()) + .build()); logger.debug("write: Completed reading block from S3"); - Block block = blockCrypto.decrypt(key, ByteStreams.toByteArray(fullObject.getObjectContent()), schema); + Block block = blockCrypto.decrypt(key, ByteStreams.toByteArray(responseStream), schema); logger.debug("write: Completed decrypting block of size."); return block; } diff --git a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/domain/spill/SpillLocationVerifier.java b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/domain/spill/SpillLocationVerifier.java index f15040812a..3dbb3e62ef 100644 --- a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/domain/spill/SpillLocationVerifier.java +++ b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/domain/spill/SpillLocationVerifier.java @@ -20,11 +20,11 @@ * #L% */ -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.AmazonS3Exception; import com.google.common.annotations.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.S3Exception; import java.util.Set; import java.util.stream.Collectors; @@ -39,14 +39,14 @@ public class SpillLocationVerifier private enum BucketState {UNCHECKED, VALID, INVALID} - private final AmazonS3 amazons3; + private final S3Client amazons3; private String bucket; private BucketState state; /** * @param amazons3 The S3 object for the account. */ - public SpillLocationVerifier(AmazonS3 amazons3) + public SpillLocationVerifier(S3Client amazons3) { this.amazons3 = amazons3; this.bucket = null; @@ -85,7 +85,7 @@ public void checkBucketAuthZ(String spillBucket) void updateBucketState() { try { - Set buckets = amazons3.listBuckets().stream().map(b -> b.getName()).collect(Collectors.toSet()); + Set buckets = amazons3.listBuckets().buckets().stream().map(b -> b.name()).collect(Collectors.toSet()); if (!buckets.contains(bucket)) { state = BucketState.INVALID; @@ -96,7 +96,7 @@ void updateBucketState() logger.info("The state of bucket {} has been updated to {} from {}", bucket, state, BucketState.UNCHECKED); } - catch (AmazonS3Exception ex) { + catch (S3Exception ex) { throw new RuntimeException("Error while checking bucket ownership for " + bucket, ex); } } diff --git a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/MetadataHandler.java b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/MetadataHandler.java index 88c00b461f..0810ba64b1 100644 --- a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/MetadataHandler.java +++ b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/MetadataHandler.java @@ -60,7 +60,6 @@ import com.amazonaws.athena.connector.lambda.serde.VersionedObjectMapperFactory; import com.amazonaws.services.lambda.runtime.Context; import com.amazonaws.services.lambda.runtime.RequestStreamHandler; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; import com.fasterxml.jackson.databind.ObjectMapper; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Field; @@ -69,6 +68,7 @@ import org.slf4j.LoggerFactory; import software.amazon.awssdk.services.athena.AthenaClient; import software.amazon.awssdk.services.kms.KmsClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.io.IOException; @@ -151,7 +151,7 @@ public MetadataHandler(String sourceType, java.util.Map configOp this.secretsManager = new CachableSecretsManager(SecretsManagerClient.create()); this.athena = AthenaClient.create(); - this.verifier = new SpillLocationVerifier(AmazonS3ClientBuilder.standard().build()); + this.verifier = new SpillLocationVerifier(S3Client.create()); this.athenaInvoker = ThrottlingInvoker.newDefaultBuilder(ATHENA_EXCEPTION_FILTER, configOptions).build(); } @@ -174,7 +174,7 @@ public MetadataHandler( this.sourceType = sourceType; this.spillBucket = spillBucket; this.spillPrefix = spillPrefix; - this.verifier = new SpillLocationVerifier(AmazonS3ClientBuilder.standard().build()); + this.verifier = new SpillLocationVerifier(S3Client.create()); this.athenaInvoker = ThrottlingInvoker.newDefaultBuilder(ATHENA_EXCEPTION_FILTER, configOptions).build(); } diff --git a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/RecordHandler.java b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/RecordHandler.java index 1ac7a85645..ac3e563005 100644 --- a/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/RecordHandler.java +++ b/athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/RecordHandler.java @@ -42,12 +42,11 @@ import com.amazonaws.athena.connector.lambda.serde.VersionedObjectMapperFactory; import com.amazonaws.services.lambda.runtime.Context; import com.amazonaws.services.lambda.runtime.RequestStreamHandler; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; import com.fasterxml.jackson.databind.ObjectMapper; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.io.IOException; @@ -69,7 +68,7 @@ public abstract class RecordHandler private static final String MAX_BLOCK_SIZE_BYTES = "MAX_BLOCK_SIZE_BYTES"; private static final int NUM_SPILL_THREADS = 2; protected final java.util.Map configOptions; - private final AmazonS3 amazonS3; + private final S3Client amazonS3; private final String sourceType; private final CachableSecretsManager secretsManager; private final AthenaClient athena; @@ -81,7 +80,7 @@ public abstract class RecordHandler public RecordHandler(String sourceType, java.util.Map configOptions) { this.sourceType = sourceType; - this.amazonS3 = AmazonS3ClientBuilder.defaultClient(); + this.amazonS3 = S3Client.create(); this.secretsManager = new CachableSecretsManager(SecretsManagerClient.create()); this.athena = AthenaClient.create(); this.configOptions = configOptions; @@ -91,7 +90,7 @@ public RecordHandler(String sourceType, java.util.Map configOpti /** * @param sourceType Used to aid in logging diagnostic info when raising a support case. */ - public RecordHandler(AmazonS3 amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, String sourceType, java.util.Map configOptions) + public RecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, String sourceType, java.util.Map configOptions) { this.sourceType = sourceType; this.amazonS3 = amazonS3; diff --git a/athena-federation-sdk/src/test/java/com/amazonaws/athena/connector/lambda/data/S3BlockSpillerTest.java b/athena-federation-sdk/src/test/java/com/amazonaws/athena/connector/lambda/data/S3BlockSpillerTest.java index 0c9de56318..0abc45c3ec 100644 --- a/athena-federation-sdk/src/test/java/com/amazonaws/athena/connector/lambda/data/S3BlockSpillerTest.java +++ b/athena-federation-sdk/src/test/java/com/amazonaws/athena/connector/lambda/data/S3BlockSpillerTest.java @@ -25,11 +25,6 @@ import com.amazonaws.athena.connector.lambda.domain.spill.SpillLocation; import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory; import com.amazonaws.athena.connector.lambda.security.LocalKeyFactory; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.PutObjectRequest; -import com.amazonaws.services.s3.model.PutObjectResult; -import com.amazonaws.services.s3.model.S3Object; -import com.amazonaws.services.s3.model.S3ObjectInputStream; import com.google.common.io.ByteStreams; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Schema; @@ -44,6 +39,13 @@ import org.mockito.stubbing.Answer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; import java.io.ByteArrayInputStream; import java.io.IOException; @@ -51,8 +53,6 @@ import static org.junit.Assert.*; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.nullable; -import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.reset; import static org.mockito.Mockito.times; @@ -71,7 +71,7 @@ public class S3BlockSpillerTest private String splitId = "splitId"; @Mock - private AmazonS3 mockS3; + private S3Client mockS3; private S3BlockSpiller blockWriter; private EncryptionKeyFactory keyFactory = new LocalKeyFactory(); @@ -130,18 +130,20 @@ public void spillTest() final ByteHolder byteHolder = new ByteHolder(); - ArgumentCaptor argument = ArgumentCaptor.forClass(PutObjectRequest.class); + ArgumentCaptor requestArgument = ArgumentCaptor.forClass(PutObjectRequest.class); + ArgumentCaptor bodyArgument = ArgumentCaptor.forClass(RequestBody.class); - when(mockS3.putObject(any())) + when(mockS3.putObject(any(PutObjectRequest.class), any(RequestBody.class))) .thenAnswer(new Answer() { @Override public Object answer(InvocationOnMock invocationOnMock) throws Throwable { - InputStream inputStream = ((PutObjectRequest) invocationOnMock.getArguments()[0]).getInputStream(); + PutObjectResponse response = PutObjectResponse.builder().build(); + InputStream inputStream = ((RequestBody) invocationOnMock.getArguments()[1]).contentStreamProvider().newStream(); byteHolder.setBytes(ByteStreams.toByteArray(inputStream)); - return mock(PutObjectResult.class); + return response; } }); @@ -151,9 +153,9 @@ public Object answer(InvocationOnMock invocationOnMock) assertEquals(bucket, ((S3SpillLocation) blockLocation).getBucket()); assertEquals(prefix + "/" + requestId + "/" + splitId + ".0", ((S3SpillLocation) blockLocation).getKey()); } - verify(mockS3, times(1)).putObject(argument.capture()); - assertEquals(argument.getValue().getBucketName(), bucket); - assertEquals(argument.getValue().getKey(), prefix + "/" + requestId + "/" + splitId + ".0"); + verify(mockS3, times(1)).putObject(requestArgument.capture(), bodyArgument.capture()); + assertEquals(requestArgument.getValue().bucket(), bucket); + assertEquals(requestArgument.getValue().key(), prefix + "/" + requestId + "/" + splitId + ".0"); SpillLocation blockLocation2 = blockWriter.write(expected); @@ -162,25 +164,23 @@ public Object answer(InvocationOnMock invocationOnMock) assertEquals(prefix + "/" + requestId + "/" + splitId + ".1", ((S3SpillLocation) blockLocation2).getKey()); } - verify(mockS3, times(2)).putObject(argument.capture()); - assertEquals(argument.getValue().getBucketName(), bucket); - assertEquals(argument.getValue().getKey(), prefix + "/" + requestId + "/" + splitId + ".1"); + verify(mockS3, times(2)).putObject(requestArgument.capture(), bodyArgument.capture()); + assertEquals(requestArgument.getValue().bucket(), bucket); + assertEquals(requestArgument.getValue().key(), prefix + "/" + requestId + "/" + splitId + ".1"); verifyNoMoreInteractions(mockS3); reset(mockS3); logger.info("spillTest: Starting read test."); - when(mockS3.getObject(eq(bucket), eq(prefix + "/" + requestId + "/" + splitId + ".1"))) + when(mockS3.getObject(any(GetObjectRequest.class))) .thenAnswer(new Answer() { @Override public Object answer(InvocationOnMock invocationOnMock) throws Throwable { - S3Object mockObject = mock(S3Object.class); - when(mockObject.getObjectContent()).thenReturn(new S3ObjectInputStream(new ByteArrayInputStream(byteHolder.getBytes()), null)); - return mockObject; + return new ResponseInputStream<>(GetObjectResponse.builder().build(), new ByteArrayInputStream(byteHolder.getBytes())); } }); @@ -189,7 +189,7 @@ public Object answer(InvocationOnMock invocationOnMock) assertEquals(expected, block); verify(mockS3, times(1)) - .getObject(eq(bucket), eq(prefix + "/" + requestId + "/" + splitId + ".1")); + .getObject(any(GetObjectRequest.class)); verifyNoMoreInteractions(mockS3); diff --git a/athena-federation-sdk/src/test/java/com/amazonaws/athena/connector/lambda/domain/spill/SpillLocationVerifierTest.java b/athena-federation-sdk/src/test/java/com/amazonaws/athena/connector/lambda/domain/spill/SpillLocationVerifierTest.java index 9c44a7a84d..88ca65cd5b 100644 --- a/athena-federation-sdk/src/test/java/com/amazonaws/athena/connector/lambda/domain/spill/SpillLocationVerifierTest.java +++ b/athena-federation-sdk/src/test/java/com/amazonaws/athena/connector/lambda/domain/spill/SpillLocationVerifierTest.java @@ -20,8 +20,6 @@ * #L% */ -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.Bucket; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -29,6 +27,10 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.Bucket; +import software.amazon.awssdk.services.s3.model.ListBucketsResponse; + import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -57,7 +59,7 @@ public void setup() bucketNames = Arrays.asList("bucket1", "bucket2", "bucket3"); List buckets = createBuckets(bucketNames); - AmazonS3 mockS3 = createMockS3(buckets); + S3Client mockS3 = createMockS3(buckets); spyVerifier = spy(new SpillLocationVerifier(mockS3)); logger.info("setUpBefore - exit"); @@ -137,19 +139,19 @@ public void checkBucketAuthZFail() logger.info("checkBucketAuthZFail - exit"); } - private AmazonS3 createMockS3(List buckets) + private S3Client createMockS3(List buckets) { - AmazonS3 s3mock = mock(AmazonS3.class); - when(s3mock.listBuckets()).thenReturn(buckets); + S3Client s3mock = mock(S3Client.class); + ListBucketsResponse response = ListBucketsResponse.builder().buckets(buckets).build(); + when(s3mock.listBuckets()).thenReturn(response); return s3mock; } private List createBuckets(List names) { - List buckets = new ArrayList(); + List buckets = new ArrayList<>(); for (String name : names) { - Bucket bucket = mock(Bucket.class); - when(bucket.getName()).thenReturn(name); + Bucket bucket = Bucket.builder().name(name).build(); buckets.add(bucket); } diff --git a/athena-gcs/src/main/java/com/amazonaws/athena/connectors/gcs/GcsRecordHandler.java b/athena-gcs/src/main/java/com/amazonaws/athena/connectors/gcs/GcsRecordHandler.java index a9307fda32..4c86fe7e13 100644 --- a/athena-gcs/src/main/java/com/amazonaws/athena/connectors/gcs/GcsRecordHandler.java +++ b/athena-gcs/src/main/java/com/amazonaws/athena/connectors/gcs/GcsRecordHandler.java @@ -28,8 +28,6 @@ import com.amazonaws.athena.connector.lambda.domain.TableName; import com.amazonaws.athena.connector.lambda.handlers.RecordHandler; import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import org.apache.arrow.dataset.file.FileFormat; @@ -50,6 +48,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.nio.charset.StandardCharsets; @@ -75,7 +74,7 @@ public class GcsRecordHandler public GcsRecordHandler(BufferAllocator allocator, java.util.Map configOptions) { - this(AmazonS3ClientBuilder.defaultClient(), + this(S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(), configOptions); this.allocator = allocator; @@ -89,7 +88,7 @@ public GcsRecordHandler(BufferAllocator allocator, java.util.Map * @param amazonAthena An instance of AmazonAthena */ @VisibleForTesting - protected GcsRecordHandler(AmazonS3 amazonS3, SecretsManagerClient secretsManager, AthenaClient amazonAthena, java.util.Map configOptions) + protected GcsRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient amazonAthena, java.util.Map configOptions) { super(amazonS3, secretsManager, amazonAthena, SOURCE_TYPE, configOptions); this.invoker = ThrottlingInvoker.newDefaultBuilder(EXCEPTION_FILTER, configOptions).build(); diff --git a/athena-gcs/src/test/java/com/amazonaws/athena/connectors/gcs/GcsCompositeHandlerTest.java b/athena-gcs/src/test/java/com/amazonaws/athena/connectors/gcs/GcsCompositeHandlerTest.java index 5a6d3e0fc8..bc3d54c7dc 100644 --- a/athena-gcs/src/test/java/com/amazonaws/athena/connectors/gcs/GcsCompositeHandlerTest.java +++ b/athena-gcs/src/test/java/com/amazonaws/athena/connectors/gcs/GcsCompositeHandlerTest.java @@ -19,8 +19,6 @@ */ package com.amazonaws.athena.connectors.gcs; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; import com.google.auth.oauth2.GoogleCredentials; import com.google.auth.oauth2.ServiceAccountCredentials; import org.junit.jupiter.api.AfterAll; @@ -28,6 +26,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; import org.mockito.Mockito; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; @@ -59,10 +58,8 @@ public void init() { mockedServiceAccountCredentials.when(() -> ServiceAccountCredentials.fromStream(Mockito.any())).thenReturn(serviceAccountCredentials); credentials = Mockito.mock(GoogleCredentials.class); mockedGoogleCredentials.when(() -> GoogleCredentials.fromStream(Mockito.any())).thenReturn(credentials); - AmazonS3ClientBuilder mockedAmazonS3Builder = Mockito.mock(AmazonS3ClientBuilder.class); - AmazonS3 mockedAmazonS3 = Mockito.mock(AmazonS3.class); - when(mockedAmazonS3Builder.build()).thenReturn(mockedAmazonS3); - mockedS3Builder.when(AmazonS3ClientBuilder::standard).thenReturn(mockedAmazonS3Builder); + S3Client mockedAmazonS3 = Mockito.mock(S3Client.class); + when(S3Client.create()).thenReturn(mockedAmazonS3); } @AfterAll diff --git a/athena-gcs/src/test/java/com/amazonaws/athena/connectors/gcs/GcsRecordHandlerTest.java b/athena-gcs/src/test/java/com/amazonaws/athena/connectors/gcs/GcsRecordHandlerTest.java index 6f88aa4e3c..3e340142b7 100644 --- a/athena-gcs/src/test/java/com/amazonaws/athena/connectors/gcs/GcsRecordHandlerTest.java +++ b/athena-gcs/src/test/java/com/amazonaws/athena/connectors/gcs/GcsRecordHandlerTest.java @@ -34,8 +34,6 @@ import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory; import com.amazonaws.athena.connector.lambda.security.FederatedIdentity; import com.amazonaws.athena.connector.lambda.security.LocalKeyFactory; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; import com.google.auth.oauth2.GoogleCredentials; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; @@ -49,6 +47,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.io.File; @@ -105,7 +104,7 @@ public void initCommonMockedStatic() LOGGER.info("Starting init."); federatedIdentity = Mockito.mock(FederatedIdentity.class); BlockAllocator allocator = new BlockAllocatorImpl(); - AmazonS3 amazonS3 = mock(AmazonS3.class); + S3Client amazonS3 = mock(S3Client.class); // Create Spill config // This will be enough for a single block @@ -123,7 +122,7 @@ public void initCommonMockedStatic() .withSpillLocation(s3SpillLocation) .build(); // To mock AmazonS3 via AmazonS3ClientBuilder - mockedS3Builder.when(AmazonS3ClientBuilder::defaultClient).thenReturn(amazonS3); + mockedS3Builder.when(S3Client::create).thenReturn(amazonS3); // To mock SecretsManagerClient via SecretsManagerClient mockedSecretManagerBuilder.when(SecretsManagerClient::create).thenReturn(secretsManager); // To mock AmazonAthena via AmazonAthenaClientBuilder diff --git a/athena-gcs/src/test/java/com/amazonaws/athena/connectors/gcs/GenericGcsTest.java b/athena-gcs/src/test/java/com/amazonaws/athena/connectors/gcs/GenericGcsTest.java index 885e5ca4d8..7d6fbef4f4 100644 --- a/athena-gcs/src/test/java/com/amazonaws/athena/connectors/gcs/GenericGcsTest.java +++ b/athena-gcs/src/test/java/com/amazonaws/athena/connectors/gcs/GenericGcsTest.java @@ -19,19 +19,19 @@ */ package com.amazonaws.athena.connectors.gcs; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; import com.google.auth.oauth2.GoogleCredentials; import com.google.auth.oauth2.ServiceAccountCredentials; import org.mockito.MockedStatic; import org.mockito.Mockito; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.lang.reflect.Field; public class GenericGcsTest { - protected MockedStatic mockedS3Builder; + protected MockedStatic mockedS3Builder; protected MockedStatic mockedSecretManagerBuilder; protected MockedStatic mockedAthenaClientBuilder; protected MockedStatic mockedGoogleCredentials; @@ -41,7 +41,7 @@ public class GenericGcsTest protected void initCommonMockedStatic() { - mockedS3Builder = Mockito.mockStatic(AmazonS3ClientBuilder.class); + mockedS3Builder = Mockito.mockStatic(S3Client.class); mockedSecretManagerBuilder = Mockito.mockStatic(SecretsManagerClient.class); mockedAthenaClientBuilder = Mockito.mockStatic(AthenaClient.class); mockedGoogleCredentials = Mockito.mockStatic(GoogleCredentials.class); diff --git a/athena-google-bigquery/src/main/java/com/amazonaws/athena/connectors/google/bigquery/BigQueryRecordHandler.java b/athena-google-bigquery/src/main/java/com/amazonaws/athena/connectors/google/bigquery/BigQueryRecordHandler.java index a0d82f7284..dc0bf9dce6 100644 --- a/athena-google-bigquery/src/main/java/com/amazonaws/athena/connectors/google/bigquery/BigQueryRecordHandler.java +++ b/athena-google-bigquery/src/main/java/com/amazonaws/athena/connectors/google/bigquery/BigQueryRecordHandler.java @@ -29,8 +29,6 @@ import com.amazonaws.athena.connector.lambda.handlers.RecordHandler; import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest; import com.amazonaws.athena.connectors.google.bigquery.qpt.BigQueryQueryPassthrough; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; import com.google.api.gax.rpc.ServerStream; import com.google.cloud.bigquery.BigQuery; import com.google.cloud.bigquery.BigQueryException; @@ -59,6 +57,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.io.IOException; @@ -91,13 +90,13 @@ public class BigQueryRecordHandler BigQueryRecordHandler(java.util.Map configOptions, BufferAllocator allocator) { - this(AmazonS3ClientBuilder.defaultClient(), + this(S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(), configOptions, allocator); } @VisibleForTesting - public BigQueryRecordHandler(AmazonS3 amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, java.util.Map configOptions, BufferAllocator allocator) + public BigQueryRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, java.util.Map configOptions, BufferAllocator allocator) { super(amazonS3, secretsManager, athena, BigQueryConstants.SOURCE_TYPE, configOptions); this.invoker = ThrottlingInvoker.newDefaultBuilder(EXCEPTION_FILTER, configOptions).build(); diff --git a/athena-google-bigquery/src/test/java/com/amazonaws/athena/connectors/google/bigquery/BigQueryRecordHandlerTest.java b/athena-google-bigquery/src/test/java/com/amazonaws/athena/connectors/google/bigquery/BigQueryRecordHandlerTest.java index 948ee7d677..371b939508 100644 --- a/athena-google-bigquery/src/test/java/com/amazonaws/athena/connectors/google/bigquery/BigQueryRecordHandlerTest.java +++ b/athena-google-bigquery/src/test/java/com/amazonaws/athena/connectors/google/bigquery/BigQueryRecordHandlerTest.java @@ -35,7 +35,6 @@ import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory; import com.amazonaws.athena.connector.lambda.security.FederatedIdentity; import com.amazonaws.athena.connector.lambda.security.LocalKeyFactory; -import com.amazonaws.services.s3.AmazonS3; import com.google.api.gax.rpc.ServerStream; import com.google.api.gax.rpc.ServerStreamingCallable; import com.google.cloud.bigquery.BigQuery; @@ -78,6 +77,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.nio.charset.StandardCharsets; @@ -120,7 +120,7 @@ public class BigQueryRecordHandlerTest @Mock private ArrowSchema arrowSchema; private BigQueryRecordHandler bigQueryRecordHandler; - private AmazonS3 amazonS3; + private S3Client amazonS3; private S3BlockSpiller spillWriter; private S3BlockSpillReader spillReader; private Schema schemaForRead; @@ -200,7 +200,7 @@ public void init() mockedStatic.when(() -> BigQueryUtils.getBigQueryClient(any(Map.class))).thenReturn(bigQuery); federatedIdentity = Mockito.mock(FederatedIdentity.class); allocator = new BlockAllocatorImpl(); - amazonS3 = mock(AmazonS3.class); + amazonS3 = mock(S3Client.class); //Create Spill config spillConfig = SpillConfig.newBuilder() diff --git a/athena-hbase/src/main/java/com/amazonaws/athena/connectors/hbase/HbaseKerberosUtils.java b/athena-hbase/src/main/java/com/amazonaws/athena/connectors/hbase/HbaseKerberosUtils.java index f8e3282ebe..56f0ddcfa8 100644 --- a/athena-hbase/src/main/java/com/amazonaws/athena/connectors/hbase/HbaseKerberosUtils.java +++ b/athena-hbase/src/main/java/com/amazonaws/athena/connectors/hbase/HbaseKerberosUtils.java @@ -19,17 +19,15 @@ */ package com.amazonaws.athena.connectors.hbase; -import com.amazonaws.auth.AWSCredentials; -import com.amazonaws.auth.AWSStaticCredentialsProvider; -import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.s3.model.GetObjectRequest; -import com.amazonaws.services.s3.model.ObjectListing; -import com.amazonaws.services.s3.model.S3Object; -import com.amazonaws.services.s3.model.S3ObjectSummary; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.ListObjectsRequest; +import software.amazon.awssdk.services.s3.model.ListObjectsResponse; +import software.amazon.awssdk.services.s3.model.S3Object; import java.io.BufferedInputStream; import java.io.File; @@ -68,20 +66,24 @@ public static Path copyConfigFilesFromS3ToTempFolder(java.util.Map responseStream = s3Client.getObject(GetObjectRequest.builder() + .bucket(s3Bucket[0]) + .key(s3Object.key()) + .build()); + InputStream inputStream = new BufferedInputStream(responseStream); + String key = s3Object.key(); String fName = key.substring(key.indexOf('/') + 1); if (!fName.isEmpty()) { File file = new File(tempDir + File.separator + fName); diff --git a/athena-hbase/src/main/java/com/amazonaws/athena/connectors/hbase/HbaseRecordHandler.java b/athena-hbase/src/main/java/com/amazonaws/athena/connectors/hbase/HbaseRecordHandler.java index 8af36b5e54..ad51a3fc35 100644 --- a/athena-hbase/src/main/java/com/amazonaws/athena/connectors/hbase/HbaseRecordHandler.java +++ b/athena-hbase/src/main/java/com/amazonaws/athena/connectors/hbase/HbaseRecordHandler.java @@ -31,8 +31,6 @@ import com.amazonaws.athena.connectors.hbase.connection.HBaseConnection; import com.amazonaws.athena.connectors.hbase.connection.HbaseConnectionFactory; import com.amazonaws.athena.connectors.hbase.qpt.HbaseQueryPassthrough; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; import org.apache.arrow.util.VisibleForTesting; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.ArrowType; @@ -52,6 +50,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.io.IOException; @@ -81,7 +80,7 @@ public class HbaseRecordHandler //Used to denote the 'type' of this connector for diagnostic purposes. private static final String SOURCE_TYPE = "hbase"; - private final AmazonS3 amazonS3; + private final S3Client amazonS3; private final HbaseConnectionFactory connectionFactory; private final HbaseQueryPassthrough queryPassthrough = new HbaseQueryPassthrough(); @@ -89,7 +88,7 @@ public class HbaseRecordHandler public HbaseRecordHandler(java.util.Map configOptions) { this( - AmazonS3ClientBuilder.defaultClient(), + S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(), new HbaseConnectionFactory(), @@ -97,7 +96,7 @@ public HbaseRecordHandler(java.util.Map configOptions) } @VisibleForTesting - protected HbaseRecordHandler(AmazonS3 amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, HbaseConnectionFactory connectionFactory, java.util.Map configOptions) + protected HbaseRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, HbaseConnectionFactory connectionFactory, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, SOURCE_TYPE, configOptions); this.amazonS3 = amazonS3; diff --git a/athena-hbase/src/test/java/com/amazonaws/athena/connectors/hbase/HbaseRecordHandlerTest.java b/athena-hbase/src/test/java/com/amazonaws/athena/connectors/hbase/HbaseRecordHandlerTest.java index 8d3ebd1b45..017608c74d 100644 --- a/athena-hbase/src/test/java/com/amazonaws/athena/connectors/hbase/HbaseRecordHandlerTest.java +++ b/athena-hbase/src/test/java/com/amazonaws/athena/connectors/hbase/HbaseRecordHandlerTest.java @@ -43,11 +43,6 @@ import com.amazonaws.athena.connectors.hbase.connection.HBaseConnection; import com.amazonaws.athena.connectors.hbase.connection.HbaseConnectionFactory; import com.amazonaws.athena.connectors.hbase.connection.ResultProcessor; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.PutObjectRequest; -import com.amazonaws.services.s3.model.PutObjectResult; -import com.amazonaws.services.s3.model.S3Object; -import com.amazonaws.services.s3.model.S3ObjectInputStream; import com.google.common.collect.ImmutableList; import com.google.common.io.ByteStreams; import org.apache.arrow.vector.types.Types; @@ -69,6 +64,13 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.io.ByteArrayInputStream; @@ -102,7 +104,7 @@ public class HbaseRecordHandlerTest private HbaseRecordHandler handler; private BlockAllocator allocator; private List mockS3Storage = new ArrayList<>(); - private AmazonS3 amazonS3; + private S3Client amazonS3; private S3BlockSpillReader spillReader; private Schema schemaForRead; private EncryptionKeyFactory keyFactory = new LocalKeyFactory(); @@ -135,33 +137,29 @@ public void setUp() allocator = new BlockAllocatorImpl(); - amazonS3 = mock(AmazonS3.class); + amazonS3 = mock(S3Client.class); - when(amazonS3.putObject(any())) + when(amazonS3.putObject(any(PutObjectRequest.class), any(RequestBody.class))) .thenAnswer((InvocationOnMock invocationOnMock) -> { - InputStream inputStream = ((PutObjectRequest) invocationOnMock.getArguments()[0]).getInputStream(); + InputStream inputStream = ((RequestBody) invocationOnMock.getArguments()[1]).contentStreamProvider().newStream(); ByteHolder byteHolder = new ByteHolder(); byteHolder.setBytes(ByteStreams.toByteArray(inputStream)); synchronized (mockS3Storage) { mockS3Storage.add(byteHolder); logger.info("puObject: total size " + mockS3Storage.size()); } - return mock(PutObjectResult.class); + return PutObjectResponse.builder().build(); }); - when(amazonS3.getObject(nullable(String.class), nullable(String.class))) + when(amazonS3.getObject(any(GetObjectRequest.class))) .thenAnswer((InvocationOnMock invocationOnMock) -> { - S3Object mockObject = mock(S3Object.class); ByteHolder byteHolder; synchronized (mockS3Storage) { byteHolder = mockS3Storage.get(0); mockS3Storage.remove(0); logger.info("getObject: total size " + mockS3Storage.size()); } - when(mockObject.getObjectContent()).thenReturn( - new S3ObjectInputStream( - new ByteArrayInputStream(byteHolder.getBytes()), null)); - return mockObject; + return new ResponseInputStream<>(GetObjectResponse.builder().build(), new ByteArrayInputStream(byteHolder.getBytes())); }); schemaForRead = TestUtils.makeSchema().addStringField(HbaseSchemaUtils.ROW_COLUMN_NAME).build(); diff --git a/athena-hortonworks-hive/src/main/java/com/amazonaws/athena/connectors/hortonworks/HiveMuxRecordHandler.java b/athena-hortonworks-hive/src/main/java/com/amazonaws/athena/connectors/hortonworks/HiveMuxRecordHandler.java index 9709e4adde..aa676b7b99 100644 --- a/athena-hortonworks-hive/src/main/java/com/amazonaws/athena/connectors/hortonworks/HiveMuxRecordHandler.java +++ b/athena-hortonworks-hive/src/main/java/com/amazonaws/athena/connectors/hortonworks/HiveMuxRecordHandler.java @@ -25,9 +25,9 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandlerFactory; -import com.amazonaws.services.s3.AmazonS3; import com.google.common.annotations.VisibleForTesting; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -57,7 +57,7 @@ public HiveMuxRecordHandler(java.util.Map configOptions) } @VisibleForTesting - HiveMuxRecordHandler(AmazonS3 amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, + HiveMuxRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, DatabaseConnectionConfig databaseConnectionConfig, Map recordHandlerMap, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, jdbcConnectionFactory, databaseConnectionConfig, recordHandlerMap, configOptions); diff --git a/athena-hortonworks-hive/src/main/java/com/amazonaws/athena/connectors/hortonworks/HiveRecordHandler.java b/athena-hortonworks-hive/src/main/java/com/amazonaws/athena/connectors/hortonworks/HiveRecordHandler.java index 5d5ab5035a..1450634a1e 100644 --- a/athena-hortonworks-hive/src/main/java/com/amazonaws/athena/connectors/hortonworks/HiveRecordHandler.java +++ b/athena-hortonworks-hive/src/main/java/com/amazonaws/athena/connectors/hortonworks/HiveRecordHandler.java @@ -28,12 +28,11 @@ import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; import com.google.common.annotations.VisibleForTesting; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.commons.lang3.Validate; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; @@ -60,11 +59,11 @@ public HiveRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, java } public HiveRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, JdbcConnectionFactory jdbcConnectionFactory, java.util.Map configOptions) { - this(databaseConnectionConfig, AmazonS3ClientBuilder.defaultClient(), SecretsManagerClient.create(), AthenaClient.create(), + this(databaseConnectionConfig, S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(), jdbcConnectionFactory, new HiveQueryStringBuilder(HIVE_QUOTE_CHARACTER, new HiveFederationExpressionParser(HIVE_QUOTE_CHARACTER)), configOptions); } @VisibleForTesting - HiveRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, AmazonS3 amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) + HiveRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, databaseConnectionConfig, jdbcConnectionFactory, configOptions); this.jdbcSplitQueryBuilder = Validate.notNull(jdbcSplitQueryBuilder, "query builder must not be null"); diff --git a/athena-hortonworks-hive/src/test/java/com/amazonaws/athena/connectors/hortonworks/HiveMuxRecordHandlerTest.java b/athena-hortonworks-hive/src/test/java/com/amazonaws/athena/connectors/hortonworks/HiveMuxRecordHandlerTest.java index 39074a4f36..32dba90175 100644 --- a/athena-hortonworks-hive/src/test/java/com/amazonaws/athena/connectors/hortonworks/HiveMuxRecordHandlerTest.java +++ b/athena-hortonworks-hive/src/test/java/com/amazonaws/athena/connectors/hortonworks/HiveMuxRecordHandlerTest.java @@ -29,7 +29,6 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; -import com.amazonaws.services.s3.AmazonS3; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Before; import org.junit.BeforeClass; @@ -37,6 +36,7 @@ import org.mockito.Mockito; import org.testng.Assert; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; @@ -50,7 +50,7 @@ public class HiveMuxRecordHandlerTest private Map recordHandlerMap; private HiveRecordHandler hiveRecordHandler; private JdbcRecordHandler jdbcRecordHandler; - private AmazonS3 amazonS3; + private S3Client amazonS3; private SecretsManagerClient secretsManager; private AthenaClient athena; private QueryStatusChecker queryStatusChecker; @@ -64,7 +64,7 @@ public void setup() { this.hiveRecordHandler = Mockito.mock(HiveRecordHandler.class); this.recordHandlerMap = Collections.singletonMap("recordHive", this.hiveRecordHandler); - this.amazonS3 = Mockito.mock(AmazonS3.class); + this.amazonS3 = Mockito.mock(S3Client.class); this.secretsManager = Mockito.mock(SecretsManagerClient.class); this.athena = Mockito.mock(AthenaClient.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); diff --git a/athena-hortonworks-hive/src/test/java/com/amazonaws/athena/connectors/hortonworks/HiveRecordHandlerTest.java b/athena-hortonworks-hive/src/test/java/com/amazonaws/athena/connectors/hortonworks/HiveRecordHandlerTest.java index 5fc3438e2f..c45cfaf7c4 100644 --- a/athena-hortonworks-hive/src/test/java/com/amazonaws/athena/connectors/hortonworks/HiveRecordHandlerTest.java +++ b/athena-hortonworks-hive/src/test/java/com/amazonaws/athena/connectors/hortonworks/HiveRecordHandlerTest.java @@ -46,10 +46,10 @@ import com.amazonaws.athena.connector.lambda.domain.predicate.Range; import com.amazonaws.athena.connector.lambda.domain.predicate.SortedRangeSet; import com.amazonaws.athena.connector.lambda.domain.predicate.ValueSet; -import com.amazonaws.services.s3.AmazonS3; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; @@ -64,7 +64,7 @@ public class HiveRecordHandlerTest private Connection connection; private JdbcConnectionFactory jdbcConnectionFactory; private JdbcSplitQueryBuilder jdbcSplitQueryBuilder; - private AmazonS3 amazonS3; + private S3Client amazonS3; private SecretsManagerClient secretsManager; private AthenaClient athena; @@ -72,7 +72,7 @@ public class HiveRecordHandlerTest public void setup() throws Exception { - this.amazonS3 = Mockito.mock(AmazonS3.class); + this.amazonS3 = Mockito.mock(S3Client.class); this.secretsManager = Mockito.mock(SecretsManagerClient.class); this.athena = Mockito.mock(AthenaClient.class); Mockito.when(this.secretsManager.getSecretValue(Mockito.eq(GetSecretValueRequest.builder().secretId("testSecret").build()))).thenReturn(GetSecretValueResponse.builder().secretString("{\"username\": \"testUser\", \"password\": \"testPassword\"}").build()); diff --git a/athena-jdbc/pom.xml b/athena-jdbc/pom.xml index 8ea9371ea6..a5dadefca8 100644 --- a/athena-jdbc/pom.xml +++ b/athena-jdbc/pom.xml @@ -9,33 +9,6 @@ athena-jdbc 2022.47.1 - - com.amazonaws - aws-java-sdk-core - ${aws-sdk.version} - - - com.fasterxml.jackson.datatype - jackson-datatype-jsr310 - - - com.fasterxml.jackson.dataformat - jackson-dataformat-cbor - - - com.fasterxml.jackson.core - jackson-core - - - com.fasterxml.jackson.core - jackson-databind - - - com.fasterxml.jackson.core - jackson-annotations - - - software.amazon.jsii jsii-runtime diff --git a/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/MultiplexingJdbcRecordHandler.java b/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/MultiplexingJdbcRecordHandler.java index 2b791d3454..e2cb4f227c 100644 --- a/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/MultiplexingJdbcRecordHandler.java +++ b/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/MultiplexingJdbcRecordHandler.java @@ -30,11 +30,11 @@ import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandlerFactory; -import com.amazonaws.services.s3.AmazonS3; import com.google.common.annotations.VisibleForTesting; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.commons.lang3.Validate; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; @@ -61,7 +61,7 @@ public MultiplexingJdbcRecordHandler(JdbcRecordHandlerFactory jdbcRecordHandlerF @VisibleForTesting protected MultiplexingJdbcRecordHandler( - AmazonS3 amazonS3, + S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, diff --git a/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcRecordHandler.java b/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcRecordHandler.java index 044488e256..9b82c5428e 100644 --- a/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcRecordHandler.java +++ b/athena-jdbc/src/main/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcRecordHandler.java @@ -54,7 +54,6 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; import com.amazonaws.athena.connectors.jdbc.connection.RdsSecretsCredentialProvider; import com.amazonaws.athena.connectors.jdbc.qpt.JdbcQueryPassthrough; -import com.amazonaws.services.s3.AmazonS3; import org.apache.arrow.util.VisibleForTesting; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.holders.NullableBigIntHolder; @@ -75,6 +74,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Array; @@ -111,7 +111,7 @@ protected JdbcRecordHandler(String sourceType, java.util.Map con } protected JdbcRecordHandler( - AmazonS3 amazonS3, + S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, DatabaseConnectionConfig databaseConnectionConfig, diff --git a/athena-jdbc/src/test/java/com/amazonaws/athena/connectors/jdbc/MultiplexingJdbcRecordHandlerTest.java b/athena-jdbc/src/test/java/com/amazonaws/athena/connectors/jdbc/MultiplexingJdbcRecordHandlerTest.java index 1177f10375..60e229d6f7 100644 --- a/athena-jdbc/src/test/java/com/amazonaws/athena/connectors/jdbc/MultiplexingJdbcRecordHandlerTest.java +++ b/athena-jdbc/src/test/java/com/amazonaws/athena/connectors/jdbc/MultiplexingJdbcRecordHandlerTest.java @@ -28,12 +28,12 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; -import com.amazonaws.services.s3.AmazonS3; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; @@ -46,7 +46,7 @@ public class MultiplexingJdbcRecordHandlerTest private Map recordHandlerMap; private JdbcRecordHandler fakeJdbcRecordHandler; private JdbcRecordHandler jdbcRecordHandler; - private AmazonS3 amazonS3; + private S3Client amazonS3; private SecretsManagerClient secretsManager; private AthenaClient athena; private QueryStatusChecker queryStatusChecker; @@ -57,7 +57,7 @@ public void setup() { this.fakeJdbcRecordHandler = Mockito.mock(JdbcRecordHandler.class); this.recordHandlerMap = Collections.singletonMap("fakedatabase", this.fakeJdbcRecordHandler); - this.amazonS3 = Mockito.mock(AmazonS3.class); + this.amazonS3 = Mockito.mock(S3Client.class); this.secretsManager = Mockito.mock(SecretsManagerClient.class); this.athena = Mockito.mock(AthenaClient.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); diff --git a/athena-jdbc/src/test/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcRecordHandlerTest.java b/athena-jdbc/src/test/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcRecordHandlerTest.java index 17a0cb55e3..cfbeba5602 100644 --- a/athena-jdbc/src/test/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcRecordHandlerTest.java +++ b/athena-jdbc/src/test/java/com/amazonaws/athena/connectors/jdbc/manager/JdbcRecordHandlerTest.java @@ -39,17 +39,18 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.PutObjectRequest; -import com.amazonaws.services.s3.model.PutObjectResult; import org.apache.arrow.vector.holders.NullableFloat8Holder; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; -import org.mockito.stubbing.Answer; +import org.mockito.invocation.InvocationOnMock; +import software.amazon.awssdk.core.sync.RequestBody; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; @@ -75,7 +76,7 @@ public class JdbcRecordHandlerTest private JdbcRecordHandler jdbcRecordHandler; private Connection connection; private JdbcConnectionFactory jdbcConnectionFactory; - private AmazonS3 amazonS3; + private S3Client amazonS3; private SecretsManagerClient secretsManager; private AthenaClient athena; private QueryStatusChecker queryStatusChecker; @@ -89,7 +90,7 @@ public void setup() this.connection = Mockito.mock(Connection.class, Mockito.RETURNS_DEEP_STUBS); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class); Mockito.when(this.jdbcConnectionFactory.getConnection(nullable(JdbcCredentialProvider.class))).thenReturn(this.connection); - this.amazonS3 = Mockito.mock(AmazonS3.class); + this.amazonS3 = Mockito.mock(S3Client.class); this.secretsManager = Mockito.mock(SecretsManagerClient.class); this.athena = Mockito.mock(AthenaClient.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); @@ -143,15 +144,16 @@ public void readWithConstraint() BlockSpiller s3Spiller = new S3BlockSpiller(this.amazonS3, spillConfig, allocator, fieldSchema, constraintEvaluator, com.google.common.collect.ImmutableMap.of()); ReadRecordsRequest readRecordsRequest = new ReadRecordsRequest(this.federatedIdentity, "testCatalog", "testQueryId", inputTableName, fieldSchema, splitBuilder.build(), constraints, 1024, 1024); - Mockito.when(amazonS3.putObject(any())).thenAnswer((Answer) invocation -> { - ByteArrayInputStream byteArrayInputStream = (ByteArrayInputStream) ((PutObjectRequest) invocation.getArguments()[0]).getInputStream(); - int n = byteArrayInputStream.available(); - byte[] bytes = new byte[n]; - byteArrayInputStream.read(bytes, 0, n); - String data = new String(bytes, StandardCharsets.UTF_8); - Assert.assertTrue(data.contains("testVal1") || data.contains("testVal2") || data.contains("testPartitionValue")); - return new PutObjectResult(); - }); + Mockito.when(amazonS3.putObject(any(PutObjectRequest.class), any(RequestBody.class))) + .thenAnswer((InvocationOnMock invocationOnMock) -> { + ByteArrayInputStream inputStream = (ByteArrayInputStream) ((RequestBody) invocationOnMock.getArguments()[1]).contentStreamProvider().newStream(); + int n = inputStream.available(); + byte[] bytes = new byte[n]; + inputStream.read(bytes, 0, n); + String data = new String(bytes, StandardCharsets.UTF_8); + Assert.assertTrue(data.contains("testVal1") || data.contains("testVal2") || data.contains("testPartitionValue")); + return PutObjectResponse.builder().build(); + }); this.jdbcRecordHandler.readWithConstraint(s3Spiller, readRecordsRequest, queryStatusChecker); } diff --git a/athena-kafka/src/main/java/com/amazonaws/athena/connectors/kafka/KafkaRecordHandler.java b/athena-kafka/src/main/java/com/amazonaws/athena/connectors/kafka/KafkaRecordHandler.java index 74d8679563..dccd950901 100644 --- a/athena-kafka/src/main/java/com/amazonaws/athena/connectors/kafka/KafkaRecordHandler.java +++ b/athena-kafka/src/main/java/com/amazonaws/athena/connectors/kafka/KafkaRecordHandler.java @@ -27,8 +27,6 @@ import com.amazonaws.athena.connectors.kafka.dto.KafkaField; import com.amazonaws.athena.connectors.kafka.dto.SplitParameters; import com.amazonaws.athena.connectors.kafka.dto.TopicResultSet; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; import com.google.common.annotations.VisibleForTesting; import com.google.protobuf.Descriptors; import com.google.protobuf.DynamicMessage; @@ -42,6 +40,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.time.Duration; @@ -60,14 +59,14 @@ public class KafkaRecordHandler KafkaRecordHandler(java.util.Map configOptions) { this( - AmazonS3ClientBuilder.defaultClient(), + S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(), configOptions); } @VisibleForTesting - public KafkaRecordHandler(AmazonS3 amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, java.util.Map configOptions) + public KafkaRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, KafkaConstants.KAFKA_SOURCE, configOptions); } diff --git a/athena-kafka/src/main/java/com/amazonaws/athena/connectors/kafka/KafkaUtils.java b/athena-kafka/src/main/java/com/amazonaws/athena/connectors/kafka/KafkaUtils.java index 7ff8296da6..0746aace06 100644 --- a/athena-kafka/src/main/java/com/amazonaws/athena/connectors/kafka/KafkaUtils.java +++ b/athena-kafka/src/main/java/com/amazonaws/athena/connectors/kafka/KafkaUtils.java @@ -24,15 +24,6 @@ import com.amazonaws.athena.connectors.kafka.dto.TopicResultSet; import com.amazonaws.athena.connectors.kafka.serde.KafkaCsvDeserializer; import com.amazonaws.athena.connectors.kafka.serde.KafkaJsonDeserializer; -import com.amazonaws.auth.AWSCredentials; -import com.amazonaws.auth.AWSStaticCredentialsProvider; -import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.s3.model.GetObjectRequest; -import com.amazonaws.services.s3.model.ObjectListing; -import com.amazonaws.services.s3.model.S3Object; -import com.amazonaws.services.s3.model.S3ObjectSummary; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.protobuf.DynamicMessage; @@ -47,6 +38,13 @@ import org.apache.kafka.common.serialization.StringDeserializer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.ListObjectsRequest; +import software.amazon.awssdk.services.s3.model.ListObjectsResponse; +import software.amazon.awssdk.services.s3.model.S3Object; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; @@ -326,20 +324,24 @@ protected static Path copyCertificatesFromS3ToTempFolder(java.util.Map responseStream = s3Client.getObject(GetObjectRequest.builder() + .bucket(s3Bucket[0]) + .key(objectSummary.key()) + .build()); + InputStream inputStream = new BufferedInputStream(responseStream); + String key = objectSummary.key(); String fName = key.substring(key.indexOf('/') + 1); if (!fName.isEmpty()) { File file = new File(tempDir + File.separator + fName); diff --git a/athena-kafka/src/test/java/com/amazonaws/athena/connectors/kafka/KafkaRecordHandlerTest.java b/athena-kafka/src/test/java/com/amazonaws/athena/connectors/kafka/KafkaRecordHandlerTest.java index 929f8ff14b..d172b52f63 100644 --- a/athena-kafka/src/test/java/com/amazonaws/athena/connectors/kafka/KafkaRecordHandlerTest.java +++ b/athena-kafka/src/test/java/com/amazonaws/athena/connectors/kafka/KafkaRecordHandlerTest.java @@ -32,7 +32,6 @@ import com.amazonaws.athena.connector.lambda.security.FederatedIdentity; import com.amazonaws.athena.connector.lambda.security.LocalKeyFactory; import com.amazonaws.athena.connectors.kafka.dto.*; -import com.amazonaws.services.s3.AmazonS3; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.protobuf.Descriptors; @@ -62,6 +61,7 @@ import software.amazon.awssdk.services.glue.model.GetSchemaResponse; import software.amazon.awssdk.services.glue.model.GetSchemaVersionRequest; import software.amazon.awssdk.services.glue.model.GetSchemaVersionResponse; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Collections; @@ -86,7 +86,7 @@ public class KafkaRecordHandlerTest { GlueClient awsGlue; @Mock - AmazonS3 amazonS3; + S3Client amazonS3; @Mock SecretsManagerClient awsSecretsManager; diff --git a/athena-kafka/src/test/java/com/amazonaws/athena/connectors/kafka/KafkaUtilsTest.java b/athena-kafka/src/test/java/com/amazonaws/athena/connectors/kafka/KafkaUtilsTest.java index 178a303023..7baacf6180 100644 --- a/athena-kafka/src/test/java/com/amazonaws/athena/connectors/kafka/KafkaUtilsTest.java +++ b/athena-kafka/src/test/java/com/amazonaws/athena/connectors/kafka/KafkaUtilsTest.java @@ -24,12 +24,6 @@ import com.amazonaws.auth.AWSStaticCredentialsProvider; import com.amazonaws.auth.BasicAWSCredentials; import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; -import com.amazonaws.services.s3.AmazonS3Client; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.s3.model.ListObjectsRequest; -import com.amazonaws.services.s3.model.ObjectListing; -import com.amazonaws.services.s3.model.S3Object; -import com.amazonaws.services.s3.model.S3ObjectSummary; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import org.apache.arrow.vector.types.Types; @@ -48,6 +42,13 @@ import org.mockito.junit.MockitoJUnitRunner; import org.mockito.MockedConstruction; import org.mockito.MockedStatic; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.ListObjectsRequest; +import software.amazon.awssdk.services.s3.model.ListObjectsResponse; +import software.amazon.awssdk.services.s3.model.S3Object; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; @@ -89,13 +90,7 @@ public class KafkaUtilsTest { BasicAWSCredentials credentials; @Mock - AmazonS3Client amazonS3Client; - - @Mock - AmazonS3ClientBuilder clientBuilder; - - @Mock - ObjectListing oList; + S3Client amazonS3Client; final java.util.Map configOptions = com.google.common.collect.ImmutableMap.of( @@ -105,9 +100,8 @@ public class KafkaUtilsTest { "certificates_s3_reference", "s3://kafka-connector-test-bucket/kafkafiles/", "secrets_manager_secret", "Kafka_afq"); - private MockedConstruction mockedObjectMapper; private MockedConstruction mockedDefaultCredentials; - private MockedStatic mockedS3ClientBuilder; + private MockedStatic mockedS3ClientBuilder; private MockedStatic mockedSecretsManagerClient; @@ -132,30 +126,22 @@ public void init() throws Exception { Mockito.when(secretValueResponse.secretString()).thenReturn(creds); Mockito.when(awsSecretsManager.getSecretValue(Mockito.isA(GetSecretValueRequest.class))).thenReturn(secretValueResponse); - mockedObjectMapper = Mockito.mockConstruction(ObjectMapper.class, - (mock, context) -> { - Mockito.doReturn(map).when(mock).readValue(Mockito.eq(creds), nullable(TypeReference.class)); - }); mockedDefaultCredentials = Mockito.mockConstruction(DefaultAWSCredentialsProviderChain.class, (mock, context) -> { Mockito.when(mock.getCredentials()).thenReturn(credentials); }); - mockedS3ClientBuilder = Mockito.mockStatic(AmazonS3ClientBuilder.class); - mockedS3ClientBuilder.when(()-> AmazonS3ClientBuilder.standard()).thenReturn(clientBuilder); - - Mockito.doReturn(clientBuilder).when(clientBuilder).withCredentials(any()); - Mockito.when(clientBuilder.build()).thenReturn(amazonS3Client); - Mockito.when(amazonS3Client.listObjects(any(), any())).thenReturn(oList); - S3Object s3Obj = new S3Object(); - s3Obj.setObjectContent(new ByteArrayInputStream("largeContentFile".getBytes())); - Mockito.when(amazonS3Client.getObject(any())).thenReturn(s3Obj); - S3ObjectSummary s3 = new S3ObjectSummary(); - s3.setKey("test/key"); - Mockito.when(oList.getObjectSummaries()).thenReturn(com.google.common.collect.ImmutableList.of(s3)); + mockedS3ClientBuilder = Mockito.mockStatic(S3Client.class); + mockedS3ClientBuilder.when(()-> S3Client.create()).thenReturn(amazonS3Client); + + S3Object s3 = S3Object.builder().key("test/key").build(); + Mockito.when(amazonS3Client.listObjects(any(ListObjectsRequest.class))).thenReturn(ListObjectsResponse.builder() + .contents(s3) + .build()); + Mockito.when(amazonS3Client.getObject(any(GetObjectRequest.class))) + .thenReturn(new ResponseInputStream<>(GetObjectResponse.builder().build(), new ByteArrayInputStream("largeContentFile".getBytes()))); } @After public void tearDown() { - mockedObjectMapper.close(); mockedDefaultCredentials.close(); mockedS3ClientBuilder.close(); mockedSecretsManagerClient.close(); diff --git a/athena-msk/src/main/java/com/amazonaws/athena/connectors/msk/AmazonMskRecordHandler.java b/athena-msk/src/main/java/com/amazonaws/athena/connectors/msk/AmazonMskRecordHandler.java index 080d72ea44..eed652e94f 100644 --- a/athena-msk/src/main/java/com/amazonaws/athena/connectors/msk/AmazonMskRecordHandler.java +++ b/athena-msk/src/main/java/com/amazonaws/athena/connectors/msk/AmazonMskRecordHandler.java @@ -27,8 +27,6 @@ import com.amazonaws.athena.connectors.msk.dto.MSKField; import com.amazonaws.athena.connectors.msk.dto.SplitParameters; import com.amazonaws.athena.connectors.msk.dto.TopicResultSet; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; import com.google.common.annotations.VisibleForTesting; import org.apache.kafka.clients.consumer.Consumer; import org.apache.kafka.clients.consumer.ConsumerRecord; @@ -38,6 +36,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.time.Duration; @@ -53,14 +52,14 @@ public class AmazonMskRecordHandler AmazonMskRecordHandler(java.util.Map configOptions) { this( - AmazonS3ClientBuilder.defaultClient(), + S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(), configOptions); } @VisibleForTesting - public AmazonMskRecordHandler(AmazonS3 amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, java.util.Map configOptions) + public AmazonMskRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, AmazonMskConstants.MSK_SOURCE, configOptions); } diff --git a/athena-msk/src/main/java/com/amazonaws/athena/connectors/msk/AmazonMskUtils.java b/athena-msk/src/main/java/com/amazonaws/athena/connectors/msk/AmazonMskUtils.java index 9a12c3bf06..9a9db0bf5c 100644 --- a/athena-msk/src/main/java/com/amazonaws/athena/connectors/msk/AmazonMskUtils.java +++ b/athena-msk/src/main/java/com/amazonaws/athena/connectors/msk/AmazonMskUtils.java @@ -24,15 +24,6 @@ import com.amazonaws.athena.connectors.msk.dto.TopicResultSet; import com.amazonaws.athena.connectors.msk.serde.MskCsvDeserializer; import com.amazonaws.athena.connectors.msk.serde.MskJsonDeserializer; -import com.amazonaws.auth.AWSCredentials; -import com.amazonaws.auth.AWSStaticCredentialsProvider; -import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.s3.model.GetObjectRequest; -import com.amazonaws.services.s3.model.ObjectListing; -import com.amazonaws.services.s3.model.S3Object; -import com.amazonaws.services.s3.model.S3ObjectSummary; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import org.apache.arrow.vector.types.Types; @@ -45,6 +36,13 @@ import org.apache.kafka.common.serialization.StringDeserializer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.ListObjectsRequest; +import software.amazon.awssdk.services.s3.model.ListObjectsResponse; +import software.amazon.awssdk.services.s3.model.S3Object; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; @@ -305,20 +303,24 @@ protected static Path copyCertificatesFromS3ToTempFolder(java.util.Map responseStream = s3Client.getObject(GetObjectRequest.builder() + .bucket(s3Bucket[0]) + .key(objectSummary.key()) + .build()); + InputStream inputStream = new BufferedInputStream(responseStream); + String key = objectSummary.key(); String fName = key.substring(key.indexOf('/') + 1); if (!fName.isEmpty()) { File file = new File(tempDir + File.separator + fName); diff --git a/athena-msk/src/test/java/com/amazonaws/athena/connectors/msk/AmazonMskRecordHandlerTest.java b/athena-msk/src/test/java/com/amazonaws/athena/connectors/msk/AmazonMskRecordHandlerTest.java index 7fcfc6a607..e1d2cf31a1 100644 --- a/athena-msk/src/test/java/com/amazonaws/athena/connectors/msk/AmazonMskRecordHandlerTest.java +++ b/athena-msk/src/test/java/com/amazonaws/athena/connectors/msk/AmazonMskRecordHandlerTest.java @@ -32,7 +32,6 @@ import com.amazonaws.athena.connector.lambda.security.FederatedIdentity; import com.amazonaws.athena.connector.lambda.security.LocalKeyFactory; import com.amazonaws.athena.connectors.msk.dto.*; -import com.amazonaws.services.s3.AmazonS3; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import org.apache.arrow.vector.types.pojo.Field; @@ -52,6 +51,7 @@ import org.mockito.MockitoAnnotations; import org.mockito.junit.MockitoJUnitRunner; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Collections; @@ -68,7 +68,7 @@ public class AmazonMskRecordHandlerTest { private static final ObjectMapper objectMapper = new ObjectMapper(); @Mock - AmazonS3 amazonS3; + S3Client amazonS3; @Mock SecretsManagerClient awsSecretsManager; diff --git a/athena-msk/src/test/java/com/amazonaws/athena/connectors/msk/AmazonMskUtilsTest.java b/athena-msk/src/test/java/com/amazonaws/athena/connectors/msk/AmazonMskUtilsTest.java index 1888bddbe1..36db23e1cc 100644 --- a/athena-msk/src/test/java/com/amazonaws/athena/connectors/msk/AmazonMskUtilsTest.java +++ b/athena-msk/src/test/java/com/amazonaws/athena/connectors/msk/AmazonMskUtilsTest.java @@ -24,12 +24,6 @@ import com.amazonaws.auth.AWSStaticCredentialsProvider; import com.amazonaws.auth.BasicAWSCredentials; import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; -import com.amazonaws.services.s3.AmazonS3Client; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.s3.model.ObjectListing; -import com.amazonaws.services.s3.model.S3Object; -import com.amazonaws.services.s3.model.S3ObjectSummary; -import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.ArrowType; @@ -46,6 +40,13 @@ import org.mockito.MockedStatic; import org.mockito.Mockito; import org.mockito.junit.MockitoJUnitRunner; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.ListObjectsRequest; +import software.amazon.awssdk.services.s3.model.ListObjectsResponse; +import software.amazon.awssdk.services.s3.model.S3Object; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; @@ -58,7 +59,6 @@ import static org.junit.Assert.*; import static java.util.Arrays.asList; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.nullable; @RunWith(MockitoJUnitRunner.class) public class AmazonMskUtilsTest { @@ -87,13 +87,7 @@ public class AmazonMskUtilsTest { BasicAWSCredentials credentials; @Mock - AmazonS3Client amazonS3Client; - - @Mock - AmazonS3ClientBuilder clientBuilder; - - @Mock - ObjectListing oList; + S3Client amazonS3Client; final java.util.Map configOptions = com.google.common.collect.ImmutableMap.of( "glue_registry_arn", "arn:aws:glue:us-west-2:123456789101:registry/Athena-Kafka", @@ -101,9 +95,8 @@ public class AmazonMskUtilsTest { "kafka_endpoint", "12.207.18.179:9092", "certificates_s3_reference", "s3://kafka-connector-test-bucket/kafkafiles/", "secrets_manager_secret", "Kafka_afq"); - private MockedConstruction mockedObjectMapper; private MockedConstruction mockedDefaultCredentials; - private MockedStatic mockedS3ClientBuilder; + private MockedStatic mockedS3ClientBuilder; private MockedStatic mockedSecretsManagerClient; @Before @@ -113,6 +106,8 @@ public void init() throws Exception { System.setProperty("aws.secretKey", "vamsajdsjkl"); mockedSecretsManagerClient = Mockito.mockStatic(SecretsManagerClient.class); mockedSecretsManagerClient.when(()-> SecretsManagerClient.create()).thenReturn(awsSecretsManager); + mockedS3ClientBuilder = Mockito.mockStatic(S3Client.class); + mockedS3ClientBuilder.when(()-> S3Client.create()).thenReturn(amazonS3Client); String creds = "{\"username\":\"admin\",\"password\":\"test\",\"keystore_password\":\"keypass\",\"truststore_password\":\"trustpass\",\"ssl_key_password\":\"sslpass\"}"; @@ -126,30 +121,20 @@ public void init() throws Exception { Mockito.when(secretValueResponse.secretString()).thenReturn(creds); Mockito.when(awsSecretsManager.getSecretValue(Mockito.isA(GetSecretValueRequest.class))).thenReturn(secretValueResponse); - mockedObjectMapper = Mockito.mockConstruction(ObjectMapper.class, - (mock, context) -> { - Mockito.doReturn(map).when(mock).readValue(Mockito.eq(creds), nullable(TypeReference.class)); - }); mockedDefaultCredentials = Mockito.mockConstruction(DefaultAWSCredentialsProviderChain.class, (mock, context) -> { Mockito.when(mock.getCredentials()).thenReturn(credentials); }); - mockedS3ClientBuilder = Mockito.mockStatic(AmazonS3ClientBuilder.class); - mockedS3ClientBuilder.when(()-> AmazonS3ClientBuilder.standard()).thenReturn(clientBuilder); - Mockito.doReturn(clientBuilder).when(clientBuilder).withCredentials(any()); - Mockito.when(clientBuilder.build()).thenReturn(amazonS3Client); - Mockito.when(amazonS3Client.listObjects(any(), any())).thenReturn(oList); - S3Object s3Obj = new S3Object(); - s3Obj.setObjectContent(new ByteArrayInputStream("largeContentFile".getBytes())); - Mockito.when(amazonS3Client.getObject(any())).thenReturn(s3Obj); - S3ObjectSummary s3 = new S3ObjectSummary(); - s3.setKey("test/key"); - Mockito.when(oList.getObjectSummaries()).thenReturn(com.google.common.collect.ImmutableList.of(s3)); + S3Object s3 = S3Object.builder().key("test/key").build(); + Mockito.when(amazonS3Client.listObjects(any(ListObjectsRequest.class))).thenReturn(ListObjectsResponse.builder() + .contents(s3) + .build()); + Mockito.when(amazonS3Client.getObject(any(GetObjectRequest.class))) + .thenReturn(new ResponseInputStream<>(GetObjectResponse.builder().build(), new ByteArrayInputStream("largeContentFile".getBytes()))); } @After public void tearDown() { - mockedObjectMapper.close(); mockedDefaultCredentials.close(); mockedS3ClientBuilder.close(); mockedSecretsManagerClient.close(); diff --git a/athena-mysql/src/main/java/com/amazonaws/athena/connectors/mysql/MySqlMuxRecordHandler.java b/athena-mysql/src/main/java/com/amazonaws/athena/connectors/mysql/MySqlMuxRecordHandler.java index 659262750e..a159921eed 100644 --- a/athena-mysql/src/main/java/com/amazonaws/athena/connectors/mysql/MySqlMuxRecordHandler.java +++ b/athena-mysql/src/main/java/com/amazonaws/athena/connectors/mysql/MySqlMuxRecordHandler.java @@ -25,9 +25,9 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandlerFactory; -import com.amazonaws.services.s3.AmazonS3; import com.google.common.annotations.VisibleForTesting; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -57,7 +57,7 @@ public MySqlMuxRecordHandler(java.util.Map configOptions) } @VisibleForTesting - MySqlMuxRecordHandler(AmazonS3 amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, + MySqlMuxRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, DatabaseConnectionConfig databaseConnectionConfig, Map recordHandlerMap, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, jdbcConnectionFactory, databaseConnectionConfig, recordHandlerMap, configOptions); diff --git a/athena-mysql/src/main/java/com/amazonaws/athena/connectors/mysql/MySqlRecordHandler.java b/athena-mysql/src/main/java/com/amazonaws/athena/connectors/mysql/MySqlRecordHandler.java index a7cb397c97..8acf177306 100644 --- a/athena-mysql/src/main/java/com/amazonaws/athena/connectors/mysql/MySqlRecordHandler.java +++ b/athena-mysql/src/main/java/com/amazonaws/athena/connectors/mysql/MySqlRecordHandler.java @@ -29,14 +29,13 @@ import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; import com.google.common.annotations.VisibleForTesting; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.commons.lang3.Validate; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; @@ -77,12 +76,12 @@ public MySqlRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, jav public MySqlRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, JdbcConnectionFactory jdbcConnectionFactory, java.util.Map configOptions) { - this(databaseConnectionConfig, AmazonS3ClientBuilder.defaultClient(), SecretsManagerClient.create(), AthenaClient.create(), + this(databaseConnectionConfig, S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(), jdbcConnectionFactory, new MySqlQueryStringBuilder(MYSQL_QUOTE_CHARACTER, new MySqlFederationExpressionParser(MYSQL_QUOTE_CHARACTER)), configOptions); } @VisibleForTesting - MySqlRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, final AmazonS3 amazonS3, final SecretsManagerClient secretsManager, + MySqlRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, final S3Client amazonS3, final SecretsManagerClient secretsManager, final AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, databaseConnectionConfig, jdbcConnectionFactory, configOptions); diff --git a/athena-mysql/src/test/java/com/amazonaws/athena/connectors/mysql/MySqlMuxJdbcRecordHandlerTest.java b/athena-mysql/src/test/java/com/amazonaws/athena/connectors/mysql/MySqlMuxJdbcRecordHandlerTest.java index fdb6f56ec7..ea9c543c0b 100644 --- a/athena-mysql/src/test/java/com/amazonaws/athena/connectors/mysql/MySqlMuxJdbcRecordHandlerTest.java +++ b/athena-mysql/src/test/java/com/amazonaws/athena/connectors/mysql/MySqlMuxJdbcRecordHandlerTest.java @@ -28,12 +28,12 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; -import com.amazonaws.services.s3.AmazonS3; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; @@ -46,7 +46,7 @@ public class MySqlMuxJdbcRecordHandlerTest private Map recordHandlerMap; private MySqlRecordHandler mySqlRecordHandler; private JdbcRecordHandler jdbcRecordHandler; - private AmazonS3 amazonS3; + private S3Client amazonS3; private SecretsManagerClient secretsManager; private AthenaClient athena; private QueryStatusChecker queryStatusChecker; @@ -57,7 +57,7 @@ public void setup() { this.mySqlRecordHandler = Mockito.mock(MySqlRecordHandler.class); this.recordHandlerMap = Collections.singletonMap("mysql", this.mySqlRecordHandler); - this.amazonS3 = Mockito.mock(AmazonS3.class); + this.amazonS3 = Mockito.mock(S3Client.class); this.secretsManager = Mockito.mock(SecretsManagerClient.class); this.athena = Mockito.mock(AthenaClient.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); diff --git a/athena-mysql/src/test/java/com/amazonaws/athena/connectors/mysql/MySqlRecordHandlerTest.java b/athena-mysql/src/test/java/com/amazonaws/athena/connectors/mysql/MySqlRecordHandlerTest.java index e4a40ff0b5..157c08228b 100644 --- a/athena-mysql/src/test/java/com/amazonaws/athena/connectors/mysql/MySqlRecordHandlerTest.java +++ b/athena-mysql/src/test/java/com/amazonaws/athena/connectors/mysql/MySqlRecordHandlerTest.java @@ -36,7 +36,6 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.s3.AmazonS3; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.apache.arrow.vector.types.Types; @@ -47,6 +46,7 @@ import org.junit.Test; import org.mockito.Mockito; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; @@ -66,7 +66,7 @@ public class MySqlRecordHandlerTest private Connection connection; private JdbcConnectionFactory jdbcConnectionFactory; private JdbcSplitQueryBuilder jdbcSplitQueryBuilder; - private AmazonS3 amazonS3; + private S3Client amazonS3; private SecretsManagerClient secretsManager; private AthenaClient athena; @@ -74,7 +74,7 @@ public class MySqlRecordHandlerTest public void setup() throws Exception { - this.amazonS3 = Mockito.mock(AmazonS3.class); + this.amazonS3 = Mockito.mock(S3Client.class); this.secretsManager = Mockito.mock(SecretsManagerClient.class); this.athena = Mockito.mock(AthenaClient.class); this.connection = Mockito.mock(Connection.class); diff --git a/athena-neptune/src/main/java/com/amazonaws/athena/connectors/neptune/NeptuneRecordHandler.java b/athena-neptune/src/main/java/com/amazonaws/athena/connectors/neptune/NeptuneRecordHandler.java index a61278292a..2456b3aa27 100644 --- a/athena-neptune/src/main/java/com/amazonaws/athena/connectors/neptune/NeptuneRecordHandler.java +++ b/athena-neptune/src/main/java/com/amazonaws/athena/connectors/neptune/NeptuneRecordHandler.java @@ -26,12 +26,11 @@ import com.amazonaws.athena.connectors.neptune.Enums.GraphType; import com.amazonaws.athena.connectors.neptune.propertygraph.PropertyGraphHandler; import com.amazonaws.athena.connectors.neptune.rdf.RDFHandler; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; import org.apache.arrow.util.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; /** @@ -63,7 +62,7 @@ public class NeptuneRecordHandler extends RecordHandler public NeptuneRecordHandler(java.util.Map configOptions) { this( - AmazonS3ClientBuilder.defaultClient(), + S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(), NeptuneConnection.createConnection(configOptions), @@ -72,7 +71,7 @@ public NeptuneRecordHandler(java.util.Map configOptions) @VisibleForTesting protected NeptuneRecordHandler( - AmazonS3 amazonS3, + S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient amazonAthena, NeptuneConnection neptuneConnection, diff --git a/athena-neptune/src/test/java/com/amazonaws/athena/connectors/neptune/NeptuneRecordHandlerTest.java b/athena-neptune/src/test/java/com/amazonaws/athena/connectors/neptune/NeptuneRecordHandlerTest.java index 13ff521aaf..bde646b1d3 100644 --- a/athena-neptune/src/test/java/com/amazonaws/athena/connectors/neptune/NeptuneRecordHandlerTest.java +++ b/athena-neptune/src/test/java/com/amazonaws/athena/connectors/neptune/NeptuneRecordHandlerTest.java @@ -46,11 +46,6 @@ import com.amazonaws.athena.connector.lambda.records.RemoteReadRecordsResponse; import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory; import com.amazonaws.athena.connector.lambda.security.LocalKeyFactory; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.PutObjectRequest; -import com.amazonaws.services.s3.model.PutObjectResult; -import com.amazonaws.services.s3.model.S3Object; -import com.amazonaws.services.s3.model.S3ObjectInputStream; import com.google.common.io.ByteStreams; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Schema; @@ -76,6 +71,13 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.io.ByteArrayInputStream; @@ -97,7 +99,7 @@ public class NeptuneRecordHandlerTest extends TestBase { private Schema schemaPGVertexForRead; private Schema schemaPGEdgeForRead; private Schema schemaPGQueryForRead; - private AmazonS3 amazonS3; + private S3Client amazonS3; private SecretsManagerClient awsSecretsManager; private AthenaClient athena; private S3BlockSpillReader spillReader; @@ -164,34 +166,32 @@ public void setUp() { .build(); allocator = new BlockAllocatorImpl(); - amazonS3 = mock(AmazonS3.class); + amazonS3 = mock(S3Client.class); awsSecretsManager = mock(SecretsManagerClient.class); athena = mock(AthenaClient.class); - when(amazonS3.putObject(any())) + when(amazonS3.putObject(any(PutObjectRequest.class), any(RequestBody.class))) .thenAnswer((InvocationOnMock invocationOnMock) -> { - InputStream inputStream = ((PutObjectRequest) invocationOnMock.getArguments()[0]).getInputStream(); + InputStream inputStream = ((RequestBody) invocationOnMock.getArguments()[1]).contentStreamProvider().newStream(); ByteHolder byteHolder = new ByteHolder(); byteHolder.setBytes(ByteStreams.toByteArray(inputStream)); synchronized (mockS3Storage) { mockS3Storage.add(byteHolder); logger.info("puObject: total size " + mockS3Storage.size()); } - return mock(PutObjectResult.class); + return PutObjectResponse.builder().build(); }); - when(amazonS3.getObject(nullable(String.class), nullable(String.class))).thenAnswer((InvocationOnMock invocationOnMock) -> { - S3Object mockObject = mock(S3Object.class); - ByteHolder byteHolder; - synchronized (mockS3Storage) { - byteHolder = mockS3Storage.get(0); - mockS3Storage.remove(0); - logger.info("getObject: total size " + mockS3Storage.size()); - } - when(mockObject.getObjectContent()).thenReturn( - new S3ObjectInputStream(new ByteArrayInputStream(byteHolder.getBytes()), null)); - return mockObject; - }); + when(amazonS3.getObject(any(GetObjectRequest.class))) + .thenAnswer((InvocationOnMock invocationOnMock) -> { + ByteHolder byteHolder; + synchronized (mockS3Storage) { + byteHolder = mockS3Storage.get(0); + mockS3Storage.remove(0); + logger.info("getObject: total size " + mockS3Storage.size()); + } + return new ResponseInputStream<>(GetObjectResponse.builder().build(), new ByteArrayInputStream(byteHolder.getBytes())); + }); handler = new NeptuneRecordHandler(amazonS3, awsSecretsManager, athena, neptuneConnection, com.google.common.collect.ImmutableMap.of()); spillReader = new S3BlockSpillReader(amazonS3, allocator); diff --git a/athena-oracle/src/main/java/com/amazonaws/athena/connectors/oracle/OracleMuxRecordHandler.java b/athena-oracle/src/main/java/com/amazonaws/athena/connectors/oracle/OracleMuxRecordHandler.java index 0b6d2f00b2..9a1f0a09ae 100644 --- a/athena-oracle/src/main/java/com/amazonaws/athena/connectors/oracle/OracleMuxRecordHandler.java +++ b/athena-oracle/src/main/java/com/amazonaws/athena/connectors/oracle/OracleMuxRecordHandler.java @@ -25,9 +25,9 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandlerFactory; -import com.amazonaws.services.s3.AmazonS3; import com.google.common.annotations.VisibleForTesting; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -57,7 +57,7 @@ public OracleMuxRecordHandler(java.util.Map configOptions) } @VisibleForTesting - OracleMuxRecordHandler(AmazonS3 amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, + OracleMuxRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, DatabaseConnectionConfig databaseConnectionConfig, Map recordHandlerMap, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, jdbcConnectionFactory, databaseConnectionConfig, recordHandlerMap, configOptions); diff --git a/athena-oracle/src/main/java/com/amazonaws/athena/connectors/oracle/OracleRecordHandler.java b/athena-oracle/src/main/java/com/amazonaws/athena/connectors/oracle/OracleRecordHandler.java index 9d4f94bc30..c312b87f5b 100644 --- a/athena-oracle/src/main/java/com/amazonaws/athena/connectors/oracle/OracleRecordHandler.java +++ b/athena-oracle/src/main/java/com/amazonaws/athena/connectors/oracle/OracleRecordHandler.java @@ -28,14 +28,13 @@ import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; import com.google.common.annotations.VisibleForTesting; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.commons.lang3.Validate; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; @@ -75,12 +74,12 @@ public OracleRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, ja public OracleRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, JdbcConnectionFactory jdbcConnectionFactory, java.util.Map configOptions) { - this(databaseConnectionConfig, AmazonS3ClientBuilder.defaultClient(), SecretsManagerClient.create(), AthenaClient.create(), + this(databaseConnectionConfig, S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(), jdbcConnectionFactory, new OracleQueryStringBuilder(ORACLE_QUOTE_CHARACTER, new OracleFederationExpressionParser(ORACLE_QUOTE_CHARACTER)), configOptions); } @VisibleForTesting - OracleRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, final AmazonS3 amazonS3, final SecretsManagerClient secretsManager, + OracleRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, final S3Client amazonS3, final SecretsManagerClient secretsManager, final AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, databaseConnectionConfig, jdbcConnectionFactory, configOptions); diff --git a/athena-oracle/src/test/java/com/amazonaws/athena/connectors/oracle/OracleMuxJdbcMetadataHandlerTest.java b/athena-oracle/src/test/java/com/amazonaws/athena/connectors/oracle/OracleMuxJdbcMetadataHandlerTest.java index 67907d9b31..537cc0c969 100644 --- a/athena-oracle/src/test/java/com/amazonaws/athena/connectors/oracle/OracleMuxJdbcMetadataHandlerTest.java +++ b/athena-oracle/src/test/java/com/amazonaws/athena/connectors/oracle/OracleMuxJdbcMetadataHandlerTest.java @@ -32,8 +32,6 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcMetadataHandler; -import com.amazonaws.athena.connectors.oracle.OracleMetadataHandler; -import com.amazonaws.athena.connectors.oracle.OracleMuxMetadataHandler; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; diff --git a/athena-oracle/src/test/java/com/amazonaws/athena/connectors/oracle/OracleMuxJdbcRecordHandlerTest.java b/athena-oracle/src/test/java/com/amazonaws/athena/connectors/oracle/OracleMuxJdbcRecordHandlerTest.java index 485c1f0ba6..1ec10050b3 100644 --- a/athena-oracle/src/test/java/com/amazonaws/athena/connectors/oracle/OracleMuxJdbcRecordHandlerTest.java +++ b/athena-oracle/src/test/java/com/amazonaws/athena/connectors/oracle/OracleMuxJdbcRecordHandlerTest.java @@ -30,12 +30,12 @@ import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.oracle.OracleMuxRecordHandler; import com.amazonaws.athena.connectors.oracle.OracleRecordHandler; -import com.amazonaws.services.s3.AmazonS3; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; @@ -48,7 +48,7 @@ public class OracleMuxJdbcRecordHandlerTest private Map recordHandlerMap; private OracleRecordHandler oracleRecordHandler; private JdbcRecordHandler jdbcRecordHandler; - private AmazonS3 amazonS3; + private S3Client amazonS3; private SecretsManagerClient secretsManager; private AthenaClient athena; private QueryStatusChecker queryStatusChecker; @@ -59,7 +59,7 @@ public void setup() { this.oracleRecordHandler = Mockito.mock(OracleRecordHandler.class); this.recordHandlerMap = Collections.singletonMap("oracle", this.oracleRecordHandler); - this.amazonS3 = Mockito.mock(AmazonS3.class); + this.amazonS3 = Mockito.mock(S3Client.class); this.secretsManager = Mockito.mock(SecretsManagerClient.class); this.athena = Mockito.mock(AthenaClient.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); diff --git a/athena-oracle/src/test/java/com/amazonaws/athena/connectors/oracle/OracleRecordHandlerTest.java b/athena-oracle/src/test/java/com/amazonaws/athena/connectors/oracle/OracleRecordHandlerTest.java index 4d0a887602..2e2f026297 100644 --- a/athena-oracle/src/test/java/com/amazonaws/athena/connectors/oracle/OracleRecordHandlerTest.java +++ b/athena-oracle/src/test/java/com/amazonaws/athena/connectors/oracle/OracleRecordHandlerTest.java @@ -32,7 +32,6 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.s3.AmazonS3; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.apache.arrow.vector.types.Types; @@ -42,6 +41,7 @@ import org.junit.Test; import org.mockito.Mockito; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; @@ -58,7 +58,7 @@ public class OracleRecordHandlerTest private Connection connection; private JdbcConnectionFactory jdbcConnectionFactory; private JdbcSplitQueryBuilder jdbcSplitQueryBuilder; - private AmazonS3 amazonS3; + private S3Client amazonS3; private SecretsManagerClient secretsManager; private AthenaClient athena; @@ -69,7 +69,7 @@ public class OracleRecordHandlerTest public void setup() throws Exception { - this.amazonS3 = Mockito.mock(AmazonS3.class); + this.amazonS3 = Mockito.mock(S3Client.class); this.secretsManager = Mockito.mock(SecretsManagerClient.class); this.athena = Mockito.mock(AthenaClient.class); this.connection = Mockito.mock(Connection.class); diff --git a/athena-postgresql/src/main/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlMuxRecordHandler.java b/athena-postgresql/src/main/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlMuxRecordHandler.java index b6adea0a75..8b98b0813f 100644 --- a/athena-postgresql/src/main/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlMuxRecordHandler.java +++ b/athena-postgresql/src/main/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlMuxRecordHandler.java @@ -25,9 +25,9 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandlerFactory; -import com.amazonaws.services.s3.AmazonS3; import com.google.common.annotations.VisibleForTesting; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -57,7 +57,7 @@ public PostGreSqlMuxRecordHandler(java.util.Map configOptions) } @VisibleForTesting - PostGreSqlMuxRecordHandler(AmazonS3 amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, + PostGreSqlMuxRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, DatabaseConnectionConfig databaseConnectionConfig, Map recordHandlerMap, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, jdbcConnectionFactory, databaseConnectionConfig, recordHandlerMap, configOptions); diff --git a/athena-postgresql/src/main/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlRecordHandler.java b/athena-postgresql/src/main/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlRecordHandler.java index d054ff3d06..0c89828a66 100644 --- a/athena-postgresql/src/main/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlRecordHandler.java +++ b/athena-postgresql/src/main/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlRecordHandler.java @@ -29,14 +29,13 @@ import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; import com.google.common.annotations.VisibleForTesting; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.commons.lang3.Validate; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; @@ -69,12 +68,12 @@ public PostGreSqlRecordHandler(java.util.Map configOptions) public PostGreSqlRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, java.util.Map configOptions) { - this(databaseConnectionConfig, AmazonS3ClientBuilder.defaultClient(), SecretsManagerClient.create(), AthenaClient.create(), + this(databaseConnectionConfig, S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(), new GenericJdbcConnectionFactory(databaseConnectionConfig, PostGreSqlMetadataHandler.JDBC_PROPERTIES, new DatabaseConnectionInfo(POSTGRESQL_DRIVER_CLASS, POSTGRESQL_DEFAULT_PORT)), new PostGreSqlQueryStringBuilder(POSTGRES_QUOTE_CHARACTER, new PostgreSqlFederationExpressionParser(POSTGRES_QUOTE_CHARACTER)), configOptions); } @VisibleForTesting - protected PostGreSqlRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, AmazonS3 amazonS3, SecretsManagerClient secretsManager, + protected PostGreSqlRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, databaseConnectionConfig, jdbcConnectionFactory, configOptions); diff --git a/athena-postgresql/src/test/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlMuxJdbcRecordHandlerTest.java b/athena-postgresql/src/test/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlMuxJdbcRecordHandlerTest.java index dd7a6a7736..eadd042db0 100644 --- a/athena-postgresql/src/test/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlMuxJdbcRecordHandlerTest.java +++ b/athena-postgresql/src/test/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlMuxJdbcRecordHandlerTest.java @@ -28,12 +28,12 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; -import com.amazonaws.services.s3.AmazonS3; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; @@ -46,7 +46,7 @@ public class PostGreSqlMuxJdbcRecordHandlerTest private Map recordHandlerMap; private PostGreSqlRecordHandler postGreSqlRecordHandler; private JdbcRecordHandler jdbcRecordHandler; - private AmazonS3 amazonS3; + private S3Client amazonS3; private SecretsManagerClient secretsManager; private AthenaClient athena; private QueryStatusChecker queryStatusChecker; @@ -57,7 +57,7 @@ public void setup() { this.postGreSqlRecordHandler = Mockito.mock(PostGreSqlRecordHandler.class); this.recordHandlerMap = Collections.singletonMap("postgres", this.postGreSqlRecordHandler); - this.amazonS3 = Mockito.mock(AmazonS3.class); + this.amazonS3 = Mockito.mock(S3Client.class); this.secretsManager = Mockito.mock(SecretsManagerClient.class); this.athena = Mockito.mock(AthenaClient.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); diff --git a/athena-postgresql/src/test/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlRecordHandlerTest.java b/athena-postgresql/src/test/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlRecordHandlerTest.java index 96b879d41e..123093f31f 100644 --- a/athena-postgresql/src/test/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlRecordHandlerTest.java +++ b/athena-postgresql/src/test/java/com/amazonaws/athena/connectors/postgresql/PostGreSqlRecordHandlerTest.java @@ -33,7 +33,6 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.s3.AmazonS3; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.apache.arrow.vector.types.Types; @@ -46,6 +45,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.math.BigDecimal; @@ -68,7 +68,7 @@ public class PostGreSqlRecordHandlerTest extends TestBase private Connection connection; private JdbcConnectionFactory jdbcConnectionFactory; private JdbcSplitQueryBuilder jdbcSplitQueryBuilder; - private AmazonS3 amazonS3; + private S3Client amazonS3; private SecretsManagerClient secretsManager; private AthenaClient athena; @@ -76,7 +76,7 @@ public class PostGreSqlRecordHandlerTest extends TestBase public void setup() throws Exception { - this.amazonS3 = Mockito.mock(AmazonS3.class); + this.amazonS3 = Mockito.mock(S3Client.class); this.secretsManager = Mockito.mock(SecretsManagerClient.class); this.athena = Mockito.mock(AthenaClient.class); this.connection = Mockito.mock(Connection.class); diff --git a/athena-redis/pom.xml b/athena-redis/pom.xml index b4faced6d0..fa13974351 100644 --- a/athena-redis/pom.xml +++ b/athena-redis/pom.xml @@ -9,33 +9,6 @@ athena-redis 2022.47.1 - - com.amazonaws - aws-java-sdk-core - ${aws-sdk.version} - - - com.fasterxml.jackson.datatype - jackson-datatype-jsr310 - - - com.fasterxml.jackson.dataformat - jackson-dataformat-cbor - - - com.fasterxml.jackson.core - jackson-core - - - com.fasterxml.jackson.core - jackson-databind - - - com.fasterxml.jackson.core - jackson-annotations - - - software.amazon.jsii jsii-runtime diff --git a/athena-redis/src/main/java/com/amazonaws/athena/connectors/redis/RedisRecordHandler.java b/athena-redis/src/main/java/com/amazonaws/athena/connectors/redis/RedisRecordHandler.java index 837e2decb1..5981aface2 100644 --- a/athena-redis/src/main/java/com/amazonaws/athena/connectors/redis/RedisRecordHandler.java +++ b/athena-redis/src/main/java/com/amazonaws/athena/connectors/redis/RedisRecordHandler.java @@ -29,8 +29,6 @@ import com.amazonaws.athena.connectors.redis.lettuce.RedisConnectionFactory; import com.amazonaws.athena.connectors.redis.lettuce.RedisConnectionWrapper; import com.amazonaws.athena.connectors.redis.qpt.RedisQueryPassthrough; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; import io.lettuce.core.KeyScanCursor; import io.lettuce.core.ScanArgs; import io.lettuce.core.ScanCursor; @@ -42,6 +40,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.HashMap; @@ -86,14 +85,14 @@ public class RedisRecordHandler private static final int SCAN_COUNT_SIZE = 100; private final RedisConnectionFactory redisConnectionFactory; - private final AmazonS3 amazonS3; + private final S3Client amazonS3; private final RedisQueryPassthrough queryPassthrough = new RedisQueryPassthrough(); public RedisRecordHandler(java.util.Map configOptions) { this( - AmazonS3ClientBuilder.standard().build(), + S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(), new RedisConnectionFactory(), @@ -101,7 +100,7 @@ public RedisRecordHandler(java.util.Map configOptions) } @VisibleForTesting - protected RedisRecordHandler(AmazonS3 amazonS3, + protected RedisRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, RedisConnectionFactory redisConnectionFactory, diff --git a/athena-redis/src/test/java/com/amazonaws/athena/connectors/redis/RedisRecordHandlerTest.java b/athena-redis/src/test/java/com/amazonaws/athena/connectors/redis/RedisRecordHandlerTest.java index b0c774c423..d330af3ca3 100644 --- a/athena-redis/src/test/java/com/amazonaws/athena/connectors/redis/RedisRecordHandlerTest.java +++ b/athena-redis/src/test/java/com/amazonaws/athena/connectors/redis/RedisRecordHandlerTest.java @@ -40,11 +40,6 @@ import com.amazonaws.athena.connectors.redis.lettuce.RedisConnectionWrapper; import com.amazonaws.athena.connectors.redis.util.MockKeyScanCursor; import com.amazonaws.athena.connectors.redis.util.MockScoredValueScanCursor; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.PutObjectRequest; -import com.amazonaws.services.s3.model.PutObjectResult; -import com.amazonaws.services.s3.model.S3Object; -import com.amazonaws.services.s3.model.S3ObjectInputStream; import com.google.common.collect.ImmutableList; import com.google.common.io.ByteStreams; import io.lettuce.core.ScanArgs; @@ -66,6 +61,13 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; @@ -105,7 +107,7 @@ public class RedisRecordHandlerTest private RedisRecordHandler handler; private BlockAllocator allocator; private List mockS3Storage = new ArrayList<>(); - private AmazonS3 amazonS3; + private S3Client amazonS3; private S3BlockSpillReader spillReader; private EncryptionKeyFactory keyFactory = new LocalKeyFactory(); @@ -137,33 +139,29 @@ public void setUp() allocator = new BlockAllocatorImpl(); - amazonS3 = mock(AmazonS3.class); + amazonS3 = mock(S3Client.class); - Mockito.lenient().when(amazonS3.putObject(any())) + Mockito.lenient().when(amazonS3.putObject(any(PutObjectRequest.class), any(RequestBody.class))) .thenAnswer((InvocationOnMock invocationOnMock) -> { - InputStream inputStream = ((PutObjectRequest) invocationOnMock.getArguments()[0]).getInputStream(); + InputStream inputStream = ((RequestBody) invocationOnMock.getArguments()[1]).contentStreamProvider().newStream(); ByteHolder byteHolder = new ByteHolder(); byteHolder.setBytes(ByteStreams.toByteArray(inputStream)); synchronized (mockS3Storage) { mockS3Storage.add(byteHolder); logger.info("puObject: total size " + mockS3Storage.size()); } - return mock(PutObjectResult.class); + return PutObjectResponse.builder().build(); }); - Mockito.lenient().when(amazonS3.getObject(nullable(String.class), nullable(String.class))) + Mockito.lenient().when(amazonS3.getObject(any(GetObjectRequest.class))) .thenAnswer((InvocationOnMock invocationOnMock) -> { - S3Object mockObject = mock(S3Object.class); ByteHolder byteHolder; synchronized (mockS3Storage) { byteHolder = mockS3Storage.get(0); mockS3Storage.remove(0); logger.info("getObject: total size " + mockS3Storage.size()); } - when(mockObject.getObjectContent()).thenReturn( - new S3ObjectInputStream( - new ByteArrayInputStream(byteHolder.getBytes()), null)); - return mockObject; + return new ResponseInputStream<>(GetObjectResponse.builder().build(), new ByteArrayInputStream(byteHolder.getBytes())); }); when(mockSecretsManager.getSecretValue(nullable(GetSecretValueRequest.class))) diff --git a/athena-redshift/src/main/java/com/amazonaws/athena/connectors/redshift/RedshiftMuxRecordHandler.java b/athena-redshift/src/main/java/com/amazonaws/athena/connectors/redshift/RedshiftMuxRecordHandler.java index 38b1b26a24..2fe7b8fa3e 100644 --- a/athena-redshift/src/main/java/com/amazonaws/athena/connectors/redshift/RedshiftMuxRecordHandler.java +++ b/athena-redshift/src/main/java/com/amazonaws/athena/connectors/redshift/RedshiftMuxRecordHandler.java @@ -25,9 +25,9 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandlerFactory; -import com.amazonaws.services.s3.AmazonS3; import com.google.common.annotations.VisibleForTesting; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -58,7 +58,7 @@ public RedshiftMuxRecordHandler(java.util.Map configOptions) } @VisibleForTesting - RedshiftMuxRecordHandler(AmazonS3 amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, + RedshiftMuxRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, DatabaseConnectionConfig databaseConnectionConfig, Map recordHandlerMap, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, jdbcConnectionFactory, databaseConnectionConfig, recordHandlerMap, configOptions); diff --git a/athena-redshift/src/main/java/com/amazonaws/athena/connectors/redshift/RedshiftRecordHandler.java b/athena-redshift/src/main/java/com/amazonaws/athena/connectors/redshift/RedshiftRecordHandler.java index cdb9db0954..8684595097 100644 --- a/athena-redshift/src/main/java/com/amazonaws/athena/connectors/redshift/RedshiftRecordHandler.java +++ b/athena-redshift/src/main/java/com/amazonaws/athena/connectors/redshift/RedshiftRecordHandler.java @@ -30,12 +30,11 @@ import com.amazonaws.athena.connectors.postgresql.PostGreSqlQueryStringBuilder; import com.amazonaws.athena.connectors.postgresql.PostGreSqlRecordHandler; import com.amazonaws.athena.connectors.postgresql.PostgreSqlFederationExpressionParser; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; import org.apache.arrow.util.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import static com.amazonaws.athena.connectors.postgresql.PostGreSqlConstants.POSTGRES_QUOTE_CHARACTER; @@ -60,12 +59,12 @@ public RedshiftRecordHandler(java.util.Map configOptions) public RedshiftRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, java.util.Map configOptions) { - super(databaseConnectionConfig, AmazonS3ClientBuilder.defaultClient(), SecretsManagerClient.create(), AthenaClient.create(), + super(databaseConnectionConfig, S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(), new GenericJdbcConnectionFactory(databaseConnectionConfig, PostGreSqlMetadataHandler.JDBC_PROPERTIES, new DatabaseConnectionInfo(REDSHIFT_DRIVER_CLASS, REDSHIFT_DEFAULT_PORT)), new PostGreSqlQueryStringBuilder(POSTGRES_QUOTE_CHARACTER, new PostgreSqlFederationExpressionParser(POSTGRES_QUOTE_CHARACTER)), configOptions); } @VisibleForTesting - RedshiftRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, AmazonS3 amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) + RedshiftRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) { super(databaseConnectionConfig, amazonS3, secretsManager, athena, jdbcConnectionFactory, jdbcSplitQueryBuilder, configOptions); } diff --git a/athena-redshift/src/test/java/com/amazonaws/athena/connectors/redshift/RedshiftMuxJdbcRecordHandlerTest.java b/athena-redshift/src/test/java/com/amazonaws/athena/connectors/redshift/RedshiftMuxJdbcRecordHandlerTest.java index 4e0ff391cf..2a4abf05bc 100644 --- a/athena-redshift/src/test/java/com/amazonaws/athena/connectors/redshift/RedshiftMuxJdbcRecordHandlerTest.java +++ b/athena-redshift/src/test/java/com/amazonaws/athena/connectors/redshift/RedshiftMuxJdbcRecordHandlerTest.java @@ -28,12 +28,12 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; -import com.amazonaws.services.s3.AmazonS3; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; @@ -46,7 +46,7 @@ public class RedshiftMuxJdbcRecordHandlerTest private Map recordHandlerMap; private RedshiftRecordHandler redshiftRecordHandler; private JdbcRecordHandler jdbcRecordHandler; - private AmazonS3 amazonS3; + private S3Client amazonS3; private SecretsManagerClient secretsManager; private AthenaClient athena; private QueryStatusChecker queryStatusChecker; @@ -57,7 +57,7 @@ public void setup() { this.redshiftRecordHandler = Mockito.mock(RedshiftRecordHandler.class); this.recordHandlerMap = Collections.singletonMap("redshift", this.redshiftRecordHandler); - this.amazonS3 = Mockito.mock(AmazonS3.class); + this.amazonS3 = Mockito.mock(S3Client.class); this.secretsManager = Mockito.mock(SecretsManagerClient.class); this.athena = Mockito.mock(AthenaClient.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); diff --git a/athena-redshift/src/test/java/com/amazonaws/athena/connectors/redshift/RedshiftRecordHandlerTest.java b/athena-redshift/src/test/java/com/amazonaws/athena/connectors/redshift/RedshiftRecordHandlerTest.java index d9e2508b22..c9242f17a8 100644 --- a/athena-redshift/src/test/java/com/amazonaws/athena/connectors/redshift/RedshiftRecordHandlerTest.java +++ b/athena-redshift/src/test/java/com/amazonaws/athena/connectors/redshift/RedshiftRecordHandlerTest.java @@ -36,7 +36,6 @@ import com.amazonaws.athena.connectors.postgresql.PostGreSqlMetadataHandler; import com.amazonaws.athena.connectors.postgresql.PostGreSqlQueryStringBuilder; import com.amazonaws.athena.connectors.postgresql.PostgreSqlFederationExpressionParser; -import com.amazonaws.services.s3.AmazonS3; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.apache.arrow.vector.types.Types; @@ -49,6 +48,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.math.BigDecimal; @@ -71,7 +71,7 @@ public class RedshiftRecordHandlerTest private Connection connection; private JdbcConnectionFactory jdbcConnectionFactory; private JdbcSplitQueryBuilder jdbcSplitQueryBuilder; - private AmazonS3 amazonS3; + private S3Client amazonS3; private SecretsManagerClient secretsManager; private AthenaClient athena; @@ -79,7 +79,7 @@ public class RedshiftRecordHandlerTest public void setup() throws Exception { - this.amazonS3 = Mockito.mock(AmazonS3.class); + this.amazonS3 = Mockito.mock(S3Client.class); this.secretsManager = Mockito.mock(SecretsManagerClient.class); this.athena = Mockito.mock(AthenaClient.class); this.connection = Mockito.mock(Connection.class); diff --git a/athena-saphana/src/main/java/com/amazonaws/athena/connectors/saphana/SaphanaMuxRecordHandler.java b/athena-saphana/src/main/java/com/amazonaws/athena/connectors/saphana/SaphanaMuxRecordHandler.java index a61d340fc7..2414854794 100644 --- a/athena-saphana/src/main/java/com/amazonaws/athena/connectors/saphana/SaphanaMuxRecordHandler.java +++ b/athena-saphana/src/main/java/com/amazonaws/athena/connectors/saphana/SaphanaMuxRecordHandler.java @@ -25,9 +25,9 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandlerFactory; -import com.amazonaws.services.s3.AmazonS3; import com.google.common.annotations.VisibleForTesting; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -55,7 +55,7 @@ public SaphanaMuxRecordHandler(java.util.Map configOptions) } @VisibleForTesting - SaphanaMuxRecordHandler(AmazonS3 amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, + SaphanaMuxRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, DatabaseConnectionConfig databaseConnectionConfig, Map recordHandlerMap, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, jdbcConnectionFactory, databaseConnectionConfig, recordHandlerMap, configOptions); diff --git a/athena-saphana/src/main/java/com/amazonaws/athena/connectors/saphana/SaphanaRecordHandler.java b/athena-saphana/src/main/java/com/amazonaws/athena/connectors/saphana/SaphanaRecordHandler.java index 750bdf2434..67f7a93e6a 100644 --- a/athena-saphana/src/main/java/com/amazonaws/athena/connectors/saphana/SaphanaRecordHandler.java +++ b/athena-saphana/src/main/java/com/amazonaws/athena/connectors/saphana/SaphanaRecordHandler.java @@ -31,8 +31,6 @@ import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; import com.google.common.annotations.VisibleForTesting; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; @@ -40,6 +38,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; @@ -66,7 +65,7 @@ public SaphanaRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, j SaphanaConstants.SAPHANA_DEFAULT_PORT)), configOptions); } @VisibleForTesting - SaphanaRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, AmazonS3 amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) + SaphanaRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, databaseConnectionConfig, jdbcConnectionFactory, configOptions); this.jdbcSplitQueryBuilder = Validate.notNull(jdbcSplitQueryBuilder, "query builder must not be null"); @@ -74,7 +73,7 @@ public SaphanaRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, j public SaphanaRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, GenericJdbcConnectionFactory jdbcConnectionFactory, java.util.Map configOptions) { - this(databaseConnectionConfig, AmazonS3ClientBuilder.defaultClient(), SecretsManagerClient.create(), + this(databaseConnectionConfig, S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(), jdbcConnectionFactory, new SaphanaQueryStringBuilder(SAPHANA_QUOTE_CHARACTER, new SaphanaFederationExpressionParser(SAPHANA_QUOTE_CHARACTER)), configOptions); } diff --git a/athena-saphana/src/test/java/com/amazonaws/athena/connectors/saphana/SaphanaMuxJdbcRecordHandlerTest.java b/athena-saphana/src/test/java/com/amazonaws/athena/connectors/saphana/SaphanaMuxJdbcRecordHandlerTest.java index 4acddbca38..5f17964b44 100644 --- a/athena-saphana/src/test/java/com/amazonaws/athena/connectors/saphana/SaphanaMuxJdbcRecordHandlerTest.java +++ b/athena-saphana/src/test/java/com/amazonaws/athena/connectors/saphana/SaphanaMuxJdbcRecordHandlerTest.java @@ -28,12 +28,12 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; -import com.amazonaws.services.s3.AmazonS3; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; @@ -46,7 +46,7 @@ public class SaphanaMuxJdbcRecordHandlerTest private Map recordHandlerMap; private SaphanaRecordHandler saphanaRecordHandler; private JdbcRecordHandler jdbcRecordHandler; - private AmazonS3 amazonS3; + private S3Client amazonS3; private SecretsManagerClient secretsManager; private AthenaClient athena; private QueryStatusChecker queryStatusChecker; @@ -57,7 +57,7 @@ public void setup() { this.saphanaRecordHandler = Mockito.mock(SaphanaRecordHandler.class); this.recordHandlerMap = Collections.singletonMap("saphana", this.saphanaRecordHandler); - this.amazonS3 = Mockito.mock(AmazonS3.class); + this.amazonS3 = Mockito.mock(S3Client.class); this.secretsManager = Mockito.mock(SecretsManagerClient.class); this.athena = Mockito.mock(AthenaClient.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); diff --git a/athena-saphana/src/test/java/com/amazonaws/athena/connectors/saphana/SaphanaRecordHandlerTest.java b/athena-saphana/src/test/java/com/amazonaws/athena/connectors/saphana/SaphanaRecordHandlerTest.java index 80934a3f8d..c48ced9e6c 100644 --- a/athena-saphana/src/test/java/com/amazonaws/athena/connectors/saphana/SaphanaRecordHandlerTest.java +++ b/athena-saphana/src/test/java/com/amazonaws/athena/connectors/saphana/SaphanaRecordHandlerTest.java @@ -32,7 +32,6 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.s3.AmazonS3; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.apache.arrow.vector.types.Types; @@ -42,6 +41,7 @@ import org.junit.Test; import org.mockito.Mockito; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; @@ -58,7 +58,7 @@ public class SaphanaRecordHandlerTest private Connection connection; private JdbcConnectionFactory jdbcConnectionFactory; private JdbcSplitQueryBuilder jdbcSplitQueryBuilder; - private AmazonS3 amazonS3; + private S3Client amazonS3; private SecretsManagerClient secretsManager; private AthenaClient athena; @@ -66,7 +66,7 @@ public class SaphanaRecordHandlerTest public void setup() throws Exception { - this.amazonS3 = Mockito.mock(AmazonS3.class); + this.amazonS3 = Mockito.mock(S3Client.class); this.secretsManager = Mockito.mock(SecretsManagerClient.class); this.athena = Mockito.mock(AthenaClient.class); this.connection = Mockito.mock(Connection.class); diff --git a/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMuxRecordHandler.java b/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMuxRecordHandler.java index 3874591f69..2fb0812375 100644 --- a/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMuxRecordHandler.java +++ b/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMuxRecordHandler.java @@ -25,9 +25,9 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandlerFactory; -import com.amazonaws.services.s3.AmazonS3; import com.google.common.annotations.VisibleForTesting; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -55,7 +55,7 @@ public SnowflakeMuxRecordHandler(java.util.Map configOptions) } @VisibleForTesting - SnowflakeMuxRecordHandler(AmazonS3 amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, + SnowflakeMuxRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, DatabaseConnectionConfig databaseConnectionConfig, Map recordHandlerMap, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, jdbcConnectionFactory, databaseConnectionConfig, recordHandlerMap, configOptions); diff --git a/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeRecordHandler.java b/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeRecordHandler.java index 57acce2f23..28ac13ff21 100644 --- a/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeRecordHandler.java +++ b/athena-snowflake/src/main/java/com/amazonaws/athena/connectors/snowflake/SnowflakeRecordHandler.java @@ -30,12 +30,11 @@ import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; import com.google.common.annotations.VisibleForTesting; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.commons.lang3.Validate; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; @@ -66,11 +65,11 @@ public SnowflakeRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, } public SnowflakeRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, GenericJdbcConnectionFactory jdbcConnectionFactory, java.util.Map configOptions) { - this(databaseConnectionConfig, AmazonS3ClientBuilder.defaultClient(), SecretsManagerClient.create(), AthenaClient.create(), + this(databaseConnectionConfig, S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(), jdbcConnectionFactory, new SnowflakeQueryStringBuilder(SNOWFLAKE_QUOTE_CHARACTER, new SnowflakeFederationExpressionParser(SNOWFLAKE_QUOTE_CHARACTER)), configOptions); } @VisibleForTesting - SnowflakeRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, final AmazonS3 amazonS3, final SecretsManagerClient secretsManager, + SnowflakeRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, final S3Client amazonS3, final SecretsManagerClient secretsManager, final AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, databaseConnectionConfig, jdbcConnectionFactory, configOptions); diff --git a/athena-snowflake/src/test/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMetadataHandlerTest.java b/athena-snowflake/src/test/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMetadataHandlerTest.java index 3acd762712..6a219a3b1f 100644 --- a/athena-snowflake/src/test/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMetadataHandlerTest.java +++ b/athena-snowflake/src/test/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMetadataHandlerTest.java @@ -29,16 +29,16 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; -import software.amazon.awssdk.services.athena.AthenaClient; -import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; -import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; -import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; +import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; import java.sql.*; import java.util.*; diff --git a/athena-snowflake/src/test/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMuxJdbcRecordHandlerTest.java b/athena-snowflake/src/test/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMuxJdbcRecordHandlerTest.java index 2e7c0b70cb..367fde0afc 100644 --- a/athena-snowflake/src/test/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMuxJdbcRecordHandlerTest.java +++ b/athena-snowflake/src/test/java/com/amazonaws/athena/connectors/snowflake/SnowflakeMuxJdbcRecordHandlerTest.java @@ -30,12 +30,12 @@ import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.snowflake.SnowflakeMuxRecordHandler; import com.amazonaws.athena.connectors.snowflake.SnowflakeRecordHandler; -import com.amazonaws.services.s3.AmazonS3; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; @@ -50,7 +50,7 @@ public class SnowflakeMuxJdbcRecordHandlerTest private Map recordHandlerMap; private SnowflakeRecordHandler snowflakeRecordHandler; private JdbcRecordHandler jdbcRecordHandler; - private AmazonS3 amazonS3; + private S3Client amazonS3; private SecretsManagerClient secretsManager; private AthenaClient athena; private QueryStatusChecker queryStatusChecker; @@ -61,7 +61,7 @@ public void setup() { this.snowflakeRecordHandler = Mockito.mock(SnowflakeRecordHandler.class); this.recordHandlerMap = Collections.singletonMap("snowflake", this.snowflakeRecordHandler); - this.amazonS3 = Mockito.mock(AmazonS3.class); + this.amazonS3 = Mockito.mock(S3Client.class); this.secretsManager = Mockito.mock(SecretsManagerClient.class); this.athena = Mockito.mock(AthenaClient.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); diff --git a/athena-snowflake/src/test/java/com/amazonaws/athena/connectors/snowflake/SnowflakeRecordHandlerTest.java b/athena-snowflake/src/test/java/com/amazonaws/athena/connectors/snowflake/SnowflakeRecordHandlerTest.java index e7f4813d34..56531dcdac 100644 --- a/athena-snowflake/src/test/java/com/amazonaws/athena/connectors/snowflake/SnowflakeRecordHandlerTest.java +++ b/athena-snowflake/src/test/java/com/amazonaws/athena/connectors/snowflake/SnowflakeRecordHandlerTest.java @@ -33,7 +33,6 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.s3.AmazonS3; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.apache.arrow.vector.types.Types; @@ -43,6 +42,7 @@ import org.junit.Test; import org.mockito.Mockito; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; @@ -59,7 +59,7 @@ public class SnowflakeRecordHandlerTest private Connection connection; private JdbcConnectionFactory jdbcConnectionFactory; private JdbcSplitQueryBuilder jdbcSplitQueryBuilder; - private AmazonS3 amazonS3; + private S3Client amazonS3; private SecretsManagerClient secretsManager; private AthenaClient athena; @@ -67,7 +67,7 @@ public class SnowflakeRecordHandlerTest public void setup() throws Exception { - this.amazonS3 = Mockito.mock(AmazonS3.class); + this.amazonS3 = Mockito.mock(S3Client.class); this.secretsManager = Mockito.mock(SecretsManagerClient.class); this.athena = Mockito.mock(AthenaClient.class); this.connection = Mockito.mock(Connection.class); diff --git a/athena-sqlserver/src/main/java/com/amazonaws/athena/connectors/sqlserver/SqlServerMuxRecordHandler.java b/athena-sqlserver/src/main/java/com/amazonaws/athena/connectors/sqlserver/SqlServerMuxRecordHandler.java index 7282b19ac4..e9d5009639 100644 --- a/athena-sqlserver/src/main/java/com/amazonaws/athena/connectors/sqlserver/SqlServerMuxRecordHandler.java +++ b/athena-sqlserver/src/main/java/com/amazonaws/athena/connectors/sqlserver/SqlServerMuxRecordHandler.java @@ -25,9 +25,9 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandlerFactory; -import com.amazonaws.services.s3.AmazonS3; import com.google.common.annotations.VisibleForTesting; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -55,7 +55,7 @@ public SqlServerMuxRecordHandler(java.util.Map configOptions) } @VisibleForTesting - SqlServerMuxRecordHandler(AmazonS3 amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, + SqlServerMuxRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, DatabaseConnectionConfig databaseConnectionConfig, Map recordHandlerMap, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, jdbcConnectionFactory, databaseConnectionConfig, recordHandlerMap, configOptions); diff --git a/athena-sqlserver/src/main/java/com/amazonaws/athena/connectors/sqlserver/SqlServerRecordHandler.java b/athena-sqlserver/src/main/java/com/amazonaws/athena/connectors/sqlserver/SqlServerRecordHandler.java index 6bdd298a57..073f5ad946 100644 --- a/athena-sqlserver/src/main/java/com/amazonaws/athena/connectors/sqlserver/SqlServerRecordHandler.java +++ b/athena-sqlserver/src/main/java/com/amazonaws/athena/connectors/sqlserver/SqlServerRecordHandler.java @@ -29,12 +29,11 @@ import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; import com.google.common.annotations.VisibleForTesting; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.commons.lang3.Validate; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; @@ -61,12 +60,12 @@ public SqlServerRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, public SqlServerRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, JdbcConnectionFactory jdbcConnectionFactory, java.util.Map configOptions) { - this(databaseConnectionConfig, AmazonS3ClientBuilder.defaultClient(), SecretsManagerClient.create(), + this(databaseConnectionConfig, S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(), jdbcConnectionFactory, new SqlServerQueryStringBuilder(SQLSERVER_QUOTE_CHARACTER, new SqlServerFederationExpressionParser(SQLSERVER_QUOTE_CHARACTER)), configOptions); } @VisibleForTesting - SqlServerRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, final AmazonS3 amazonS3, final SecretsManagerClient secretsManager, + SqlServerRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, final S3Client amazonS3, final SecretsManagerClient secretsManager, final AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, databaseConnectionConfig, jdbcConnectionFactory, configOptions); diff --git a/athena-sqlserver/src/test/java/com/amazonaws/athena/connectors/sqlserver/SqlServerMuxRecordHandlerTest.java b/athena-sqlserver/src/test/java/com/amazonaws/athena/connectors/sqlserver/SqlServerMuxRecordHandlerTest.java index e5074306b3..e6faa255d9 100644 --- a/athena-sqlserver/src/test/java/com/amazonaws/athena/connectors/sqlserver/SqlServerMuxRecordHandlerTest.java +++ b/athena-sqlserver/src/test/java/com/amazonaws/athena/connectors/sqlserver/SqlServerMuxRecordHandlerTest.java @@ -28,12 +28,12 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; -import com.amazonaws.services.s3.AmazonS3; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; @@ -46,7 +46,7 @@ public class SqlServerMuxRecordHandlerTest private Map recordHandlerMap; private SqlServerRecordHandler sqlServerRecordHandler; private JdbcRecordHandler jdbcRecordHandler; - private AmazonS3 amazonS3; + private S3Client amazonS3; private SecretsManagerClient secretsManager; private AthenaClient athena; private QueryStatusChecker queryStatusChecker; @@ -57,7 +57,7 @@ public void setup() { this.sqlServerRecordHandler = Mockito.mock(SqlServerRecordHandler.class); this.recordHandlerMap = Collections.singletonMap(SqlServerConstants.NAME, this.sqlServerRecordHandler); - this.amazonS3 = Mockito.mock(AmazonS3.class); + this.amazonS3 = Mockito.mock(S3Client.class); this.secretsManager = Mockito.mock(SecretsManagerClient.class); this.athena = Mockito.mock(AthenaClient.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); diff --git a/athena-sqlserver/src/test/java/com/amazonaws/athena/connectors/sqlserver/SqlServerRecordHandlerTest.java b/athena-sqlserver/src/test/java/com/amazonaws/athena/connectors/sqlserver/SqlServerRecordHandlerTest.java index 58cc7a8dc6..c6f8f659dd 100644 --- a/athena-sqlserver/src/test/java/com/amazonaws/athena/connectors/sqlserver/SqlServerRecordHandlerTest.java +++ b/athena-sqlserver/src/test/java/com/amazonaws/athena/connectors/sqlserver/SqlServerRecordHandlerTest.java @@ -32,7 +32,6 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.s3.AmazonS3; import com.google.common.collect.ImmutableMap; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Schema; @@ -41,6 +40,7 @@ import org.junit.Test; import org.mockito.Mockito; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; @@ -57,7 +57,7 @@ public class SqlServerRecordHandlerTest private Connection connection; private JdbcConnectionFactory jdbcConnectionFactory; private JdbcSplitQueryBuilder jdbcSplitQueryBuilder; - private AmazonS3 amazonS3; + private S3Client amazonS3; private SecretsManagerClient secretsManager; private AthenaClient athena; @@ -66,7 +66,7 @@ public void setup() throws Exception { System.setProperty("aws.region", "us-east-1"); - this.amazonS3 = Mockito.mock(AmazonS3.class); + this.amazonS3 = Mockito.mock(S3Client.class); this.secretsManager = Mockito.mock(SecretsManagerClient.class); this.athena = Mockito.mock(AthenaClient.class); this.connection = Mockito.mock(Connection.class); diff --git a/athena-synapse/src/main/java/com/amazonaws/athena/connectors/synapse/SynapseMuxRecordHandler.java b/athena-synapse/src/main/java/com/amazonaws/athena/connectors/synapse/SynapseMuxRecordHandler.java index 38e47cbd02..6fabc20bf7 100644 --- a/athena-synapse/src/main/java/com/amazonaws/athena/connectors/synapse/SynapseMuxRecordHandler.java +++ b/athena-synapse/src/main/java/com/amazonaws/athena/connectors/synapse/SynapseMuxRecordHandler.java @@ -24,9 +24,9 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandlerFactory; -import com.amazonaws.services.s3.AmazonS3; import com.google.common.annotations.VisibleForTesting; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -54,7 +54,7 @@ public SynapseMuxRecordHandler(java.util.Map configOptions) } @VisibleForTesting - SynapseMuxRecordHandler(AmazonS3 amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, + SynapseMuxRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, DatabaseConnectionConfig databaseConnectionConfig, Map recordHandlerMap, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, jdbcConnectionFactory, databaseConnectionConfig, recordHandlerMap, configOptions); diff --git a/athena-synapse/src/main/java/com/amazonaws/athena/connectors/synapse/SynapseRecordHandler.java b/athena-synapse/src/main/java/com/amazonaws/athena/connectors/synapse/SynapseRecordHandler.java index 11e5d74148..a7a6aed815 100644 --- a/athena-synapse/src/main/java/com/amazonaws/athena/connectors/synapse/SynapseRecordHandler.java +++ b/athena-synapse/src/main/java/com/amazonaws/athena/connectors/synapse/SynapseRecordHandler.java @@ -33,8 +33,6 @@ import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; import com.google.common.annotations.VisibleForTesting; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; @@ -43,6 +41,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; @@ -64,14 +63,14 @@ public SynapseRecordHandler(java.util.Map configOptions) } public SynapseRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, java.util.Map configOptions) { - this(databaseConnectionConfig, AmazonS3ClientBuilder.defaultClient(), SecretsManagerClient.create(), + this(databaseConnectionConfig, S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(), new SynapseJdbcConnectionFactory(databaseConnectionConfig, SynapseMetadataHandler.JDBC_PROPERTIES, new DatabaseConnectionInfo(SynapseConstants.DRIVER_CLASS, SynapseConstants.DEFAULT_PORT)), new SynapseQueryStringBuilder(QUOTE_CHARACTER, new SynapseFederationExpressionParser(QUOTE_CHARACTER)), configOptions); } @VisibleForTesting - SynapseRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, final AmazonS3 amazonS3, final SecretsManagerClient secretsManager, + SynapseRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, final S3Client amazonS3, final SecretsManagerClient secretsManager, final AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, databaseConnectionConfig, jdbcConnectionFactory, configOptions); diff --git a/athena-synapse/src/test/java/com/amazonaws/athena/connectors/synapse/SynapseMetadataHandlerTest.java b/athena-synapse/src/test/java/com/amazonaws/athena/connectors/synapse/SynapseMetadataHandlerTest.java index 4138897089..7e90ae436f 100644 --- a/athena-synapse/src/test/java/com/amazonaws/athena/connectors/synapse/SynapseMetadataHandlerTest.java +++ b/athena-synapse/src/test/java/com/amazonaws/athena/connectors/synapse/SynapseMetadataHandlerTest.java @@ -50,7 +50,6 @@ import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; - import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; diff --git a/athena-synapse/src/test/java/com/amazonaws/athena/connectors/synapse/SynapseMuxRecordHandlerTest.java b/athena-synapse/src/test/java/com/amazonaws/athena/connectors/synapse/SynapseMuxRecordHandlerTest.java index 369d2d7dd2..3ed375cae4 100644 --- a/athena-synapse/src/test/java/com/amazonaws/athena/connectors/synapse/SynapseMuxRecordHandlerTest.java +++ b/athena-synapse/src/test/java/com/amazonaws/athena/connectors/synapse/SynapseMuxRecordHandlerTest.java @@ -28,12 +28,12 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; -import com.amazonaws.services.s3.AmazonS3; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; @@ -46,7 +46,7 @@ public class SynapseMuxRecordHandlerTest private Map recordHandlerMap; private SynapseRecordHandler synapseRecordHandler; private JdbcRecordHandler jdbcRecordHandler; - private AmazonS3 amazonS3; + private S3Client amazonS3; private SecretsManagerClient secretsManager; private AthenaClient athena; private QueryStatusChecker queryStatusChecker; @@ -57,7 +57,7 @@ public void setup() { this.synapseRecordHandler = Mockito.mock(SynapseRecordHandler.class); this.recordHandlerMap = Collections.singletonMap(SynapseConstants.NAME, this.synapseRecordHandler); - this.amazonS3 = Mockito.mock(AmazonS3.class); + this.amazonS3 = Mockito.mock(S3Client.class); this.secretsManager = Mockito.mock(SecretsManagerClient.class); this.athena = Mockito.mock(AthenaClient.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); diff --git a/athena-synapse/src/test/java/com/amazonaws/athena/connectors/synapse/SynapseRecordHandlerTest.java b/athena-synapse/src/test/java/com/amazonaws/athena/connectors/synapse/SynapseRecordHandlerTest.java index aaa61dea9a..b0108974cc 100644 --- a/athena-synapse/src/test/java/com/amazonaws/athena/connectors/synapse/SynapseRecordHandlerTest.java +++ b/athena-synapse/src/test/java/com/amazonaws/athena/connectors/synapse/SynapseRecordHandlerTest.java @@ -31,7 +31,6 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.s3.AmazonS3; import com.google.common.collect.ImmutableMap; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Schema; @@ -40,6 +39,7 @@ import org.junit.Test; import org.mockito.Mockito; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; @@ -58,7 +58,7 @@ public class SynapseRecordHandlerTest private Connection connection; private JdbcConnectionFactory jdbcConnectionFactory; private JdbcSplitQueryBuilder jdbcSplitQueryBuilder; - private AmazonS3 amazonS3; + private S3Client amazonS3; private SecretsManagerClient secretsManager; private AthenaClient athena; @@ -66,7 +66,7 @@ public class SynapseRecordHandlerTest public void setup() throws Exception { - this.amazonS3 = Mockito.mock(AmazonS3.class); + this.amazonS3 = Mockito.mock(S3Client.class); this.secretsManager = Mockito.mock(SecretsManagerClient.class); this.athena = Mockito.mock(AthenaClient.class); this.connection = Mockito.mock(Connection.class); diff --git a/athena-teradata/src/main/java/com/amazonaws/athena/connectors/teradata/TeradataMuxRecordHandler.java b/athena-teradata/src/main/java/com/amazonaws/athena/connectors/teradata/TeradataMuxRecordHandler.java index 2f1a9f2954..5667ddae62 100644 --- a/athena-teradata/src/main/java/com/amazonaws/athena/connectors/teradata/TeradataMuxRecordHandler.java +++ b/athena-teradata/src/main/java/com/amazonaws/athena/connectors/teradata/TeradataMuxRecordHandler.java @@ -25,9 +25,9 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandlerFactory; -import com.amazonaws.services.s3.AmazonS3; import com.google.common.annotations.VisibleForTesting; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.util.Map; @@ -55,7 +55,7 @@ public TeradataMuxRecordHandler(java.util.Map configOptions) } @VisibleForTesting - TeradataMuxRecordHandler(AmazonS3 amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, + TeradataMuxRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, DatabaseConnectionConfig databaseConnectionConfig, Map recordHandlerMap, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, jdbcConnectionFactory, databaseConnectionConfig, recordHandlerMap, configOptions); diff --git a/athena-teradata/src/main/java/com/amazonaws/athena/connectors/teradata/TeradataRecordHandler.java b/athena-teradata/src/main/java/com/amazonaws/athena/connectors/teradata/TeradataRecordHandler.java index 74b83ae3fd..52382322a6 100644 --- a/athena-teradata/src/main/java/com/amazonaws/athena/connectors/teradata/TeradataRecordHandler.java +++ b/athena-teradata/src/main/java/com/amazonaws/athena/connectors/teradata/TeradataRecordHandler.java @@ -29,12 +29,11 @@ import com.amazonaws.athena.connectors.jdbc.manager.JDBCUtil; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; import com.google.common.annotations.VisibleForTesting; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.commons.lang3.Validate; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; @@ -58,12 +57,12 @@ public TeradataRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, public TeradataRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, JdbcConnectionFactory jdbcConnectionFactory, java.util.Map configOptions) { - this(databaseConnectionConfig, AmazonS3ClientBuilder.defaultClient(), SecretsManagerClient.create(), AthenaClient.create(), + this(databaseConnectionConfig, S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(), jdbcConnectionFactory, new TeradataQueryStringBuilder(TERADATA_QUOTE_CHARACTER, new TeradataFederationExpressionParser(TERADATA_QUOTE_CHARACTER)), configOptions); } @VisibleForTesting - TeradataRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, final AmazonS3 amazonS3, final SecretsManagerClient secretsManager, + TeradataRecordHandler(DatabaseConnectionConfig databaseConnectionConfig, final S3Client amazonS3, final SecretsManagerClient secretsManager, final AthenaClient athena, JdbcConnectionFactory jdbcConnectionFactory, JdbcSplitQueryBuilder jdbcSplitQueryBuilder, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, databaseConnectionConfig, jdbcConnectionFactory, configOptions); diff --git a/athena-teradata/src/test/java/com/amazonaws/athena/connectors/teradata/TeradataMuxJdbcRecordHandlerTest.java b/athena-teradata/src/test/java/com/amazonaws/athena/connectors/teradata/TeradataMuxJdbcRecordHandlerTest.java index 8ee13facf2..0c768ba3db 100644 --- a/athena-teradata/src/test/java/com/amazonaws/athena/connectors/teradata/TeradataMuxJdbcRecordHandlerTest.java +++ b/athena-teradata/src/test/java/com/amazonaws/athena/connectors/teradata/TeradataMuxJdbcRecordHandlerTest.java @@ -28,12 +28,12 @@ import com.amazonaws.athena.connectors.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.manager.JdbcRecordHandler; -import com.amazonaws.services.s3.AmazonS3; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; @@ -46,7 +46,7 @@ public class TeradataMuxJdbcRecordHandlerTest private Map recordHandlerMap; private TeradataRecordHandler teradataRecordHandler; private JdbcRecordHandler jdbcRecordHandler; - private AmazonS3 amazonS3; + private S3Client amazonS3; private SecretsManagerClient secretsManager; private AthenaClient athena; private QueryStatusChecker queryStatusChecker; @@ -57,7 +57,7 @@ public void setup() { this.teradataRecordHandler = Mockito.mock(TeradataRecordHandler.class); this.recordHandlerMap = Collections.singletonMap("teradata", this.teradataRecordHandler); - this.amazonS3 = Mockito.mock(AmazonS3.class); + this.amazonS3 = Mockito.mock(S3Client.class); this.secretsManager = Mockito.mock(SecretsManagerClient.class); this.athena = Mockito.mock(AthenaClient.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); diff --git a/athena-teradata/src/test/java/com/amazonaws/athena/connectors/teradata/TeradataRecordHandlerTest.java b/athena-teradata/src/test/java/com/amazonaws/athena/connectors/teradata/TeradataRecordHandlerTest.java index 09dd2f4c75..4a306592df 100644 --- a/athena-teradata/src/test/java/com/amazonaws/athena/connectors/teradata/TeradataRecordHandlerTest.java +++ b/athena-teradata/src/test/java/com/amazonaws/athena/connectors/teradata/TeradataRecordHandlerTest.java @@ -32,7 +32,6 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; import com.amazonaws.athena.connectors.jdbc.manager.JdbcSplitQueryBuilder; -import com.amazonaws.services.s3.AmazonS3; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.apache.arrow.vector.types.Types; @@ -42,6 +41,7 @@ import org.junit.Test; import org.mockito.Mockito; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.sql.Connection; @@ -58,7 +58,7 @@ public class TeradataRecordHandlerTest private Connection connection; private JdbcConnectionFactory jdbcConnectionFactory; private JdbcSplitQueryBuilder jdbcSplitQueryBuilder; - private AmazonS3 amazonS3; + private S3Client amazonS3; private SecretsManagerClient secretsManager; private AthenaClient athena; @@ -66,7 +66,7 @@ public class TeradataRecordHandlerTest public void setup() throws Exception { - this.amazonS3 = Mockito.mock(AmazonS3.class); + this.amazonS3 = Mockito.mock(S3Client.class); this.secretsManager = Mockito.mock(SecretsManagerClient.class); this.athena = Mockito.mock(AthenaClient.class); this.connection = Mockito.mock(Connection.class); diff --git a/athena-timestream/src/main/java/com/amazonaws/athena/connectors/timestream/TimestreamRecordHandler.java b/athena-timestream/src/main/java/com/amazonaws/athena/connectors/timestream/TimestreamRecordHandler.java index 9975a8c33f..a8cc2be021 100644 --- a/athena-timestream/src/main/java/com/amazonaws/athena/connectors/timestream/TimestreamRecordHandler.java +++ b/athena-timestream/src/main/java/com/amazonaws/athena/connectors/timestream/TimestreamRecordHandler.java @@ -40,8 +40,6 @@ import com.amazonaws.athena.connectors.timestream.qpt.TimestreamQueryPassthrough; import com.amazonaws.athena.connectors.timestream.query.QueryFactory; import com.amazonaws.athena.connectors.timestream.query.SelectQueryBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; import com.amazonaws.services.timestreamquery.AmazonTimestreamQuery; import com.amazonaws.services.timestreamquery.model.Datum; import com.amazonaws.services.timestreamquery.model.QueryRequest; @@ -59,6 +57,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.time.Instant; @@ -92,7 +91,7 @@ public class TimestreamRecordHandler public TimestreamRecordHandler(java.util.Map configOptions) { this( - AmazonS3ClientBuilder.defaultClient(), + S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(), TimestreamClientBuilder.buildQueryClient(SOURCE_TYPE), @@ -100,7 +99,7 @@ public TimestreamRecordHandler(java.util.Map configOptions) } @VisibleForTesting - protected TimestreamRecordHandler(AmazonS3 amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, AmazonTimestreamQuery tsQuery, java.util.Map configOptions) + protected TimestreamRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, AmazonTimestreamQuery tsQuery, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, SOURCE_TYPE, configOptions); this.tsQuery = tsQuery; diff --git a/athena-timestream/src/test/java/com/amazonaws/athena/connectors/timestream/TimestreamRecordHandlerTest.java b/athena-timestream/src/test/java/com/amazonaws/athena/connectors/timestream/TimestreamRecordHandlerTest.java index 9804555422..d7ad28e816 100644 --- a/athena-timestream/src/test/java/com/amazonaws/athena/connectors/timestream/TimestreamRecordHandlerTest.java +++ b/athena-timestream/src/test/java/com/amazonaws/athena/connectors/timestream/TimestreamRecordHandlerTest.java @@ -40,11 +40,6 @@ import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory; import com.amazonaws.athena.connector.lambda.security.FederatedIdentity; import com.amazonaws.athena.connector.lambda.security.LocalKeyFactory; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.PutObjectRequest; -import com.amazonaws.services.s3.model.PutObjectResult; -import com.amazonaws.services.s3.model.S3Object; -import com.amazonaws.services.s3.model.S3ObjectInputStream; import com.amazonaws.services.timestreamquery.AmazonTimestreamQuery; import com.amazonaws.services.timestreamquery.model.QueryRequest; import com.amazonaws.services.timestreamquery.model.QueryResult; @@ -64,7 +59,14 @@ import org.mockito.stubbing.Answer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.core.sync.RequestBody; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.io.ByteArrayInputStream; @@ -100,7 +102,7 @@ public class TimestreamRecordHandlerTest private TimestreamRecordHandler handler; private BlockAllocator allocator; private List mockS3Storage = new ArrayList<>(); - private AmazonS3 amazonS3; + private S3Client amazonS3; private S3BlockSpillReader spillReader; private Schema schemaForRead; private EncryptionKeyFactory keyFactory = new LocalKeyFactory(); @@ -144,31 +146,29 @@ public void setUp() allocator = new BlockAllocatorImpl(); - amazonS3 = mock(AmazonS3.class); + amazonS3 = mock(S3Client.class); - when(amazonS3.putObject(any())) + when(amazonS3.putObject(any(PutObjectRequest.class), any(RequestBody.class))) .thenAnswer((InvocationOnMock invocationOnMock) -> { - InputStream inputStream = ((PutObjectRequest) invocationOnMock.getArguments()[0]).getInputStream(); + InputStream inputStream = ((RequestBody) invocationOnMock.getArguments()[1]).contentStreamProvider().newStream(); ByteHolder byteHolder = new ByteHolder(); byteHolder.setBytes(ByteStreams.toByteArray(inputStream)); synchronized (mockS3Storage) { mockS3Storage.add(byteHolder); + logger.info("puObject: total size " + mockS3Storage.size()); } - return mock(PutObjectResult.class); + return PutObjectResponse.builder().build(); }); - when(amazonS3.getObject(nullable(String.class), nullable(String.class))) + when(amazonS3.getObject(any(GetObjectRequest.class))) .thenAnswer((InvocationOnMock invocationOnMock) -> { - S3Object mockObject = mock(S3Object.class); ByteHolder byteHolder; synchronized (mockS3Storage) { byteHolder = mockS3Storage.get(0); mockS3Storage.remove(0); + logger.info("getObject: total size " + mockS3Storage.size()); } - when(mockObject.getObjectContent()).thenReturn( - new S3ObjectInputStream( - new ByteArrayInputStream(byteHolder.getBytes()), null)); - return mockObject; + return new ResponseInputStream<>(GetObjectResponse.builder().build(), new ByteArrayInputStream(byteHolder.getBytes())); }); schemaForRead = SchemaBuilder.newBuilder() diff --git a/athena-tpcds/src/main/java/com/amazonaws/athena/connectors/tpcds/TPCDSRecordHandler.java b/athena-tpcds/src/main/java/com/amazonaws/athena/connectors/tpcds/TPCDSRecordHandler.java index df8b16529a..33c7fce626 100644 --- a/athena-tpcds/src/main/java/com/amazonaws/athena/connectors/tpcds/TPCDSRecordHandler.java +++ b/athena-tpcds/src/main/java/com/amazonaws/athena/connectors/tpcds/TPCDSRecordHandler.java @@ -26,8 +26,6 @@ import com.amazonaws.athena.connector.lambda.domain.TableName; import com.amazonaws.athena.connector.lambda.handlers.RecordHandler; import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; import com.teradata.tpcds.Results; import com.teradata.tpcds.Session; import com.teradata.tpcds.Table; @@ -39,6 +37,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.io.IOException; @@ -76,11 +75,11 @@ public class TPCDSRecordHandler public TPCDSRecordHandler(java.util.Map configOptions) { - super(AmazonS3ClientBuilder.defaultClient(), SecretsManagerClient.create(), AthenaClient.create(), SOURCE_TYPE, configOptions); + super(S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(), SOURCE_TYPE, configOptions); } @VisibleForTesting - protected TPCDSRecordHandler(AmazonS3 amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, java.util.Map configOptions) + protected TPCDSRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient athena, java.util.Map configOptions) { super(amazonS3, secretsManager, athena, SOURCE_TYPE, configOptions); } diff --git a/athena-tpcds/src/test/java/com/amazonaws/athena/connectors/tpcds/TPCDSRecordHandlerTest.java b/athena-tpcds/src/test/java/com/amazonaws/athena/connectors/tpcds/TPCDSRecordHandlerTest.java index e26bf2458d..a13b453c55 100644 --- a/athena-tpcds/src/test/java/com/amazonaws/athena/connectors/tpcds/TPCDSRecordHandlerTest.java +++ b/athena-tpcds/src/test/java/com/amazonaws/athena/connectors/tpcds/TPCDSRecordHandlerTest.java @@ -41,11 +41,6 @@ import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory; import com.amazonaws.athena.connector.lambda.security.FederatedIdentity; import com.amazonaws.athena.connector.lambda.security.LocalKeyFactory; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.PutObjectRequest; -import com.amazonaws.services.s3.model.PutObjectResult; -import com.amazonaws.services.s3.model.S3Object; -import com.amazonaws.services.s3.model.S3ObjectInputStream; import com.google.common.collect.ImmutableMap; import com.google.common.io.ByteStreams; import com.teradata.tpcds.Table; @@ -61,7 +56,14 @@ import org.mockito.junit.MockitoJUnitRunner; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.core.sync.RequestBody; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import java.io.ByteArrayInputStream; @@ -98,7 +100,7 @@ public class TPCDSRecordHandlerTest private Schema schemaForRead; @Mock - private AmazonS3 mockS3; + private S3Client mockS3; @Mock private SecretsManagerClient mockSecretsManager; @@ -127,30 +129,28 @@ public void setUp() handler = new TPCDSRecordHandler(mockS3, mockSecretsManager, mockAthena, com.google.common.collect.ImmutableMap.of()); spillReader = new S3BlockSpillReader(mockS3, allocator); - when(mockS3.putObject(any())) + when(mockS3.putObject(any(PutObjectRequest.class), any(RequestBody.class))) .thenAnswer((InvocationOnMock invocationOnMock) -> { + InputStream inputStream = ((RequestBody) invocationOnMock.getArguments()[1]).contentStreamProvider().newStream(); + ByteHolder byteHolder = new ByteHolder(); + byteHolder.setBytes(ByteStreams.toByteArray(inputStream)); synchronized (mockS3Storage) { - InputStream inputStream = ((PutObjectRequest) invocationOnMock.getArguments()[0]).getInputStream(); - ByteHolder byteHolder = new ByteHolder(); - byteHolder.setBytes(ByteStreams.toByteArray(inputStream)); mockS3Storage.add(byteHolder); - return mock(PutObjectResult.class); + logger.info("puObject: total size " + mockS3Storage.size()); } + return PutObjectResponse.builder().build(); }); - when(mockS3.getObject(nullable(String.class), nullable(String.class))) - .thenAnswer((InvocationOnMock invocationOnMock) -> - { + when(mockS3.getObject(any(GetObjectRequest.class))) + .thenAnswer((InvocationOnMock invocationOnMock) -> { + ByteHolder byteHolder; synchronized (mockS3Storage) { - S3Object mockObject = mock(S3Object.class); - ByteHolder byteHolder = mockS3Storage.get(0); + byteHolder = mockS3Storage.get(0); mockS3Storage.remove(0); - when(mockObject.getObjectContent()).thenReturn( - new S3ObjectInputStream( - new ByteArrayInputStream(byteHolder.getBytes()), null)); - return mockObject; + logger.info("getObject: total size " + mockS3Storage.size()); } + return new ResponseInputStream<>(GetObjectResponse.builder().build(), new ByteArrayInputStream(byteHolder.getBytes())); }); } diff --git a/athena-vertica/pom.xml b/athena-vertica/pom.xml index b5b086e449..745c4e5782 100644 --- a/athena-vertica/pom.xml +++ b/athena-vertica/pom.xml @@ -32,6 +32,11 @@ jcl-over-slf4j ${slf4j-log4j.version} + + org.apache.arrow + arrow-dataset + ${apache.arrow.version} + org.apache.logging.log4j log4j-slf4j2-impl diff --git a/athena-vertica/src/main/java/com/amazonaws/athena/connectors/vertica/VerticaConstants.java b/athena-vertica/src/main/java/com/amazonaws/athena/connectors/vertica/VerticaConstants.java index dbcf85e8a5..72bdacf8c2 100644 --- a/athena-vertica/src/main/java/com/amazonaws/athena/connectors/vertica/VerticaConstants.java +++ b/athena-vertica/src/main/java/com/amazonaws/athena/connectors/vertica/VerticaConstants.java @@ -24,6 +24,9 @@ public final class VerticaConstants public static final String VERTICA_NAME = "vertica"; public static final String VERTICA_DRIVER_CLASS = "com.vertica.jdbc.Driver"; public static final int VERTICA_DEFAULT_PORT = 5433; + public static final String VERTICA_SPLIT_QUERY_ID = "query_id"; + public static final String VERTICA_SPLIT_EXPORT_BUCKET = "exportBucket"; + public static final String VERTICA_SPLIT_OBJECT_KEY = "s3ObjectKey"; private VerticaConstants() {} } diff --git a/athena-vertica/src/main/java/com/amazonaws/athena/connectors/vertica/VerticaMetadataHandler.java b/athena-vertica/src/main/java/com/amazonaws/athena/connectors/vertica/VerticaMetadataHandler.java index ee40632659..4e691900aa 100644 --- a/athena-vertica/src/main/java/com/amazonaws/athena/connectors/vertica/VerticaMetadataHandler.java +++ b/athena-vertica/src/main/java/com/amazonaws/athena/connectors/vertica/VerticaMetadataHandler.java @@ -48,11 +48,6 @@ import com.amazonaws.athena.connectors.jdbc.qpt.JdbcQueryPassthrough; import com.amazonaws.athena.connectors.vertica.query.QueryFactory; import com.amazonaws.athena.connectors.vertica.query.VerticaExportQueryBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.s3.model.ListObjectsRequest; -import com.amazonaws.services.s3.model.ObjectListing; -import com.amazonaws.services.s3.model.S3ObjectSummary; import com.google.common.collect.ImmutableMap; import org.apache.arrow.util.VisibleForTesting; import org.apache.arrow.vector.complex.reader.FieldReader; @@ -62,6 +57,11 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.stringtemplate.v4.ST; +import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.ListObjectsRequest; +import software.amazon.awssdk.services.s3.model.ListObjectsResponse; +import software.amazon.awssdk.services.s3.model.S3Object; import java.sql.Connection; import java.sql.DatabaseMetaData; @@ -80,6 +80,9 @@ import static com.amazonaws.athena.connectors.vertica.VerticaConstants.VERTICA_DEFAULT_PORT; import static com.amazonaws.athena.connectors.vertica.VerticaConstants.VERTICA_DRIVER_CLASS; import static com.amazonaws.athena.connectors.vertica.VerticaConstants.VERTICA_NAME; +import static com.amazonaws.athena.connectors.vertica.VerticaConstants.VERTICA_SPLIT_EXPORT_BUCKET; +import static com.amazonaws.athena.connectors.vertica.VerticaConstants.VERTICA_SPLIT_OBJECT_KEY; +import static com.amazonaws.athena.connectors.vertica.VerticaConstants.VERTICA_SPLIT_QUERY_ID; import static com.amazonaws.athena.connectors.vertica.VerticaSchemaUtils.convertToArrowType; @@ -100,7 +103,7 @@ public class VerticaMetadataHandler private static final String[] TABLE_TYPES = {"TABLE"}; private final QueryFactory queryFactory = new QueryFactory(); private final VerticaSchemaUtils verticaSchemaUtils; - private AmazonS3 amazonS3; + private S3Client amazonS3; private final JdbcQueryPassthrough queryPassthrough = new JdbcQueryPassthrough(); @@ -117,11 +120,11 @@ public VerticaMetadataHandler(DatabaseConnectionConfig databaseConnectionConfig, public VerticaMetadataHandler(DatabaseConnectionConfig databaseConnectionConfig, JdbcConnectionFactory jdbcConnectionFactory, Map configOptions) { super(databaseConnectionConfig, jdbcConnectionFactory, configOptions); - amazonS3 = AmazonS3ClientBuilder.defaultClient(); + amazonS3 = S3Client.create(); verticaSchemaUtils = new VerticaSchemaUtils(); } @VisibleForTesting - public VerticaMetadataHandler(DatabaseConnectionConfig databaseConnectionConfig, JdbcConnectionFactory jdbcConnectionFactory, Map configOptions, AmazonS3 amazonS3, VerticaSchemaUtils verticaSchemaUtils) + public VerticaMetadataHandler(DatabaseConnectionConfig databaseConnectionConfig, JdbcConnectionFactory jdbcConnectionFactory, Map configOptions, S3Client amazonS3, VerticaSchemaUtils verticaSchemaUtils) { super(databaseConnectionConfig, jdbcConnectionFactory, configOptions); this.amazonS3 = amazonS3; @@ -298,8 +301,8 @@ public void getPartitions(BlockWriter blockWriter, GetTableLayoutRequest request } logger.info("Vertica Export Statement: {}", preparedSQLStmt); - // Build the Set AWS Region SQL - String awsRegionSql = queryBuilder.buildSetAwsRegionSql(amazonS3.getRegion().toString()); + // Build the Set AWS Region SQL - Assumes using the default region provider chain + String awsRegionSql = queryBuilder.buildSetAwsRegionSql(DefaultAwsRegionProviderChain.builder().build().getRegion().toString()); // write the prepared SQL statement to the partition column created in enhancePartitionSchema blockWriter.writeRows((Block block, int rowNum) ->{ @@ -374,16 +377,16 @@ public GetSplitsResponse doGetSplits(BlockAllocator allocator, GetSplitsRequest * For each generated S3 object, create a split and add data to the split. */ Split split; - List s3ObjectSummaries = getlistExportedObjects(exportBucket, queryId); + List s3ObjectsList = getlistExportedObjects(exportBucket, queryId); - if(!s3ObjectSummaries.isEmpty()) + if(!s3ObjectsList.isEmpty()) { - for (S3ObjectSummary objectSummary : s3ObjectSummaries) + for (S3Object s3Object : s3ObjectsList) { split = Split.newBuilder(makeSpillLocation(request), makeEncryptionKey()) - .add("query_id", queryID) - .add("exportBucket", exportBucket) - .add("s3ObjectKey", objectSummary.getKey()) + .add(VERTICA_SPLIT_QUERY_ID, queryID) + .add(VERTICA_SPLIT_EXPORT_BUCKET, exportBucket) + .add(VERTICA_SPLIT_OBJECT_KEY, s3Object.key()) .build(); splits.add(split); @@ -395,9 +398,9 @@ public GetSplitsResponse doGetSplits(BlockAllocator allocator, GetSplitsRequest //No records were exported by Vertica for the issued query, creating a "empty" split logger.info("No records were exported by Vertica"); split = Split.newBuilder(makeSpillLocation(request), makeEncryptionKey()) - .add("query_id", queryID) - .add("exportBucket", exportBucket) - .add("s3ObjectKey", EMPTY_STRING) + .add(VERTICA_SPLIT_QUERY_ID, queryID) + .add(VERTICA_SPLIT_EXPORT_BUCKET, exportBucket) + .add(VERTICA_SPLIT_OBJECT_KEY, EMPTY_STRING) .build(); splits.add(split); return new GetSplitsResponse(catalogName,split); @@ -428,17 +431,20 @@ private void executeQueriesOnVertica(Connection connection, String sqlStatement, /* * Get the list of all the exported S3 objects */ - private List getlistExportedObjects(String s3ExportBucket, String queryId){ - ObjectListing objectListing; + private List getlistExportedObjects(String s3ExportBucket, String queryId){ + ListObjectsResponse listObjectsResponse; try { - objectListing = amazonS3.listObjects(new ListObjectsRequest().withBucketName(s3ExportBucket).withPrefix(queryId)); + listObjectsResponse = amazonS3.listObjects(ListObjectsRequest.builder() + .bucket(s3ExportBucket) + .prefix(queryId) + .build()); } catch (SdkClientException e) { throw new RuntimeException("Exception listing the exported objects : " + e.getMessage(), e); } - return objectListing.getObjectSummaries(); + return listObjectsResponse.contents(); } private void testAccess(Connection conn, TableName table) { diff --git a/athena-vertica/src/main/java/com/amazonaws/athena/connectors/vertica/VerticaRecordHandler.java b/athena-vertica/src/main/java/com/amazonaws/athena/connectors/vertica/VerticaRecordHandler.java index 795d02e402..29bec641d6 100644 --- a/athena-vertica/src/main/java/com/amazonaws/athena/connectors/vertica/VerticaRecordHandler.java +++ b/athena-vertica/src/main/java/com/amazonaws/athena/connectors/vertica/VerticaRecordHandler.java @@ -32,26 +32,34 @@ import com.amazonaws.athena.connector.lambda.domain.Split; import com.amazonaws.athena.connector.lambda.handlers.RecordHandler; import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.s3.model.*; -import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.arrow.dataset.file.FileFormat; +import org.apache.arrow.dataset.file.FileSystemDatasetFactory; +import org.apache.arrow.dataset.jni.NativeMemoryPool; +import org.apache.arrow.dataset.scanner.ScanOptions; +import org.apache.arrow.dataset.scanner.Scanner; +import org.apache.arrow.dataset.source.Dataset; +import org.apache.arrow.dataset.source.DatasetFactory; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.util.VisibleForTesting; +import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.holders.*; +import org.apache.arrow.vector.ipc.ArrowReader; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; -import java.io.BufferedReader; +import static com.amazonaws.athena.connectors.vertica.VerticaConstants.VERTICA_SPLIT_EXPORT_BUCKET; +import static com.amazonaws.athena.connectors.vertica.VerticaConstants.VERTICA_SPLIT_OBJECT_KEY; +import static com.amazonaws.athena.connectors.vertica.VerticaConstants.VERTICA_SPLIT_QUERY_ID; + import java.io.IOException; -import java.io.InputStream; -import java.io.InputStreamReader; import java.math.BigDecimal; -import java.nio.charset.StandardCharsets; import java.time.LocalDate; import java.time.LocalDateTime; import java.util.HashMap; @@ -61,22 +69,18 @@ public class VerticaRecordHandler extends RecordHandler { private static final Logger logger = LoggerFactory.getLogger(VerticaRecordHandler.class); private static final String SOURCE_TYPE = "vertica"; - private static final String VERTICA_QUOTE_CHARACTER = "\""; - private static final String QUERY = "select * from S3Object s"; - private AmazonS3 amazonS3; public VerticaRecordHandler(java.util.Map configOptions) { - this(AmazonS3ClientBuilder.defaultClient(), + this(S3Client.create(), SecretsManagerClient.create(), AthenaClient.create(), configOptions); } @VisibleForTesting - protected VerticaRecordHandler(AmazonS3 amazonS3, SecretsManagerClient secretsManager, AthenaClient amazonAthena, java.util.Map configOptions) + protected VerticaRecordHandler(S3Client amazonS3, SecretsManagerClient secretsManager, AthenaClient amazonAthena, java.util.Map configOptions) { super(amazonS3, secretsManager, amazonAthena, SOURCE_TYPE, configOptions); - this.amazonS3 = amazonS3; } /** @@ -100,9 +104,9 @@ protected void readWithConstraint(BlockSpiller spiller, ReadRecordsRequest recor Schema schemaName = recordsRequest.getSchema(); Split split = recordsRequest.getSplit(); - String id = split.getProperty("query_id"); - String exportBucket = split.getProperty("exportBucket"); - String s3ObjectKey = split.getProperty("s3ObjectKey"); + String id = split.getProperty(VERTICA_SPLIT_QUERY_ID); + String exportBucket = split.getProperty(VERTICA_SPLIT_EXPORT_BUCKET); + String s3ObjectKey = split.getProperty(VERTICA_SPLIT_OBJECT_KEY); if(!s3ObjectKey.isEmpty()) { //get column name and type from the Schema @@ -127,25 +131,25 @@ protected void readWithConstraint(BlockSpiller spiller, ReadRecordsRequest recor } GeneratedRowWriter rowWriter = builder.build(); - /* - Using S3 Select to read the S3 Parquet file generated in the split - */ - //Creating the read Request - SelectObjectContentRequest request = generateBaseParquetRequest(exportBucket, s3ObjectKey); - try (SelectObjectContentResult result = amazonS3.selectObjectContent(request)) { - InputStream resultInputStream = result.getPayload().getRecordsInputStream(); - BufferedReader streamReader = new BufferedReader(new InputStreamReader(resultInputStream, StandardCharsets.UTF_8)); - String inputStr; - while ((inputStr = streamReader.readLine()) != null) { - HashMap map = new HashMap<>(); - //we are reading the parquet files, but serializing the output it as JSON as SDK provides a Parquet InputSerialization, but only a JSON or CSV OutputSerializatio - ObjectMapper objectMapper = new ObjectMapper(); - map = objectMapper.readValue(inputStr, HashMap.class); - rowContext.setNameValue(map); - - //Passing the RowContext to BlockWriter; - spiller.writeRows((Block block, int rowNum) -> rowWriter.writeRow(block, rowNum, rowContext) ? 1 : 0); + /* + Using Arrow Dataset to read the S3 Parquet file generated in the split + */ + try (ArrowReader reader = constructArrowReader(constructS3Uri(exportBucket, s3ObjectKey))) + { + while (reader.loadNextBatch()) { + VectorSchemaRoot root = reader.getVectorSchemaRoot(); + for (int row = 0; row < root.getRowCount(); row++) { + HashMap map = new HashMap<>(); + for (Field field : root.getSchema().getFields()) { + map.put(field.getName(), root.getVector(field).getObject(row)); + } + rowContext.setNameValue(map); + + //Passing the RowContext to BlockWriter; + spiller.writeRows((Block block, int rowNum) -> rowWriter.writeRow(block, rowNum, rowContext) ? 1 : 0); + } } + reader.close(); } catch (Exception e) { throw new RuntimeException("Error in connecting to S3 and selecting the object content for object : " + s3ObjectKey, e); } @@ -329,28 +333,24 @@ public HashMap getNameValue() { } } - - /* - Method to create the Parquet read request - */ - private static SelectObjectContentRequest generateBaseParquetRequest(String bucket, String key) + @VisibleForTesting + protected ArrowReader constructArrowReader(String uri) { - SelectObjectContentRequest request = new SelectObjectContentRequest(); - request.setBucketName(bucket); - request.setKey(key); - request.setExpression(VerticaRecordHandler.QUERY); - request.setExpressionType(ExpressionType.SQL); - - InputSerialization inputSerialization = new InputSerialization(); - inputSerialization.setParquet(new ParquetInput()); - inputSerialization.setCompressionType(CompressionType.NONE); - request.setInputSerialization(inputSerialization); - - OutputSerialization outputSerialization = new OutputSerialization(); - outputSerialization.setJson(new JSONOutput()); - request.setOutputSerialization(outputSerialization); + BufferAllocator allocator = new RootAllocator(); + DatasetFactory datasetFactory = new FileSystemDatasetFactory( + allocator, + NativeMemoryPool.getDefault(), + FileFormat.PARQUET, + uri); + Dataset dataset = datasetFactory.finish(); + ScanOptions options = new ScanOptions(/*batchSize*/ 32768); + Scanner scanner = dataset.newScan(options); + return scanner.scanBatches(); + } - return request; + private static String constructS3Uri(String bucket, String key) + { + return "s3://" + bucket + "/" + key; } } diff --git a/athena-vertica/src/test/java/com/amazonaws/athena/connectors/vertica/VerticaMetadataHandlerTest.java b/athena-vertica/src/test/java/com/amazonaws/athena/connectors/vertica/VerticaMetadataHandlerTest.java index ae833032d7..48091b59e0 100644 --- a/athena-vertica/src/test/java/com/amazonaws/athena/connectors/vertica/VerticaMetadataHandlerTest.java +++ b/athena-vertica/src/test/java/com/amazonaws/athena/connectors/vertica/VerticaMetadataHandlerTest.java @@ -47,11 +47,6 @@ import com.amazonaws.athena.connectors.jdbc.connection.JdbcCredentialProvider; import com.amazonaws.athena.connectors.vertica.query.QueryFactory; import com.amazonaws.athena.connectors.vertica.query.VerticaExportQueryBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.ListObjectsRequest; -import com.amazonaws.services.s3.model.ObjectListing; -import com.amazonaws.services.s3.model.Region; -import com.amazonaws.services.s3.model.S3ObjectSummary; import com.google.common.collect.ImmutableList; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.After; @@ -66,6 +61,10 @@ import org.slf4j.LoggerFactory; import org.stringtemplate.v4.ST; import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.ListObjectsRequest; +import software.amazon.awssdk.services.s3.model.ListObjectsResponse; +import software.amazon.awssdk.services.s3.model.S3Object; import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest; import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse; @@ -105,7 +104,7 @@ public class VerticaMetadataHandlerTest extends TestBase private Connection connection; private SecretsManagerClient secretsManager; private AthenaClient athena; - private AmazonS3 amazonS3; + private S3Client amazonS3; private FederatedIdentity federatedIdentity; private BlockAllocatorImpl allocator; private DatabaseMetaData databaseMetaData; @@ -117,11 +116,7 @@ public class VerticaMetadataHandlerTest extends TestBase private QueryStatusChecker queryStatusChecker; private VerticaMetadataHandler verticaMetadataHandlerMocked; @Mock - private AmazonS3 s3clientMock; - @Mock - private ListObjectsRequest listObjectsRequest; - @Mock - private ObjectListing objectListing; + private S3Client s3clientMock; private DatabaseConnectionConfig databaseConnectionConfig = new DatabaseConnectionConfig("testCatalog", VERTICA_NAME, "vertica://jdbc:vertica:thin:username/password@//127.0.0.1:1521/vrt"); @@ -144,11 +139,10 @@ public void setUp() throws Exception this.schemaBuilder = Mockito.mock(SchemaBuilder.class); this.blockWriter = Mockito.mock(BlockWriter.class); this.queryStatusChecker = Mockito.mock(QueryStatusChecker.class); - this.amazonS3 = Mockito.mock(AmazonS3.class); + this.amazonS3 = Mockito.mock(S3Client.class); Mockito.lenient().when(this.secretsManager.getSecretValue(Mockito.eq(GetSecretValueRequest.builder().secretId("testSecret").build()))).thenReturn(GetSecretValueResponse.builder().secretString("{\"username\": \"testUser\", \"password\": \"testPassword\"}").build()); Mockito.when(connection.getMetaData()).thenReturn(databaseMetaData); - Mockito.when(amazonS3.getRegion()).thenReturn(Region.US_West_2); this.jdbcConnectionFactory = Mockito.mock(JdbcConnectionFactory.class, Mockito.RETURNS_DEEP_STUBS); this.connection = Mockito.mock(Connection.class, Mockito.RETURNS_DEEP_STUBS); @@ -344,21 +338,13 @@ public void doGetSplits() throws Exception BlockUtils.setValue(partitions.getFieldVector("awsRegionSql"), i, "us-west-2"); } - List s3ObjectSummariesList = new ArrayList<>(); - S3ObjectSummary s3ObjectSummary = new S3ObjectSummary(); - s3ObjectSummary.setBucketName("s3ExportBucket"); - s3ObjectSummary.setKey("testKey"); - s3ObjectSummariesList.add(s3ObjectSummary); - ListObjectsRequest listObjectsRequestObj = new ListObjectsRequest(); - listObjectsRequestObj.setBucketName("s3ExportBucket"); - listObjectsRequestObj.setPrefix("queryId"); - + List objectList = new ArrayList<>(); + S3Object obj = S3Object.builder().key("testKey").build(); + objectList.add(obj); + ListObjectsResponse listObjectsResponse = ListObjectsResponse.builder().contents(objectList).build(); Mockito.when(verticaMetadataHandlerMocked.getS3ExportBucket()).thenReturn("testS3Bucket"); - Mockito.lenient().when(listObjectsRequest.withBucketName(nullable(String.class))).thenReturn(listObjectsRequestObj); - Mockito.lenient().when(listObjectsRequest.withPrefix(nullable(String.class))).thenReturn(listObjectsRequestObj); - Mockito.when(amazonS3.listObjects(nullable(ListObjectsRequest.class))).thenReturn(objectListing); - Mockito.when(objectListing.getObjectSummaries()).thenReturn(s3ObjectSummariesList); + Mockito.when(amazonS3.listObjects(nullable(ListObjectsRequest.class))).thenReturn(listObjectsResponse); GetSplitsRequest originalReq = new GetSplitsRequest(this.federatedIdentity, "queryId", "catalog_name", new TableName("schema", "table_name"), diff --git a/athena-vertica/src/test/java/com/amazonaws/athena/connectors/vertica/VerticaRecordHandlerTest.java b/athena-vertica/src/test/java/com/amazonaws/athena/connectors/vertica/VerticaRecordHandlerTest.java new file mode 100644 index 0000000000..b6ec304ad3 --- /dev/null +++ b/athena-vertica/src/test/java/com/amazonaws/athena/connectors/vertica/VerticaRecordHandlerTest.java @@ -0,0 +1,349 @@ +/*- + * #%L + * athena-gcs + * %% + * Copyright (C) 2019 - 2022 Amazon Web Services + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ +package com.amazonaws.athena.connectors.vertica; + +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.apache.arrow.vector.types.Types; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.Text; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestName; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.junit.MockitoJUnitRunner; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.amazonaws.athena.connector.lambda.data.Block; +import com.amazonaws.athena.connector.lambda.data.BlockAllocator; +import com.amazonaws.athena.connector.lambda.data.BlockAllocatorImpl; +import com.amazonaws.athena.connector.lambda.data.BlockUtils; +import com.amazonaws.athena.connector.lambda.data.S3BlockSpillReader; +import com.amazonaws.athena.connector.lambda.data.SchemaBuilder; +import com.amazonaws.athena.connector.lambda.domain.Split; +import com.amazonaws.athena.connector.lambda.domain.predicate.Constraints; +import com.amazonaws.athena.connector.lambda.domain.predicate.Range; +import com.amazonaws.athena.connector.lambda.domain.predicate.SortedRangeSet; +import com.amazonaws.athena.connector.lambda.domain.predicate.ValueSet; +import com.amazonaws.athena.connector.lambda.domain.spill.S3SpillLocation; +import com.amazonaws.athena.connector.lambda.domain.spill.SpillLocation; +import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest; +import com.amazonaws.athena.connector.lambda.records.ReadRecordsResponse; +import com.amazonaws.athena.connector.lambda.records.RecordResponse; +import com.amazonaws.athena.connector.lambda.records.RemoteReadRecordsResponse; +import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory; +import com.amazonaws.athena.connector.lambda.security.FederatedIdentity; +import com.amazonaws.athena.connector.lambda.security.LocalKeyFactory; +import com.google.common.collect.ImmutableList; +import com.google.common.io.ByteStreams; + +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.services.athena.AthenaClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; +import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient; + +import static com.amazonaws.athena.connector.lambda.domain.predicate.Constraints.DEFAULT_NO_LIMIT; +import static com.amazonaws.athena.connectors.vertica.VerticaConstants.VERTICA_SPLIT_EXPORT_BUCKET; +import static com.amazonaws.athena.connectors.vertica.VerticaConstants.VERTICA_SPLIT_OBJECT_KEY; +import static com.amazonaws.athena.connectors.vertica.VerticaConstants.VERTICA_SPLIT_QUERY_ID; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; + +@RunWith(MockitoJUnitRunner.class) + +public class VerticaRecordHandlerTest + extends TestBase +{ + private static final Logger logger = LoggerFactory.getLogger(VerticaRecordHandlerTest.class); + + private VerticaRecordHandler handler; + private BlockAllocator allocator; + private List mockS3Storage = new ArrayList<>(); + private S3BlockSpillReader spillReader; + private FederatedIdentity identity = new FederatedIdentity("arn", "account", Collections.emptyMap(), Collections.emptyList()); + private EncryptionKeyFactory keyFactory = new LocalKeyFactory(); + + private static final BufferAllocator bufferAllocator = new RootAllocator(); + + @Rule + public TestName testName = new TestName(); + + @Mock + private S3Client mockS3; + + @Mock + private SecretsManagerClient mockSecretsManager; + + @Mock + private AthenaClient mockAthena; + + @Before + public void setup() + { + logger.info("{}: enter", testName.getMethodName()); + + allocator = new BlockAllocatorImpl(); + handler = new VerticaRecordHandler(mockS3, mockSecretsManager, mockAthena, com.google.common.collect.ImmutableMap.of()); + spillReader = new S3BlockSpillReader(mockS3, allocator); + + Mockito.lenient().when(mockS3.putObject(any(PutObjectRequest.class), any(RequestBody.class))) + .thenAnswer((InvocationOnMock invocationOnMock) -> { + InputStream inputStream = ((RequestBody) invocationOnMock.getArguments()[1]).contentStreamProvider().newStream(); + ByteHolder byteHolder = new ByteHolder(); + byteHolder.setBytes(ByteStreams.toByteArray(inputStream)); + synchronized (mockS3Storage) { + mockS3Storage.add(byteHolder); + logger.info("puObject: total size " + mockS3Storage.size()); + } + return PutObjectResponse.builder().build(); + }); + + Mockito.lenient().when(mockS3.getObject(any(GetObjectRequest.class))) + .thenAnswer((InvocationOnMock invocationOnMock) -> { + ByteHolder byteHolder; + synchronized (mockS3Storage) { + byteHolder = mockS3Storage.get(0); + mockS3Storage.remove(0); + logger.info("getObject: total size " + mockS3Storage.size()); + } + return new ResponseInputStream<>(GetObjectResponse.builder().build(), new ByteArrayInputStream(byteHolder.getBytes())); + }); + } + + @After + public void after() + { + allocator.close(); + logger.info("{}: exit ", testName.getMethodName()); + } + + @Test + public void doReadRecordsNoSpill() + throws Exception + { + logger.info("doReadRecordsNoSpill: enter"); + + VectorSchemaRoot schemaRoot = createRoot(); + ArrowReader mockReader = mock(ArrowReader.class); + when(mockReader.loadNextBatch()).thenReturn(true, false); + when(mockReader.getVectorSchemaRoot()).thenReturn(schemaRoot); + VerticaRecordHandler handlerSpy = spy(handler); + doReturn(mockReader).when(handlerSpy).constructArrowReader(any()); + + Map constraintsMap = new HashMap<>(); + constraintsMap.put("time", SortedRangeSet.copyOf(Types.MinorType.BIGINT.getType(), + ImmutableList.of(Range.equal(allocator, Types.MinorType.BIGINT.getType(), 100L)), false)); + + S3SpillLocation splitLoc = S3SpillLocation.newBuilder() + .withBucket(UUID.randomUUID().toString()) + .withSplitId(UUID.randomUUID().toString()) + .withQueryId(UUID.randomUUID().toString()) + .withIsDirectory(true) + .build(); + + Split.Builder splitBuilder = Split.newBuilder(splitLoc, keyFactory.create()) + .add(VERTICA_SPLIT_QUERY_ID, "query_id") + .add(VERTICA_SPLIT_EXPORT_BUCKET, "export_bucket") + .add(VERTICA_SPLIT_OBJECT_KEY, "s3_object_key"); + + ReadRecordsRequest request = new ReadRecordsRequest(identity, + DEFAULT_CATALOG, + QUERY_ID, + TABLE_NAME, + schemaRoot.getSchema(), + splitBuilder.build(), + new Constraints(constraintsMap, Collections.emptyList(), Collections.emptyList(), DEFAULT_NO_LIMIT), + 100_000_000_000L, + 100_000_000_000L//100GB don't expect this to spill + ); + RecordResponse rawResponse = handlerSpy.doReadRecords(allocator, request); + + assertTrue(rawResponse instanceof ReadRecordsResponse); + + ReadRecordsResponse response = (ReadRecordsResponse) rawResponse; + logger.info("doReadRecordsNoSpill: rows[{}]", response.getRecordCount()); + + assertTrue(response.getRecords().getRowCount() == 2); + logger.info("doReadRecordsNoSpill: {}", BlockUtils.rowToString(response.getRecords(), 0)); + logger.info("doReadRecordsNoSpill: {}", BlockUtils.rowToString(response.getRecords(), 1)); + + for (Field field : schemaRoot.getSchema().getFields()) { + assertTrue(response.getRecords().getFieldVector(field.getName()).getObject(0).equals(schemaRoot.getVector(field).getObject(0))); + assertTrue(response.getRecords().getFieldVector(field.getName()).getObject(1).equals(schemaRoot.getVector(field).getObject(1))); + } + + logger.info("doReadRecordsNoSpill: exit"); + } + + @Test + public void doReadRecordsSpill() + throws Exception + { + logger.info("doReadRecordsSpill: enter"); + + VectorSchemaRoot schemaRoot = createRoot(); + ArrowReader mockReader = mock(ArrowReader.class); + when(mockReader.loadNextBatch()).thenReturn(true, false); + when(mockReader.getVectorSchemaRoot()).thenReturn(schemaRoot); + VerticaRecordHandler handlerSpy = spy(handler); + doReturn(mockReader).when(handlerSpy).constructArrowReader(any()); + + Map constraintsMap = new HashMap<>(); + constraintsMap.put("time", SortedRangeSet.copyOf(Types.MinorType.BIGINT.getType(), + ImmutableList.of(Range.equal(allocator, Types.MinorType.BIGINT.getType(), 100L)), false)); + + S3SpillLocation splitLoc = S3SpillLocation.newBuilder() + .withBucket(UUID.randomUUID().toString()) + .withSplitId(UUID.randomUUID().toString()) + .withQueryId(UUID.randomUUID().toString()) + .withIsDirectory(true) + .build(); + + Split.Builder splitBuilder = Split.newBuilder(splitLoc, keyFactory.create()) + .add(VERTICA_SPLIT_QUERY_ID, "query_id") + .add(VERTICA_SPLIT_EXPORT_BUCKET, "export_bucket") + .add(VERTICA_SPLIT_OBJECT_KEY, "s3_object_key"); + + ReadRecordsRequest request = new ReadRecordsRequest(identity, + DEFAULT_CATALOG, + QUERY_ID, + TABLE_NAME, + schemaRoot.getSchema(), + splitBuilder.build(), + new Constraints(constraintsMap, Collections.emptyList(), Collections.emptyList(), DEFAULT_NO_LIMIT), + 1_500_000L, //~1.5MB so we should see some spill + 0L + ); + RecordResponse rawResponse = handlerSpy.doReadRecords(allocator, request); + + assertTrue(rawResponse instanceof RemoteReadRecordsResponse); + + try (RemoteReadRecordsResponse response = (RemoteReadRecordsResponse) rawResponse) { + logger.info("doReadRecordsSpill: remoteBlocks[{}]", response.getRemoteBlocks().size()); + + //assertTrue(response.getNumberBlocks() > 1); + + int blockNum = 0; + for (SpillLocation next : response.getRemoteBlocks()) { + S3SpillLocation spillLocation = (S3SpillLocation) next; + try (Block block = spillReader.read(spillLocation, response.getEncryptionKey(), response.getSchema())) { + + logger.info("doReadRecordsSpill: blockNum[{}] and recordCount[{}]", blockNum++, block.getRowCount()); + // assertTrue(++blockNum < response.getRemoteBlocks().size() && block.getRowCount() > 10_000); + + logger.info("doReadRecordsSpill: {}", BlockUtils.rowToString(block, 0)); + assertNotNull(BlockUtils.rowToString(block, 0)); + } + } + } + + logger.info("doReadRecordsSpill: exit"); + } + + private class ByteHolder + { + private byte[] bytes; + + public void setBytes(byte[] bytes) + { + this.bytes = bytes; + } + + public byte[] getBytes() + { + return bytes; + } + } + + private VectorSchemaRoot createRoot() + { + Schema schema = SchemaBuilder.newBuilder() + .addBigIntField("day") + .addBigIntField("month") + .addBigIntField("year") + .addStringField("preparedStmt") + .addStringField("queryId") + .addStringField("awsRegionSql") + .build(); + VectorSchemaRoot schemaRoot = VectorSchemaRoot.create(schema, bufferAllocator); + BigIntVector dayVector = (BigIntVector) schemaRoot.getVector("day"); + dayVector.allocateNew(2); + dayVector.set(0, 0); + dayVector.set(1, 1); + dayVector.setValueCount(2); + BigIntVector monthVector = (BigIntVector) schemaRoot.getVector("month"); + monthVector.allocateNew(2); + monthVector.set(0, 0); + monthVector.set(1, 1); + monthVector.setValueCount(2); + BigIntVector yearVector = (BigIntVector) schemaRoot.getVector("year"); + yearVector.allocateNew(2); + yearVector.set(0, 2000); + yearVector.set(1, 2001); + yearVector.setValueCount(2); + VarCharVector stmtVector = (VarCharVector) schemaRoot.getVector("preparedStmt"); + stmtVector.allocateNew(2); + stmtVector.set(0, new Text("test1")); + stmtVector.set(1, new Text("test2")); + stmtVector.setValueCount(2); + VarCharVector idVector = (VarCharVector) schemaRoot.getVector("queryId"); + idVector.allocateNew(2); + idVector.set(0, new Text("queryID1")); + idVector.set(1, new Text("queryID2")); + idVector.setValueCount(2); + VarCharVector regionVector = (VarCharVector) schemaRoot.getVector("awsRegionSql"); + regionVector.allocateNew(2); + regionVector.set(0, new Text("region1")); + regionVector.set(1, new Text("region2")); + regionVector.setValueCount(2); + schemaRoot.setRowCount(2); + return schemaRoot; + } +}