Skip to content

[SPARK-12010][SQL] Add columnMapping support #10312

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions python/pyspark/sql/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ def orc(self, path, mode=None, partitionBy=None):
self._jwrite.orc(path)

@since(1.4)
def jdbc(self, url, table, mode=None, properties=None):
def jdbc(self, url, table, mode=None, properties=None, columnMapping=None):
"""Saves the content of the :class:`DataFrame` to a external database table via JDBC.

.. note:: Don't create too many partitions in parallel on a large cluster;\
Expand All @@ -511,13 +511,20 @@ def jdbc(self, url, table, mode=None, properties=None):
:param properties: JDBC database connection arguments, a list of
arbitrary string tag/value. Normally at least a
"user" and "password" property should be included.
:param columnMapping: optional column name mapping from DF field names to
JDBC table column names.
"""
if properties is None:
properties = dict()
jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)()
for k in properties:
jprop.setProperty(k, properties[k])
self._jwrite.mode(mode).jdbc(url, table, jprop)
if columnMapping is None:
columnMapping = dict()
jcolumnMapping = JavaClass("java.util.HashMap", self._sqlContext._sc._gateway._gateway_client)()
for k in columnMapping:
jcolumnMapping.put(k, columnMapping[k])
self._jwrite.mode(mode).jdbc(url, table, jprop, jcolumnMapping)


def _test():
Expand Down
42 changes: 40 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -265,10 +265,22 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @param connectionProperties JDBC database connection arguments, a list of arbitrary string
* tag/value. Normally at least a "user" and "password" property
* should be included.
* @param columnMapping Maps DataFrame column names to target table column names.
* This parameter can be omitted if the target table has/will be
* created in this method and therefore the target table structure
* matches the DF structure.
* This parameter is stongly recommended, if target table already
* exists and has been created outside of this method.
* If omitted, the SQL insert statement will not include column names,
* which means that the field ordering of the DataFrame must match
* the target table column ordering.
*
* @since 1.4.0
*/
def jdbc(url: String, table: String, connectionProperties: Properties): Unit = {
def jdbc(url: String,
table: String,
connectionProperties: Properties,
columnMapping: scala.collection.Map[String, String]): Unit = {
val props = new Properties()
extraOptions.foreach { case (key, value) =>
props.put(key, value)
Expand Down Expand Up @@ -303,7 +315,33 @@ final class DataFrameWriter private[sql](df: DataFrame) {
conn.close()
}

JdbcUtils.saveTable(df, url, table, props)
JdbcUtils.saveTable(df, url, table, props, columnMapping)
}

/**
* (java-friendly) version of
* [[DataFrameWriter.jdbc(String,String,Properties,scala.collection.Map[String,String]):]]
*/
def jdbc(url: String,
table: String,
connectionProperties: Properties,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise these are indented 1 space too little

columnMapping: java.util.Map[String, String]): Unit = {
// Convert java Map into scala Map
var sColumnMapping: scala.collection.Map[String, String] = null
if (columnMapping!=null) {
sColumnMapping = columnMapping.asScala
}
jdbc(url, table, connectionProperties, sColumnMapping)
}

/**
* Three parameter version of
* [[DataFrameWriter.jdbc(String,String,Properties,scala.collection.Map[String,String]):]]
*/
def jdbc(url: String,
table: String,
connectionProperties: Properties): Unit = {
jdbc(url, table, connectionProperties, null.asInstanceOf[Map[String, String]])
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,16 @@ object JdbcUtils extends Logging {

/**
* Returns a PreparedStatement that inserts a row into table via conn.
* If a columnMapping is provided, it will be used to translate RDD
* column names into table column names.
*/
def insertStatement(conn: Connection, table: String, rddSchema: StructType): PreparedStatement = {
val sql = new StringBuilder(s"INSERT INTO $table VALUES (")
var fieldsLeft = rddSchema.fields.length
while (fieldsLeft > 0) {
sql.append("?")
if (fieldsLeft > 1) sql.append(", ") else sql.append(")")
fieldsLeft = fieldsLeft - 1
}
conn.prepareStatement(sql.toString())
def insertStatement(conn: Connection,
dialect: JdbcDialect,
table: String,
rddSchema: StructType,
columnMapping: scala.collection.Map[String, String]): PreparedStatement = {
val sql = dialect.getInsertStatement(table, rddSchema, columnMapping)
conn.prepareStatement(sql)
}

/**
Expand Down Expand Up @@ -122,6 +122,7 @@ object JdbcUtils extends Logging {
iterator: Iterator[Row],
rddSchema: StructType,
nullTypes: Array[Int],
columnMapping: scala.collection.Map[String, String] = null,
batchSize: Int,
dialect: JdbcDialect): Iterator[Byte] = {
val conn = getConnection()
Expand All @@ -139,7 +140,7 @@ object JdbcUtils extends Logging {
if (supportsTransactions) {
conn.setAutoCommit(false) // Everything in the same db transaction.
}
val stmt = insertStatement(conn, table, rddSchema)
val stmt = insertStatement(conn, dialect, table, rddSchema, columnMapping)
try {
var rowCount = 0
while (iterator.hasNext) {
Expand Down Expand Up @@ -234,7 +235,8 @@ object JdbcUtils extends Logging {
df: DataFrame,
url: String,
table: String,
properties: Properties = new Properties()) {
properties: Properties = new Properties(),
columnMapping: scala.collection.Map[String, String] = null) {
val dialect = JdbcDialects.get(url)
val nullTypes: Array[Int] = df.schema.fields.map { field =>
getJdbcType(field.dataType, dialect).jdbcNullType
Expand All @@ -245,7 +247,8 @@ object JdbcUtils extends Logging {
val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties)
val batchSize = properties.getProperty("batchsize", "1000").toInt
df.foreachPartition { iterator =>
savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect)
savePartition(getConnection, table, iterator, rddSchema, nullTypes,
columnMapping, batchSize, dialect)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,30 @@ abstract class JdbcDialect extends Serializable {
def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = {
}

/**
* Get the SQL statement that should be used to insert new records into the table.
* Dialects can override this method to return a statement that works best in a particular
* database.
* @param table The name of the table.
* @param rddSchema The schema of DataFrame to be inserted
* @param columnMapping An optional mapping from DataFrame field names to database column
* names
* @return The SQL statement to use for inserting into the table.
*/
def getInsertStatement(table: String,
rddSchema: StructType,
columnMapping: scala.collection.Map[String, String] = null): String = {
if (columnMapping == null) {
rddSchema.fields.map(_ => "?")
.mkString(s"INSERT INTO $table VALUES (", ", ", " ) ")
} else {
rddSchema.fields.map(
field => columnMapping.getOrElse(field.name, field.name)
).mkString(s"INSERT INTO $table ( ", ", ", " ) " ) +
rddSchema.fields.map(field => "?").mkString("VALUES ( ", ", ", " )" )
}
}

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,16 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {

df.write.jdbc(url, "TEST.BASICCREATETEST", new Properties)
assert(2 === sqlContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).count)
assert(
2 === sqlContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties()).collect()(0).length)
}

test("Basic CREATE with columnMapping") {
val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2)

val columnMapping = Map("name" -> "name", "id" -> "id")
df.write.jdbc(url, "TEST.BASICCREATETEST", new Properties, columnMapping)
assert(2 === sqlContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties()).count)
assert(
2 === sqlContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).collect()(0).length)
}
Expand Down