Skip to content

[SPARK-26263][SQL] Validate partition values with user provided schema #23215

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/sql-migration-guide-upgrade.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ displayTitle: Spark SQL Upgrading Guide

- In Spark version 2.4 and earlier, users can create a map with duplicated keys via built-in functions like `CreateMap`, `StringToMap`, etc. The behavior of map with duplicated keys is undefined, e.g. map look up respects the duplicated key appears first, `Dataset.collect` only keeps the duplicated key appears last, `MapKeys` returns duplicated keys, etc. Since Spark 3.0, these built-in functions will remove duplicated map keys with last wins policy. Users may still read map values with duplicated keys from data sources which do not enforce it (e.g. Parquet), the behavior will be udefined.

- In Spark version 2.4 and earlier, partition column value is converted as null if it can't be casted to corresponding user provided schema. Since 3.0, partition column value is validated with user provided schema. An exception is thrown if the validation fails. You can disable such validation by setting `spark.sql.sources.validatePartitionColumns` to `false`.

- In Spark version 2.4 and earlier, the `SET` command works without any warnings even if the specified key is for `SparkConf` entries and it has no effect because the command does not update `SparkConf`, but the behavior might confuse users. Since 3.0, the command fails if a `SparkConf` key is used. You can disable such a check by setting `spark.sql.legacy.execution.setCommandRejectsSparkConfs` to `false`.

- Spark applications which are built with Spark version 2.4 and prior, and call methods of `UserDefinedFunction`, need to be re-compiled with Spark 3.0, as they are not binary compatible with Spark 3.0.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1396,6 +1396,16 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val VALIDATE_PARTITION_COLUMNS =
buildConf("spark.sql.sources.validatePartitionColumns")
.internal()
.doc("When this option is set to true, partition column values will be validated with " +
"user-specified schema. If the validation fails, a runtime exception is thrown." +
"When this option is set to false, the partition column value will be converted to null " +
"if it can not be casted to corresponding user-specified schema.")
.booleanConf
.createWithDefault(true)

val CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE =
buildConf("spark.sql.streaming.continuous.executorQueueSize")
.internal()
Expand Down Expand Up @@ -2014,6 +2024,8 @@ class SQLConf extends Serializable with Logging {
def allowCreatingManagedTableUsingNonemptyLocation: Boolean =
getConf(ALLOW_CREATING_MANAGED_TABLE_USING_NONEMPTY_LOCATION)

def validatePartitionColumns: Boolean = getConf(VALIDATE_PARTITION_COLUMNS)

def partitionOverwriteMode: PartitionOverwriteMode.Value =
PartitionOverwriteMode.withName(getConf(PARTITION_OVERWRITE_MODE))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,13 @@ abstract class PartitioningAwareFileIndex(
val timeZoneId = caseInsensitiveOptions.get(DateTimeUtils.TIMEZONE_OPTION)
.getOrElse(sparkSession.sessionState.conf.sessionLocalTimeZone)

val caseSensitive = sparkSession.sqlContext.conf.caseSensitiveAnalysis
PartitioningUtils.parsePartitions(
leafDirs,
typeInference = sparkSession.sessionState.conf.partitionColumnTypeInferenceEnabled,
basePaths = basePaths,
userSpecifiedSchema = userSpecifiedSchema,
caseSensitive = caseSensitive,
caseSensitive = sparkSession.sqlContext.conf.caseSensitiveAnalysis,
validatePartitionColumns = sparkSession.sqlContext.conf.validatePartitionColumns,
timeZoneId = timeZoneId)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@ import scala.util.Try

import org.apache.hadoop.fs.Path

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCoercion}
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal}
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SchemaUtils

Expand Down Expand Up @@ -96,9 +97,10 @@ object PartitioningUtils {
basePaths: Set[Path],
userSpecifiedSchema: Option[StructType],
caseSensitive: Boolean,
validatePartitionColumns: Boolean,
timeZoneId: String): PartitionSpec = {
parsePartitions(paths, typeInference, basePaths, userSpecifiedSchema,
caseSensitive, DateTimeUtils.getTimeZone(timeZoneId))
parsePartitions(paths, typeInference, basePaths, userSpecifiedSchema, caseSensitive,
validatePartitionColumns, DateTimeUtils.getTimeZone(timeZoneId))
}

private[datasources] def parsePartitions(
Expand All @@ -107,6 +109,7 @@ object PartitioningUtils {
basePaths: Set[Path],
userSpecifiedSchema: Option[StructType],
caseSensitive: Boolean,
validatePartitionColumns: Boolean,
timeZone: TimeZone): PartitionSpec = {
val userSpecifiedDataTypes = if (userSpecifiedSchema.isDefined) {
val nameToDataType = userSpecifiedSchema.get.fields.map(f => f.name -> f.dataType).toMap
Expand All @@ -121,7 +124,8 @@ object PartitioningUtils {

// First, we need to parse every partition's path and see if we can find partition values.
val (partitionValues, optDiscoveredBasePaths) = paths.map { path =>
parsePartition(path, typeInference, basePaths, userSpecifiedDataTypes, timeZone)
parsePartition(path, typeInference, basePaths, userSpecifiedDataTypes,
validatePartitionColumns, timeZone)
}.unzip

// We create pairs of (path -> path's partition value) here
Expand Down Expand Up @@ -203,6 +207,7 @@ object PartitioningUtils {
typeInference: Boolean,
basePaths: Set[Path],
userSpecifiedDataTypes: Map[String, DataType],
validatePartitionColumns: Boolean,
timeZone: TimeZone): (Option[PartitionValues], Option[Path]) = {
val columns = ArrayBuffer.empty[(String, Literal)]
// Old Hadoop versions don't have `Path.isRoot`
Expand All @@ -224,7 +229,8 @@ object PartitioningUtils {
// Let's say currentPath is a path of "/table/a=1/", currentPath.getName will give us a=1.
// Once we get the string, we try to parse it and find the partition column and value.
val maybeColumn =
parsePartitionColumn(currentPath.getName, typeInference, userSpecifiedDataTypes, timeZone)
parsePartitionColumn(currentPath.getName, typeInference, userSpecifiedDataTypes,
validatePartitionColumns, timeZone)
maybeColumn.foreach(columns += _)

// Now, we determine if we should stop.
Expand Down Expand Up @@ -258,6 +264,7 @@ object PartitioningUtils {
columnSpec: String,
typeInference: Boolean,
userSpecifiedDataTypes: Map[String, DataType],
validatePartitionColumns: Boolean,
timeZone: TimeZone): Option[(String, Literal)] = {
val equalSignIndex = columnSpec.indexOf('=')
if (equalSignIndex == -1) {
Expand All @@ -272,10 +279,15 @@ object PartitioningUtils {
val literal = if (userSpecifiedDataTypes.contains(columnName)) {
// SPARK-26188: if user provides corresponding column schema, get the column value without
// inference, and then cast it as user specified data type.
val columnValue = inferPartitionColumnValue(rawColumnValue, false, timeZone)
val castedValue =
Cast(columnValue, userSpecifiedDataTypes(columnName), Option(timeZone.getID)).eval()
Literal.create(castedValue, userSpecifiedDataTypes(columnName))
val dataType = userSpecifiedDataTypes(columnName)
val columnValueLiteral = inferPartitionColumnValue(rawColumnValue, false, timeZone)
val columnValue = columnValueLiteral.eval()
val castedValue = Cast(columnValueLiteral, dataType, Option(timeZone.getID)).eval()
if (validatePartitionColumns && columnValue != null && castedValue == null) {
throw new RuntimeException(s"Failed to cast value `$columnValue` to `$dataType` " +
s"for partition column `$columnName`")
}
Literal.create(castedValue, dataType)
} else {
inferPartitionColumnValue(rawColumnValue, typeInference, timeZone)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.util.{KnownSizeEstimation, SizeEstimator}

class FileIndexSuite extends SharedSQLContext {
Expand Down Expand Up @@ -95,6 +95,31 @@ class FileIndexSuite extends SharedSQLContext {
}
}

test("SPARK-26263: Throw exception when partition value can't be casted to user-specified type") {
withTempDir { dir =>
val partitionDirectory = new File(dir, "a=foo")
partitionDirectory.mkdir()
val file = new File(partitionDirectory, "text.txt")
stringToFile(file, "text")
val path = new Path(dir.getCanonicalPath)
val schema = StructType(Seq(StructField("a", IntegerType, false)))
withSQLConf(SQLConf.VALIDATE_PARTITION_COLUMNS.key -> "true") {
val fileIndex = new InMemoryFileIndex(spark, Seq(path), Map.empty, Some(schema))
val msg = intercept[RuntimeException] {
fileIndex.partitionSpec()
}.getMessage
assert(msg == "Failed to cast value `foo` to `IntegerType` for partition column `a`")
}

withSQLConf(SQLConf.VALIDATE_PARTITION_COLUMNS.key -> "false") {
val fileIndex = new InMemoryFileIndex(spark, Seq(path), Map.empty, Some(schema))
val partitionValues = fileIndex.partitionSpec().partitions.map(_.values)
assert(partitionValues.length == 1 && partitionValues(0).numFields == 1 &&
partitionValues(0).isNullAt(0))
}
}
}

test("InMemoryFileIndex: input paths are converted to qualified paths") {
withTempDir { dir =>
val file = new File(dir, "text.txt")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
"hdfs://host:9000/path/a=10.5/b=hello")

var exception = intercept[AssertionError] {
parsePartitions(paths.map(new Path(_)), true, Set.empty[Path], None, true, timeZoneId)
parsePartitions(paths.map(new Path(_)), true, Set.empty[Path], None, true, true, timeZoneId)
}
assert(exception.getMessage().contains("Conflicting directory structures detected"))

Expand All @@ -117,6 +117,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
Set(new Path("hdfs://host:9000/path/")),
None,
true,
true,
timeZoneId)

// Valid
Expand All @@ -132,6 +133,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
Set(new Path("hdfs://host:9000/path/something=true/table")),
None,
true,
true,
timeZoneId)

// Valid
Expand All @@ -147,6 +149,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
Set(new Path("hdfs://host:9000/path/table=true")),
None,
true,
true,
timeZoneId)

// Invalid
Expand All @@ -162,6 +165,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
Set(new Path("hdfs://host:9000/path/")),
None,
true,
true,
timeZoneId)
}
assert(exception.getMessage().contains("Conflicting directory structures detected"))
Expand All @@ -184,20 +188,22 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
Set(new Path("hdfs://host:9000/tmp/tables/")),
None,
true,
true,
timeZoneId)
}
assert(exception.getMessage().contains("Conflicting directory structures detected"))
}

test("parse partition") {
def check(path: String, expected: Option[PartitionValues]): Unit = {
val actual = parsePartition(new Path(path), true, Set.empty[Path], Map.empty, timeZone)._1
val actual = parsePartition(new Path(path), true, Set.empty[Path],
Map.empty, true, timeZone)._1
assert(expected === actual)
}

def checkThrows[T <: Throwable: Manifest](path: String, expected: String): Unit = {
val message = intercept[T] {
parsePartition(new Path(path), true, Set.empty[Path], Map.empty, timeZone)
parsePartition(new Path(path), true, Set.empty[Path], Map.empty, true, timeZone)
}.getMessage

assert(message.contains(expected))
Expand Down Expand Up @@ -242,6 +248,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
typeInference = true,
basePaths = Set(new Path("file://path/a=10")),
Map.empty,
true,
timeZone = timeZone)._1

assert(partitionSpec1.isEmpty)
Expand All @@ -252,6 +259,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
typeInference = true,
basePaths = Set(new Path("file://path")),
Map.empty,
true,
timeZone = timeZone)._1

assert(partitionSpec2 ==
Expand All @@ -272,6 +280,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
rootPaths,
None,
true,
true,
timeZoneId)
assert(actualSpec.partitionColumns === spec.partitionColumns)
assert(actualSpec.partitions.length === spec.partitions.length)
Expand Down Expand Up @@ -384,7 +393,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
test("parse partitions with type inference disabled") {
def check(paths: Seq[String], spec: PartitionSpec): Unit = {
val actualSpec =
parsePartitions(paths.map(new Path(_)), false, Set.empty[Path], None, true, timeZoneId)
parsePartitions(paths.map(new Path(_)), false, Set.empty[Path], None,
true, true, timeZoneId)
assert(actualSpec === spec)
}

Expand Down