Skip to content

Commit 18907cb

Browse files
committed
overlapped columns between data and partition schema in data source tables
1 parent 1051ebe commit 18907cb

File tree

4 files changed

+85
-25
lines changed

4 files changed

+85
-25
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.catalog._
2424
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2525
import org.apache.spark.sql.execution.datasources._
2626
import org.apache.spark.sql.sources.BaseRelation
27+
import org.apache.spark.sql.types.StructType
2728

2829
/**
2930
* A command used to create a data source table.
@@ -85,14 +86,28 @@ case class CreateDataSourceTableCommand(table: CatalogTable, ignoreIfExists: Boo
8586
}
8687
}
8788

88-
val newTable = table.copy(
89-
schema = dataSource.schema,
90-
partitionColumnNames = partitionColumnNames,
91-
// If metastore partition management for file source tables is enabled, we start off with
92-
// partition provider hive, but no partitions in the metastore. The user has to call
93-
// `msck repair table` to populate the table partitions.
94-
tracksPartitionsInCatalog = partitionColumnNames.nonEmpty &&
95-
sessionState.conf.manageFilesourcePartitions)
89+
val newTable = dataSource match {
90+
// Since Spark 2.1, we store the inferred schema of data source in metastore, to avoid
91+
// inferring the schema again at read path. However if the data source has overlapped columns
92+
// between data and partition schema, we can't store it in metastore as it breaks the
93+
// assumption of table schema. Here we fallback to the behavior of Spark prior to 2.1, store
94+
// empty schema in metastore and infer it at runtime. Note that this also means the new
95+
// scalable partitioning handling feature(introduced at Spark 2.1) is disabled in this case.
96+
case r: HadoopFsRelation if r.overlappedPartCols.nonEmpty =>
97+
table.copy(schema = new StructType(), partitionColumnNames = Nil)
98+
99+
case _ =>
100+
table.copy(
101+
schema = dataSource.schema,
102+
partitionColumnNames = partitionColumnNames,
103+
// If metastore partition management for file source tables is enabled, we start off with
104+
// partition provider hive, but no partitions in the metastore. The user has to call
105+
// `msck repair table` to populate the table partitions.
106+
tracksPartitionsInCatalog = partitionColumnNames.nonEmpty &&
107+
sessionState.conf.manageFilesourcePartitions)
108+
109+
}
110+
96111
// We will return Nil or throw exception at the beginning if the table already exists, so when
97112
// we reach here, the table should not exist and we should set `ignoreIfExists` to false.
98113
sessionState.catalog.createTable(newTable, ignoreIfExists = false)

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.execution.datasources
1919

20+
import java.util.Locale
21+
2022
import scala.collection.mutable
2123

2224
import org.apache.spark.sql.{SparkSession, SQLContext}
@@ -50,15 +52,22 @@ case class HadoopFsRelation(
5052

5153
override def sqlContext: SQLContext = sparkSession.sqlContext
5254

53-
val schema: StructType = {
54-
val getColName: (StructField => String) =
55-
if (sparkSession.sessionState.conf.caseSensitiveAnalysis) _.name else _.name.toLowerCase
56-
val overlappedPartCols = mutable.Map.empty[String, StructField]
57-
partitionSchema.foreach { partitionField =>
58-
if (dataSchema.exists(getColName(_) == getColName(partitionField))) {
59-
overlappedPartCols += getColName(partitionField) -> partitionField
60-
}
55+
private def getColName(f: StructField): String = {
56+
if (sparkSession.sessionState.conf.caseSensitiveAnalysis) {
57+
f.name
58+
} else {
59+
f.name.toLowerCase(Locale.ROOT)
60+
}
61+
}
62+
63+
val overlappedPartCols = mutable.Map.empty[String, StructField]
64+
partitionSchema.foreach { partitionField =>
65+
if (dataSchema.exists(getColName(_) == getColName(partitionField))) {
66+
overlappedPartCols += getColName(partitionField) -> partitionField
6167
}
68+
}
69+
70+
val schema: StructType = {
6271
StructType(dataSchema.map(f => overlappedPartCols.getOrElse(getColName(f), f)) ++
6372
partitionSchema.filterNot(f => overlappedPartCols.contains(getColName(f))))
6473
}

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2741,4 +2741,20 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
27412741
assert (aggregateExpressions.isDefined)
27422742
assert (aggregateExpressions.get.size == 2)
27432743
}
2744+
2745+
test("SPARK-22356: overlapped columns between data and partition schema in data source tables") {
2746+
withTempPath { path =>
2747+
Seq((1, 1, 1), (1, 2, 1)).toDF("i", "p", "j")
2748+
.write.mode("overwrite").parquet(new File(path, "p=1").getCanonicalPath)
2749+
withTable("t") {
2750+
sql(s"create table t using parquet options(path='${path.getCanonicalPath}')")
2751+
// We should respect the column order in data schema.
2752+
assert(spark.table("t").columns === Array("i", "p", "j"))
2753+
checkAnswer(spark.table("t"), Row(1, 1, 1) :: Row(1, 1, 1) :: Nil)
2754+
// The DESC TABLE should report same schema as table scan.
2755+
assert(sql("desc t").select("col_name")
2756+
.as[String].collect().mkString(",").contains("i,p,j"))
2757+
}
2758+
}
2759+
}
27442760
}

sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils {
4040
private val tmpDataDir = Utils.createTempDir(namePrefix = "test-data")
4141
// For local test, you can set `sparkTestingDir` to a static value like `/tmp/test-spark`, to
4242
// avoid downloading Spark of different versions in each run.
43-
private val sparkTestingDir = Utils.createTempDir(namePrefix = "test-spark")
43+
private val sparkTestingDir = new File("/tmp/test-spark")
4444
private val unusedJar = TestUtils.createJarWithClasses(Seq.empty)
4545

4646
override def afterAll(): Unit = {
@@ -77,35 +77,38 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils {
7777
super.beforeAll()
7878

7979
val tempPyFile = File.createTempFile("test", ".py")
80+
// scalastyle:off line.size.limit
8081
Files.write(tempPyFile.toPath,
8182
s"""
8283
|from pyspark.sql import SparkSession
84+
|import os
8385
|
8486
|spark = SparkSession.builder.enableHiveSupport().getOrCreate()
8587
|version_index = spark.conf.get("spark.sql.test.version.index", None)
8688
|
8789
|spark.sql("create table data_source_tbl_{} using json as select 1 i".format(version_index))
8890
|
89-
|spark.sql("create table hive_compatible_data_source_tbl_" + version_index + \\
90-
| " using parquet as select 1 i")
91+
|spark.sql("create table hive_compatible_data_source_tbl_{} using parquet as select 1 i".format(version_index))
9192
|
9293
|json_file = "${genDataDir("json_")}" + str(version_index)
9394
|spark.range(1, 2).selectExpr("cast(id as int) as i").write.json(json_file)
94-
|spark.sql("create table external_data_source_tbl_" + version_index + \\
95-
| "(i int) using json options (path '{}')".format(json_file))
95+
|spark.sql("create table external_data_source_tbl_{}(i int) using json options (path '{}')".format(version_index, json_file))
9696
|
9797
|parquet_file = "${genDataDir("parquet_")}" + str(version_index)
9898
|spark.range(1, 2).selectExpr("cast(id as int) as i").write.parquet(parquet_file)
99-
|spark.sql("create table hive_compatible_external_data_source_tbl_" + version_index + \\
100-
| "(i int) using parquet options (path '{}')".format(parquet_file))
99+
|spark.sql("create table hive_compatible_external_data_source_tbl_{}(i int) using parquet options (path '{}')".format(version_index, parquet_file))
101100
|
102101
|json_file2 = "${genDataDir("json2_")}" + str(version_index)
103102
|spark.range(1, 2).selectExpr("cast(id as int) as i").write.json(json_file2)
104-
|spark.sql("create table external_table_without_schema_" + version_index + \\
105-
| " using json options (path '{}')".format(json_file2))
103+
|spark.sql("create table external_table_without_schema_{} using json options (path '{}')".format(version_index, json_file2))
104+
|
105+
|parquet_file2 = "${genDataDir("parquet2_")}" + str(version_index)
106+
|spark.range(1, 3).selectExpr("1 as i", "cast(id as int) as p", "1 as j").write.parquet(os.path.join(parquet_file2, "p=1"))
107+
|spark.sql("create table tbl_with_col_overlap_{} using parquet options(path '{}')".format(version_index, parquet_file2))
106108
|
107109
|spark.sql("create view v_{} as select 1 i".format(version_index))
108110
""".stripMargin.getBytes("utf8"))
111+
// scalastyle:on line.size.limit
109112

110113
PROCESS_TABLES.testingVersions.zipWithIndex.foreach { case (version, index) =>
111114
val sparkHome = new File(sparkTestingDir, s"spark-$version")
@@ -153,6 +156,7 @@ object PROCESS_TABLES extends QueryTest with SQLTestUtils {
153156
.enableHiveSupport()
154157
.getOrCreate()
155158
spark = session
159+
import session.implicits._
156160

157161
testingVersions.indices.foreach { index =>
158162
Seq(
@@ -194,6 +198,22 @@ object PROCESS_TABLES extends QueryTest with SQLTestUtils {
194198

195199
// test permanent view
196200
checkAnswer(sql(s"select i from v_$index"), Row(1))
201+
202+
// SPARK-22356: overlapped columns between data and partition schema in data source tables
203+
val tbl_with_col_overlap = s"tbl_with_col_overlap_$index"
204+
// For Spark 2.2.0 and 2.1.x, the behavior is different from Spark 2.0.
205+
if (testingVersions(index).startsWith("2.1") || testingVersions(index) == "2.2.0") {
206+
spark.sql("msck repair table " + tbl_with_col_overlap)
207+
assert(spark.table(tbl_with_col_overlap).columns === Array("i", "j", "p"))
208+
checkAnswer(spark.table(tbl_with_col_overlap), Row(1, 1, 1) :: Row(1, 1, 1) :: Nil)
209+
assert(sql("desc " + tbl_with_col_overlap).select("col_name")
210+
.as[String].collect().mkString(",").contains("i,j,p"))
211+
} else {
212+
assert(spark.table(tbl_with_col_overlap).columns === Array("i", "p", "j"))
213+
checkAnswer(spark.table(tbl_with_col_overlap), Row(1, 1, 1) :: Row(1, 1, 1) :: Nil)
214+
assert(sql("desc " + tbl_with_col_overlap).select("col_name")
215+
.as[String].collect().mkString(",").contains("i,p,j"))
216+
}
197217
}
198218
}
199219
}

0 commit comments

Comments
 (0)