Skip to content

Commit 4dce45a

Browse files
gengliangwangcloud-fan
authored andcommitted
[SPARK-26744][SQL] Support schema validation in FileDataSourceV2 framework
## What changes were proposed in this pull request? The file source has a schema validation feature, which validates 2 schemas: 1. the user-specified schema when reading. 2. the schema of input data when writing. If a file source doesn't support the schema, we can fail the query earlier. This PR is to implement the same feature in the `FileDataSourceV2` framework. Comparing to `FileFormat`, `FileDataSourceV2` has multiple layers. The API is added in two places: 1. Read path: the table schema is determined in `TableProvider.getTable`. The actual read schema can be a subset of the table schema. This PR proposes to validate the actual read schema in `FileScan`. 2. Write path: validate the actual output schema in `FileWriteBuilder`. ## How was this patch tested? Unit test Closes #23714 from gengliangwang/schemaValidationV2. Authored-by: Gengliang Wang <gengliang.wang@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 4cabab8 commit 4dce45a

File tree

6 files changed

+167
-77
lines changed

6 files changed

+167
-77
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,39 @@ package org.apache.spark.sql.execution.datasources.v2
1818

1919
import org.apache.hadoop.fs.Path
2020

21-
import org.apache.spark.sql.SparkSession
21+
import org.apache.spark.sql.{AnalysisException, SparkSession}
2222
import org.apache.spark.sql.execution.PartitionedFileUtil
2323
import org.apache.spark.sql.execution.datasources._
2424
import org.apache.spark.sql.sources.v2.reader.{Batch, InputPartition, Scan}
25-
import org.apache.spark.sql.types.StructType
25+
import org.apache.spark.sql.types.{DataType, StructType}
2626

2727
abstract class FileScan(
2828
sparkSession: SparkSession,
29-
fileIndex: PartitioningAwareFileIndex) extends Scan with Batch {
29+
fileIndex: PartitioningAwareFileIndex,
30+
readSchema: StructType) extends Scan with Batch {
3031
/**
3132
* Returns whether a file with `path` could be split or not.
3233
*/
3334
def isSplitable(path: Path): Boolean = {
3435
false
3536
}
3637

38+
/**
39+
* Returns whether this format supports the given [[DataType]] in write path.
40+
* By default all data types are supported.
41+
*/
42+
def supportsDataType(dataType: DataType): Boolean = true
43+
44+
/**
45+
* The string that represents the format that this data source provider uses. This is
46+
* overridden by children to provide a nice alias for the data source. For example:
47+
*
48+
* {{{
49+
* override def formatName(): String = "ORC"
50+
* }}}
51+
*/
52+
def formatName: String
53+
3754
protected def partitions: Seq[FilePartition] = {
3855
val selectedPartitions = fileIndex.listFiles(Seq.empty, Seq.empty)
3956
val maxSplitBytes = FilePartition.maxSplitBytes(sparkSession, selectedPartitions)
@@ -57,5 +74,13 @@ abstract class FileScan(
5774
partitions.toArray
5875
}
5976

60-
override def toBatch: Batch = this
77+
override def toBatch: Batch = {
78+
readSchema.foreach { field =>
79+
if (!supportsDataType(field.dataType)) {
80+
throw new AnalysisException(
81+
s"$formatName data source does not support ${field.dataType.catalogString} data type.")
82+
}
83+
}
84+
this
85+
}
6186
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ import org.apache.spark.sql.execution.metric.SQLMetric
3434
import org.apache.spark.sql.internal.SQLConf
3535
import org.apache.spark.sql.sources.v2.DataSourceOptions
3636
import org.apache.spark.sql.sources.v2.writer.{BatchWrite, SupportsSaveMode, WriteBuilder}
37-
import org.apache.spark.sql.types.StructType
37+
import org.apache.spark.sql.types.{DataType, StructType}
3838
import org.apache.spark.util.SerializableConfiguration
3939

4040
abstract class FileWriteBuilder(options: DataSourceOptions)
@@ -104,12 +104,34 @@ abstract class FileWriteBuilder(options: DataSourceOptions)
104104
options: Map[String, String],
105105
dataSchema: StructType): OutputWriterFactory
106106

107+
/**
108+
* Returns whether this format supports the given [[DataType]] in write path.
109+
* By default all data types are supported.
110+
*/
111+
def supportsDataType(dataType: DataType): Boolean = true
112+
113+
/**
114+
* The string that represents the format that this data source provider uses. This is
115+
* overridden by children to provide a nice alias for the data source. For example:
116+
*
117+
* {{{
118+
* override def formatName(): String = "ORC"
119+
* }}}
120+
*/
121+
def formatName: String
122+
107123
private def validateInputs(): Unit = {
108124
assert(schema != null, "Missing input data schema")
109125
assert(queryId != null, "Missing query ID")
110126
assert(mode != null, "Missing save mode")
111127
assert(options.paths().length == 1)
112128
DataSource.validateSchema(schema)
129+
schema.foreach { field =>
130+
if (!supportsDataType(field.dataType)) {
131+
throw new AnalysisException(
132+
s"$formatName data source does not support ${field.dataType.catalogString} data type.")
133+
}
134+
}
113135
}
114136

115137
private def getJobInstance(hadoopConf: Configuration, path: Path): Job = {

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import org.apache.spark.sql.execution.datasources._
2020
import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
2121
import org.apache.spark.sql.execution.datasources.v2._
2222
import org.apache.spark.sql.sources.v2.{DataSourceOptions, Table}
23-
import org.apache.spark.sql.types.StructType
23+
import org.apache.spark.sql.types._
2424

2525
class OrcDataSourceV2 extends FileDataSourceV2 {
2626

@@ -42,3 +42,20 @@ class OrcDataSourceV2 extends FileDataSourceV2 {
4242
OrcTable(tableName, sparkSession, options, Some(schema))
4343
}
4444
}
45+
46+
object OrcDataSourceV2 {
47+
def supportsDataType(dataType: DataType): Boolean = dataType match {
48+
case _: AtomicType => true
49+
50+
case st: StructType => st.forall { f => supportsDataType(f.dataType) }
51+
52+
case ArrayType(elementType, _) => supportsDataType(elementType)
53+
54+
case MapType(keyType, valueType, _) =>
55+
supportsDataType(keyType) && supportsDataType(valueType)
56+
57+
case udt: UserDefinedType[_] => supportsDataType(udt.sqlType)
58+
59+
case _ => false
60+
}
61+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,15 @@ import org.apache.spark.sql.SparkSession
2323
import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
2424
import org.apache.spark.sql.execution.datasources.v2.FileScan
2525
import org.apache.spark.sql.sources.v2.reader.PartitionReaderFactory
26-
import org.apache.spark.sql.types.StructType
26+
import org.apache.spark.sql.types.{DataType, StructType}
2727
import org.apache.spark.util.SerializableConfiguration
2828

2929
case class OrcScan(
3030
sparkSession: SparkSession,
3131
hadoopConf: Configuration,
3232
fileIndex: PartitioningAwareFileIndex,
3333
dataSchema: StructType,
34-
readSchema: StructType) extends FileScan(sparkSession, fileIndex) {
34+
readSchema: StructType) extends FileScan(sparkSession, fileIndex, readSchema) {
3535
override def isSplitable(path: Path): Boolean = true
3636

3737
override def createReaderFactory(): PartitionReaderFactory = {
@@ -40,4 +40,10 @@ case class OrcScan(
4040
OrcPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf,
4141
dataSchema, fileIndex.partitionSchema, readSchema)
4242
}
43+
44+
override def supportsDataType(dataType: DataType): Boolean = {
45+
OrcDataSourceV2.supportsDataType(dataType)
46+
}
47+
48+
override def formatName: String = "ORC"
4349
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWriteBuilder.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,4 +63,10 @@ class OrcWriteBuilder(options: DataSourceOptions) extends FileWriteBuilder(optio
6363
}
6464
}
6565
}
66+
67+
override def supportsDataType(dataType: DataType): Boolean = {
68+
OrcDataSourceV2.supportsDataType(dataType)
69+
}
70+
71+
override def formatName: String = "ORC"
6672
}

sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala

Lines changed: 83 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -329,83 +329,97 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo
329329
test("SPARK-24204 error handling for unsupported Interval data types - csv, json, parquet, orc") {
330330
withTempDir { dir =>
331331
val tempDir = new File(dir, "files").getCanonicalPath
332-
// TODO(SPARK-26744): support data type validating in V2 data source, and test V2 as well.
333-
withSQLConf(SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> "orc") {
334-
// write path
335-
Seq("csv", "json", "parquet", "orc").foreach { format =>
336-
var msg = intercept[AnalysisException] {
337-
sql("select interval 1 days").write.format(format).mode("overwrite").save(tempDir)
338-
}.getMessage
339-
assert(msg.contains("Cannot save interval data type into external storage."))
340-
341-
msg = intercept[AnalysisException] {
342-
spark.udf.register("testType", () => new IntervalData())
343-
sql("select testType()").write.format(format).mode("overwrite").save(tempDir)
344-
}.getMessage
345-
assert(msg.toLowerCase(Locale.ROOT)
346-
.contains(s"$format data source does not support calendarinterval data type."))
332+
Seq(true, false).foreach { useV1 =>
333+
val useV1List = if (useV1) {
334+
"orc"
335+
} else {
336+
""
347337
}
338+
def errorMessage(format: String, isWrite: Boolean): String = {
339+
if (isWrite && (useV1 || format != "orc")) {
340+
"cannot save interval data type into external storage."
341+
} else {
342+
s"$format data source does not support calendarinterval data type."
343+
}
344+
}
345+
346+
withSQLConf(SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> useV1List) {
347+
// write path
348+
Seq("csv", "json", "parquet", "orc").foreach { format =>
349+
var msg = intercept[AnalysisException] {
350+
sql("select interval 1 days").write.format(format).mode("overwrite").save(tempDir)
351+
}.getMessage
352+
assert(msg.toLowerCase(Locale.ROOT).contains(errorMessage(format, true)))
353+
}
348354

349-
// read path
350-
Seq("parquet", "csv").foreach { format =>
351-
var msg = intercept[AnalysisException] {
352-
val schema = StructType(StructField("a", CalendarIntervalType, true) :: Nil)
353-
spark.range(1).write.format(format).mode("overwrite").save(tempDir)
354-
spark.read.schema(schema).format(format).load(tempDir).collect()
355-
}.getMessage
356-
assert(msg.toLowerCase(Locale.ROOT)
357-
.contains(s"$format data source does not support calendarinterval data type."))
358-
359-
msg = intercept[AnalysisException] {
360-
val schema = StructType(StructField("a", new IntervalUDT(), true) :: Nil)
361-
spark.range(1).write.format(format).mode("overwrite").save(tempDir)
362-
spark.read.schema(schema).format(format).load(tempDir).collect()
363-
}.getMessage
364-
assert(msg.toLowerCase(Locale.ROOT)
365-
.contains(s"$format data source does not support calendarinterval data type."))
355+
// read path
356+
Seq("parquet", "csv").foreach { format =>
357+
var msg = intercept[AnalysisException] {
358+
val schema = StructType(StructField("a", CalendarIntervalType, true) :: Nil)
359+
spark.range(1).write.format(format).mode("overwrite").save(tempDir)
360+
spark.read.schema(schema).format(format).load(tempDir).collect()
361+
}.getMessage
362+
assert(msg.toLowerCase(Locale.ROOT).contains(errorMessage(format, false)))
363+
364+
msg = intercept[AnalysisException] {
365+
val schema = StructType(StructField("a", new IntervalUDT(), true) :: Nil)
366+
spark.range(1).write.format(format).mode("overwrite").save(tempDir)
367+
spark.read.schema(schema).format(format).load(tempDir).collect()
368+
}.getMessage
369+
assert(msg.toLowerCase(Locale.ROOT).contains(errorMessage(format, false)))
370+
}
366371
}
367372
}
368373
}
369374
}
370375

371376
test("SPARK-24204 error handling for unsupported Null data types - csv, parquet, orc") {
372-
// TODO(SPARK-26744): support data type validating in V2 data source, and test V2 as well.
373-
withSQLConf(SQLConf.USE_V1_SOURCE_READER_LIST.key -> "orc",
374-
SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> "orc") {
375-
withTempDir { dir =>
376-
val tempDir = new File(dir, "files").getCanonicalPath
377-
378-
Seq("parquet", "csv", "orc").foreach { format =>
379-
// write path
380-
var msg = intercept[AnalysisException] {
381-
sql("select null").write.format(format).mode("overwrite").save(tempDir)
382-
}.getMessage
383-
assert(msg.toLowerCase(Locale.ROOT)
384-
.contains(s"$format data source does not support null data type."))
385-
386-
msg = intercept[AnalysisException] {
387-
spark.udf.register("testType", () => new NullData())
388-
sql("select testType()").write.format(format).mode("overwrite").save(tempDir)
389-
}.getMessage
390-
assert(msg.toLowerCase(Locale.ROOT)
391-
.contains(s"$format data source does not support null data type."))
392-
393-
// read path
394-
msg = intercept[AnalysisException] {
395-
val schema = StructType(StructField("a", NullType, true) :: Nil)
396-
spark.range(1).write.format(format).mode("overwrite").save(tempDir)
397-
spark.read.schema(schema).format(format).load(tempDir).collect()
398-
}.getMessage
399-
assert(msg.toLowerCase(Locale.ROOT)
400-
.contains(s"$format data source does not support null data type."))
401-
402-
msg = intercept[AnalysisException] {
403-
val schema = StructType(StructField("a", new NullUDT(), true) :: Nil)
404-
spark.range(1).write.format(format).mode("overwrite").save(tempDir)
405-
spark.read.schema(schema).format(format).load(tempDir).collect()
406-
}.getMessage
407-
assert(msg.toLowerCase(Locale.ROOT)
408-
.contains(s"$format data source does not support null data type."))
377+
Seq(true, false).foreach { useV1 =>
378+
val useV1List = if (useV1) {
379+
"orc"
380+
} else {
381+
""
382+
}
383+
def errorMessage(format: String): String = {
384+
s"$format data source does not support null data type."
385+
}
386+
withSQLConf(SQLConf.USE_V1_SOURCE_READER_LIST.key -> useV1List,
387+
SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> useV1List) {
388+
withTempDir { dir =>
389+
val tempDir = new File(dir, "files").getCanonicalPath
390+
391+
Seq("parquet", "csv", "orc").foreach { format =>
392+
// write path
393+
var msg = intercept[AnalysisException] {
394+
sql("select null").write.format(format).mode("overwrite").save(tempDir)
395+
}.getMessage
396+
assert(msg.toLowerCase(Locale.ROOT)
397+
.contains(errorMessage(format)))
398+
399+
msg = intercept[AnalysisException] {
400+
spark.udf.register("testType", () => new NullData())
401+
sql("select testType()").write.format(format).mode("overwrite").save(tempDir)
402+
}.getMessage
403+
assert(msg.toLowerCase(Locale.ROOT)
404+
.contains(errorMessage(format)))
405+
406+
// read path
407+
msg = intercept[AnalysisException] {
408+
val schema = StructType(StructField("a", NullType, true) :: Nil)
409+
spark.range(1).write.format(format).mode("overwrite").save(tempDir)
410+
spark.read.schema(schema).format(format).load(tempDir).collect()
411+
}.getMessage
412+
assert(msg.toLowerCase(Locale.ROOT)
413+
.contains(errorMessage(format)))
414+
415+
msg = intercept[AnalysisException] {
416+
val schema = StructType(StructField("a", new NullUDT(), true) :: Nil)
417+
spark.range(1).write.format(format).mode("overwrite").save(tempDir)
418+
spark.read.schema(schema).format(format).load(tempDir).collect()
419+
}.getMessage
420+
assert(msg.toLowerCase(Locale.ROOT)
421+
.contains(errorMessage(format)))
422+
}
409423
}
410424
}
411425
}

0 commit comments

Comments
 (0)