Skip to content

Commit

Permalink
Cosmos spark config infra (#17702)
Browse files Browse the repository at this point in the history
* cosmos spark config infra
  • Loading branch information
moderakh authored Nov 20, 2020
1 parent 744aa1c commit bb658bf
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ package com.azure.cosmos.spark

import org.apache.spark.sql.connector.write.{BatchWrite, DataWriterFactory, PhysicalWriteInfo, WriterCommitMessage}

class CosmosBatchWriter extends BatchWrite with CosmosLoggingTrait {
class CosmosBatchWriter(userConfig: Map[String, String]) extends BatchWrite with CosmosLoggingTrait {
logInfo(s"Instantiated ${this.getClass.getSimpleName}")

override def createBatchWriterFactory(physicalWriteInfo: PhysicalWriteInfo): DataWriterFactory = new CosmosDataWriteFactory()
override def createBatchWriterFactory(physicalWriteInfo: PhysicalWriteInfo): DataWriterFactory = new CosmosDataWriteFactory(userConfig)

override def commit(writerCommitMessages: Array[WriterCommitMessage]): Unit = {
// TODO
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.azure.cosmos.spark

import java.net.URL
import java.util.Locale


// each config category will be a case class:
// TODO moderakh more configs
//case class ClientConfig()
//case class CosmosBatchWriteConfig()

case class CosmosAccountConfig(endpoint: String, key: String)

object CosmosAccountConfig {
val CosmosAccountEndpointUri = CosmosConfigEntry[String](key = "spark.cosmos.accountEndpoint",
mandatory = true,
parseFromStringFunction = accountEndpointUri => {
new URL(accountEndpointUri)
accountEndpointUri
},
helpMessage = "Cosmos DB Account Endpoint Uri")

val CosmosKey = CosmosConfigEntry[String](key = "spark.cosmos.accountKey",
mandatory = true,
parseFromStringFunction = accountEndpointUri => accountEndpointUri,
helpMessage = "Cosmos DB Account Key")

def parseCosmosAccountConfig(cfg: Map[String, String]): CosmosAccountConfig = {
val endpointOpt = CosmosConfigEntry.parse(cfg, CosmosAccountEndpointUri)
val key = CosmosConfigEntry.parse(cfg, CosmosKey)

// parsing above already validated these assertions
assert(endpointOpt.isDefined)
assert(key.isDefined)

CosmosAccountConfig(endpointOpt.get, key.get)
}
}

case class CosmosContainerConfig(database: String, container: String)

object CosmosContainerConfig {
val databaseName = CosmosConfigEntry[String](key = "spark.cosmos.database",
mandatory = true,
parseFromStringFunction = database => database,
helpMessage = "Cosmos DB database name")

val containerName = CosmosConfigEntry[String](key = "spark.cosmos.container",
mandatory = true,
parseFromStringFunction = container => container,
helpMessage = "Cosmos DB container name")

def parseCosmosContainerConfig(cfg: Map[String, String]): CosmosContainerConfig = {
val databaseOpt = CosmosConfigEntry.parse(cfg, databaseName)
val containerOpt = CosmosConfigEntry.parse(cfg, containerName)

// parsing above already validated this
assert(databaseOpt.isDefined)
assert(containerOpt.isDefined)

CosmosContainerConfig(databaseOpt.get, containerOpt.get)
}
}

case class CosmosConfigEntry[T](key: String,
mandatory: Boolean,
defaultValue: Option[String] = Option.empty,
parseFromStringFunction: String => T,
helpMessage: String) {

def parse(paramAsString: String) : T = {
try {
parseFromStringFunction(paramAsString)
} catch {
case e: Exception => throw new RuntimeException(s"invalid configuration for ${key}:${paramAsString}. Config description: ${helpMessage}", e)
}
}
}

// TODO: moderakh how to merge user config with SparkConf application config?
object CosmosConfigEntry {
def parse[T](configuration: Map[String, String], configEntry: CosmosConfigEntry[T]): Option[T] = {
// TODO moderakh: where should we handle case sensitivity?
// we are doing this here per config parsing for now
val opt = configuration.map { case (key, value) => (key.toLowerCase(Locale.ROOT), value) }.get(configEntry.key.toLowerCase(Locale.ROOT))
if (opt.isDefined) {
Option.apply(configEntry.parse(opt.get))
}
else {
if (configEntry.mandatory) {
throw new RuntimeException(s"mandatory option ${configEntry.key} is missing. Config description: ${configEntry.helpMessage}")
} else {
Option.empty
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,23 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.write.{DataWriter, DataWriterFactory, WriterCommitMessage}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}

class CosmosDataWriteFactory extends DataWriterFactory with CosmosLoggingTrait {
class CosmosDataWriteFactory(userConfig: Map[String, String]) extends DataWriterFactory with CosmosLoggingTrait {
logInfo(s"Instantiated ${this.getClass.getSimpleName}")

override def createWriter(i: Int, l: Long): DataWriter[InternalRow] = new CosmosWriter()

class CosmosWriter() extends DataWriter[InternalRow] {
logInfo(s"Instantiated ${this.getClass.getSimpleName}")

// TODO moderakh account config and databaseName, containerName need to passed down from the user
val cosmosAccountConfig = CosmosAccountConfig.parseCosmosAccountConfig(userConfig)
val cosmosTargetContainerConfig = CosmosContainerConfig.parseCosmosContainerConfig(userConfig)

// TODO moderakh: this needs to be shared to avoid creating multiple clients
val client = new CosmosClientBuilder()
.key(TestConfigurations.MASTER_KEY)
.endpoint(TestConfigurations.HOST)
.key(cosmosAccountConfig.key)
.endpoint(cosmosAccountConfig.endpoint)
.consistencyLevel(ConsistencyLevel.EVENTUAL)
.buildAsyncClient();
val databaseName = "testDB"
val containerName = "testContainer"

override def write(internalRow: InternalRow): Unit = {
// TODO moderakh: schema is hard coded for now to make end to end TestE2EMain work implement schema inference code
Expand All @@ -36,8 +37,8 @@ class CosmosDataWriteFactory extends DataWriterFactory with CosmosLoggingTrait {
if (!objectNode.has("id")) {
objectNode.put("id", UUID.randomUUID().toString)
}
client.getDatabase(databaseName)
.getContainer(containerName)
client.getDatabase(cosmosTargetContainerConfig.database)
.getContainer(cosmosTargetContainerConfig.container)
.createItem(objectNode)
.block()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ import scala.collection.JavaConverters._
* CosmosTable is the entry point this is registered in the spark
* @param userProvidedSchema
* @param transforms
* @param map
* @param userConfig
*/
class CosmosTable(val userProvidedSchema: StructType,
val transforms: Array[Transform],
val map: util.Map[String, String])
val userConfig: util.Map[String, String])
extends Table with SupportsWrite with CosmosLoggingTrait {
logInfo(s"Instantiated ${this.getClass.getSimpleName}")

Expand All @@ -35,6 +35,6 @@ class CosmosTable(val userProvidedSchema: StructType,

override def capabilities(): util.Set[TableCapability] = Set(TableCapability.BATCH_WRITE).asJava

override def newWriteBuilder(logicalWriteInfo: LogicalWriteInfo): WriteBuilder = new CosmosWriterBuilder
override def newWriteBuilder(logicalWriteInfo: LogicalWriteInfo): WriteBuilder = new CosmosWriterBuilder(userConfig.asScala.toMap)

}
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ package com.azure.cosmos.spark

import org.apache.spark.sql.connector.write.{BatchWrite, WriteBuilder}

class CosmosWriterBuilder extends WriteBuilder with CosmosLoggingTrait {
class CosmosWriterBuilder(userConfig: Map[String, String]) extends WriteBuilder with CosmosLoggingTrait {
logInfo(s"Instantiated ${this.getClass.getSimpleName}")

override def buildForBatch(): BatchWrite = new CosmosBatchWriter()
override def buildForBatch(): BatchWrite = new CosmosBatchWriter(userConfig: Map[String, String])
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package com.azure.cosmos.spark

import org.assertj.core.api.Assertions.assertThat

class CosmosConfigSpec extends UnitSpec {
//scalastyle:off multiple.string.literals

"account endpoint" should "be parsed" in {
val userConfig = Map(
"spark.cosmos.accountEndpoint" -> "https://localhsot:8081",
"spark.cosmos.accountKey" -> "xyz"
)

val endpointConfig = CosmosAccountConfig.parseCosmosAccountConfig(userConfig)

assertThat(endpointConfig.endpoint).isEqualTo( "https://localhsot:8081")
assertThat(endpointConfig.key).isEqualTo( "xyz")
}

"account endpoint" should "be validated" in {
val userConfig = Map(
"spark.cosmos.accountEndpoint" -> "invalidUrl",
"spark.cosmos.accountKey" -> "xyz"
)

try {
CosmosAccountConfig.parseCosmosAccountConfig(userConfig)
fail("invalid URL")
} catch {
case e: Exception => assertThat(e.getMessage).isEqualTo(
"invalid configuration for spark.cosmos.accountEndpoint:invalidUrl." +
" Config description: Cosmos DB Account Endpoint Uri")
}
}

"account endpoint" should "mandatory config" in {
val userConfig = Map(
"spark.cosmos.accountKey" -> "xyz"
)

try {
CosmosAccountConfig.parseCosmosAccountConfig(userConfig)
fail("missing URL")
} catch {
case e: Exception => assertThat(e.getMessage).isEqualTo(
"mandatory option spark.cosmos.accountEndpoint is missing." +
" Config description: Cosmos DB Account Endpoint Uri")
}
}
//scalastyle:on multiple.string.literals
}

0 comments on commit bb658bf

Please sign in to comment.