Skip to content
This repository has been archived by the owner on Aug 31, 2021. It is now read-only.

Commit

Permalink
Allows giving providerclassname, used to create awsCredentialsProvide…
Browse files Browse the repository at this point in the history
…r object
  • Loading branch information
fogrid committed Jun 15, 2020
1 parent ba7e8c8 commit a390bca
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
package com.audienceproject.spark.dynamodb.connector

import com.amazonaws.auth.profile.ProfileCredentialsProvider
import com.amazonaws.auth.{AWSStaticCredentialsProvider, BasicSessionCredentials, DefaultAWSCredentialsProviderChain}
import com.amazonaws.auth.{AWSCredentialsProvider, AWSStaticCredentialsProvider, BasicSessionCredentials, DefaultAWSCredentialsProviderChain}
import com.amazonaws.client.builder.AwsClientBuilder.EndpointConfiguration
import com.amazonaws.services.dynamodbv2.document.{DynamoDB, ItemCollection, ScanOutcome}
import com.amazonaws.services.dynamodbv2.{AmazonDynamoDB, AmazonDynamoDBAsync, AmazonDynamoDBAsyncClientBuilder, AmazonDynamoDBClientBuilder}
Expand All @@ -33,14 +33,16 @@ private[dynamodb] trait DynamoConnector {

@transient private lazy val properties = sys.props

def getDynamoDB(region: Option[String] = None, roleArn: Option[String] = None): DynamoDB = {
val client: AmazonDynamoDB = getDynamoDBClient(region, roleArn)
def getDynamoDB(region: Option[String] = None, roleArn: Option[String] = None, providerClassName: Option[String] = None): DynamoDB = {
val client: AmazonDynamoDB = getDynamoDBClient(region, roleArn, providerClassName)
new DynamoDB(client)
}

private def getDynamoDBClient(region: Option[String] = None, roleArn: Option[String] = None): AmazonDynamoDB = {
private def getDynamoDBClient(region: Option[String] = None,
roleArn: Option[String] = None,
providerClassName: Option[String]): AmazonDynamoDB = {
val chosenRegion = region.getOrElse(properties.getOrElse("aws.dynamodb.region", "us-east-1"))
val credentials = getCredentials(chosenRegion, roleArn)
val credentials = getCredentials(chosenRegion, roleArn, providerClassName)

properties.get("aws.dynamodb.endpoint").map(endpoint => {
AmazonDynamoDBClientBuilder.standard()
Expand All @@ -55,9 +57,11 @@ private[dynamodb] trait DynamoConnector {
)
}

def getDynamoDBAsyncClient(region: Option[String] = None, roleArn: Option[String] = None): AmazonDynamoDBAsync = {
def getDynamoDBAsyncClient(region: Option[String] = None,
roleArn: Option[String] = None,
providerClassName: Option[String] = None): AmazonDynamoDBAsync = {
val chosenRegion = region.getOrElse(properties.getOrElse("aws.dynamodb.region", "us-east-1"))
val credentials = getCredentials(chosenRegion, roleArn)
val credentials = getCredentials(chosenRegion, roleArn, providerClassName)

properties.get("aws.dynamodb.endpoint").map(endpoint => {
AmazonDynamoDBAsyncClientBuilder.standard()
Expand All @@ -73,10 +77,15 @@ private[dynamodb] trait DynamoConnector {
}

/**
* Get credentials from a passed in arn or from profile or return the default credential provider
**/
private def getCredentials(chosenRegion: String, roleArn: Option[String]) = {
roleArn.map(arn => {
* Get credentials from an instantiated object of the class name given
* or a passed in arn
* or from profile
* or return the default credential provider
**/
private def getCredentials(chosenRegion: String, roleArn: Option[String], providerClassName: Option[String]) = {
providerClassName.map(providerClass => {
Class.forName(providerClass).newInstance.asInstanceOf[AWSCredentialsProvider]
}).orElse(roleArn.map(arn => {
val stsClient = properties.get("aws.sts.endpoint").map(endpoint => {
AWSSecurityTokenServiceClientBuilder
.standard()
Expand All @@ -103,7 +112,7 @@ private[dynamodb] trait DynamoConnector {
stsCredentials.getSessionToken
)
new AWSStaticCredentialsProvider(assumeCreds)
}).orElse(properties.get("aws.profile").map(new ProfileCredentialsProvider(_)))
})).orElse(properties.get("aws.profile").map(new ProfileCredentialsProvider(_)))
.getOrElse(new DefaultAWSCredentialsProviderChain)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,12 @@ private[dynamodb] class TableConnector(tableName: String, parallelism: Int, para
private val filterPushdown = parameters.getOrElse("filterpushdown", "true").toBoolean
private val region = parameters.get("region")
private val roleArn = parameters.get("rolearn")
private val providerClassName = parameters.get("providerclassname")

override val filterPushdownEnabled: Boolean = filterPushdown

override val (keySchema, readLimit, writeLimit, itemLimit, totalSegments) = {
val table = getDynamoDB(region, roleArn).getTable(tableName)
val table = getDynamoDB(region, roleArn, providerClassName).getTable(tableName)
val desc = table.describe()

// Key schema.
Expand Down Expand Up @@ -106,7 +107,7 @@ private[dynamodb] class TableConnector(tableName: String, parallelism: Int, para
scanSpec.withExpressionSpec(xspec.buildForScan())
}

getDynamoDB(region, roleArn).getTable(tableName).scan(scanSpec)
getDynamoDB(region, roleArn, providerClassName).getTable(tableName).scan(scanSpec)
}

override def putItems(columnSchema: ColumnSchema, items: Seq[InternalRow])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,12 @@ private[dynamodb] class TableIndexConnector(tableName: String, indexName: String
private val filterPushdown = parameters.getOrElse("filterPushdown", "true").toBoolean
private val region = parameters.get("region")
private val roleArn = parameters.get("roleArn")
private val providerClassName = parameters.get("providerclassname")

override val filterPushdownEnabled: Boolean = filterPushdown

override val (keySchema, readLimit, itemLimit, totalSegments) = {
val table = getDynamoDB(region, roleArn).getTable(tableName)
val table = getDynamoDB(region, roleArn, providerClassName).getTable(tableName)
val indexDesc = table.describe().getGlobalSecondaryIndexes.asScala.find(_.getIndexName == indexName).get

// Key schema.
Expand Down Expand Up @@ -96,7 +97,7 @@ private[dynamodb] class TableIndexConnector(tableName: String, indexName: String
scanSpec.withExpressionSpec(xspec.buildForScan())
}

getDynamoDB(region, roleArn).getTable(tableName).getIndex(indexName).scan(scanSpec)
getDynamoDB(region, roleArn, providerClassName).getTable(tableName).getIndex(indexName).scan(scanSpec)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@ class DynamoWriterFactory(connector: TableConnector,

private val region = parameters.get("region")
private val roleArn = parameters.get("rolearn")
private val providerClassName = parameters.get("providerclassname")

override def createDataWriter(partitionId: Int, taskId: Long, epochId: Long): DataWriter[InternalRow] = {
val columnSchema = new ColumnSchema(connector.keySchema, schema)
val client = connector.getDynamoDB(region, roleArn)
val client = connector.getDynamoDB(region, roleArn, providerClassName)
if (update) {
assert(!delete, "Please provide exactly one of 'update' or 'delete' options.")
new DynamoUpdateWriter(columnSchema, connector, client)
Expand Down

0 comments on commit a390bca

Please sign in to comment.