Skip to content

Commit

Permalink
[Spark] Fix dependent constraints/generated columns checker for type …
Browse files Browse the repository at this point in the history
…widening (#3912)

<!--
Thanks for sending a pull request!  Here are some tips for you:
1. If this is your first time, please read our contributor guidelines:
https://github.com/delta-io/delta/blob/master/CONTRIBUTING.md
2. If the PR is unfinished, add '[WIP]' in your PR title, e.g., '[WIP]
Your PR title ...'.
  3. Be sure to keep the PR description updated to reflect all changes.
  4. Please write your PR title to summarize what this PR proposes.
5. If possible, provide a concise example to reproduce the issue for a
faster review.
6. If applicable, include the corresponding issue number in the PR title
and link it in the body.
-->

#### Which Delta project/connector is this regarding?
<!--
Please add the component selected below to the beginning of the pull
request title
For example: [Spark] Title of my pull request
-->

- [X] Spark
- [ ] Standalone
- [ ] Flink
- [ ] Kernel
- [ ] Other (fill in here)

## Description
The current checker of dependent expressions doesn't validate changes
for array and map types. For example, usage of type widening could lead
to constraints breaks:
```
scala> sql("CREATE TABLE table (a array<byte>) USING DELTA")
scala> sql("INSERT INTO table VALUES (array(1, -2, 3))")
scala> sql("SELECT hash(a[1]) FROM table").show()
+-----------+
| hash(a[1])|
+-----------+
|-1160545675|
+-----------+

scala> sql("ALTER TABLE table ADD CONSTRAINT ch1 CHECK (hash(a[1]) = -1160545675)")
scala> sql("ALTER TABLE table SET TBLPROPERTIES('delta.enableTypeWidening' = true)")
scala> sql("ALTER TABLE table CHANGE COLUMN a.element TYPE BIGINT")
scala> sql("SELECT hash(a[1]) FROM table").show()
+----------+
|hash(a[1])|
+----------+
|-981642528|
+----------+

scala> sql("INSERT INTO table VALUES (array(1, -2, 3))")
24/11/15 12:53:23 ERROR Utils: Aborting task
com.databricks.sql.transaction.tahoe.schema.DeltaInvariantViolationException: [DELTA_VIOLATE_CONSTRAINT_WITH_VALUES] CHECK constraint ch1 (hash(a[1]) = -1160545675) violated by row with values:
```
The proposed algorithm is more strict and regards maps, arrays and
structs during constraints/generated columns dependencies.
<!--
- Describe what this PR changes.
- Describe why we need the change.
 
If this PR resolves an issue be sure to include "Resolves #XXX" to
correctly link and close the issue upon merge.
-->

## How was this patch tested?
Added new tests for constraints and generated columns used with type
widening feature.
<!--
If tests were added, say they were added here. Please make sure to test
the changes thoroughly including negative and positive cases if
possible.
If the changes were tested in any way other than unit tests, please
clarify how you tested step by step (ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future).
If the changes were not tested, please explain why.
-->

## Does this PR introduce _any_ user-facing changes?
Due to strictness of the algorithm new potential dangerous type changes
will be prohibited. An exception will be thrown in the example above.
But such changes are called in the schema evolution feature mostly that
was introduced recently, so it should not affect many users.
<!--
If yes, please clarify the previous behavior and the change this PR
proposes - provide the console output, description and/or an example to
show the behavior difference if possible.
If possible, please also clarify if this is a user-facing change
compared to the released Delta Lake versions or within the unreleased
branches such as master.
If no, write 'No'.
-->
  • Loading branch information
Alexvsalexvsalex authored Dec 3, 2024
1 parent 81f27b3 commit 2937bc8
Show file tree
Hide file tree
Showing 8 changed files with 367 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,22 +78,22 @@ trait AlterDeltaTableCommand extends DeltaCommand {
protected def checkDependentExpressions(
sparkSession: SparkSession,
columnParts: Seq[String],
newMetadata: actions.Metadata,
oldMetadata: actions.Metadata,
protocol: Protocol): Unit = {
if (!sparkSession.sessionState.conf.getConf(
DeltaSQLConf.DELTA_ALTER_TABLE_CHANGE_COLUMN_CHECK_EXPRESSIONS)) {
return
}
// check if the column to change is referenced by check constraints
val dependentConstraints =
Constraints.findDependentConstraints(sparkSession, columnParts, newMetadata)
Constraints.findDependentConstraints(sparkSession, columnParts, oldMetadata)
if (dependentConstraints.nonEmpty) {
throw DeltaErrors.foundViolatingConstraintsForColumnChange(
UnresolvedAttribute(columnParts).name, dependentConstraints)
}
// check if the column to change is referenced by any generated columns
val dependentGenCols = SchemaUtils.findDependentGeneratedColumns(
sparkSession, columnParts, protocol, newMetadata.schema)
sparkSession, columnParts, protocol, oldMetadata.schema)
if (dependentGenCols.nonEmpty) {
throw DeltaErrors.foundViolatingGeneratedColumnsForColumnChange(
UnresolvedAttribute(columnParts).name, dependentGenCols)
Expand Down Expand Up @@ -768,7 +768,7 @@ case class AlterTableDropColumnsDeltaCommand(
configuration = newConfiguration
)
columnsToDrop.foreach { columnParts =>
checkDependentExpressions(sparkSession, columnParts, newMetadata, txn.protocol)
checkDependentExpressions(sparkSession, columnParts, metadata, txn.protocol)
}

txn.updateMetadata(newMetadata)
Expand Down Expand Up @@ -927,7 +927,7 @@ case class AlterTableChangeColumnDeltaCommand(
if (newColumn.name != columnName) {
// need to validate the changes if the column is renamed
checkDependentExpressions(
sparkSession, columnPath :+ columnName, newMetadata, txn.protocol)
sparkSession, columnPath :+ columnName, metadata, txn.protocol)
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,11 @@ object Constraints {
metadata.configuration.filter {
case (key, constraint) if key.toLowerCase(Locale.ROOT).startsWith("delta.constraints.") =>
SchemaUtils.containsDependentExpression(
sparkSession, columnName, constraint, sparkSession.sessionState.conf.resolver)
sparkSession,
columnName,
constraint,
metadata.schema,
sparkSession.sessionState.conf.resolver)
case _ => false
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.internal.MDC
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.FileSourceGeneratedMetadataStructField
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}

/**
* A trait that writers into Delta can extend to update the schema and/or partitioning of the table.
Expand Down Expand Up @@ -309,19 +309,19 @@ object ImplicitMetadataOperation {
currentDt: DataType,
updateDt: DataType): Unit = (currentDt, updateDt) match {
// we explicitly ignore the check for `StructType` here.
case (StructType(_), StructType(_)) =>

// FIXME: we intentionally incorporate the pattern match for `ArrayType` and `MapType`
// here mainly due to the field paths for maps/arrays in constraints/generated columns
// are *NOT* consistent with regular field paths,
// e.g., `hash(a.arr[0].x)` vs. `hash(a.element.x)`.
// this makes it hard to recurse into maps/arrays and check for the corresponding
// fields - thus we can not actually block the operation even if the updated field
// is being referenced by any CHECK constraints or generated columns.
case (from, to) =>
case (_: StructType, _: StructType) =>
case (current: ArrayType, update: ArrayType) =>
checkConstraintsOrGeneratedColumnsOnStructField(
spark, path :+ "element", protocol, metadata, current.elementType, update.elementType)
case (current: MapType, update: MapType) =>
checkConstraintsOrGeneratedColumnsOnStructField(
spark, path :+ "key", protocol, metadata, current.keyType, update.keyType)
checkConstraintsOrGeneratedColumnsOnStructField(
spark, path :+ "value", protocol, metadata, current.valueType, update.valueType)
case (_, _) =>
if (currentDt != updateDt) {
checkDependentConstraints(spark, path, metadata, from, to)
checkDependentGeneratedColumns(spark, path, protocol, metadata, from, to)
checkDependentConstraints(spark, path, metadata, currentDt, updateDt)
checkDependentGeneratedColumns(spark, path, protocol, metadata, currentDt, updateDt)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ import org.apache.spark.internal.MDC
import org.apache.spark.sql._
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.{Resolver, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Expression, GetArrayItem, GetArrayStructFields, GetMapValue, GetStructField}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.functions.{col, struct}
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -1269,20 +1270,58 @@ def normalizeColumnNamesInDataType(
// identifier with back-ticks.
def quoteIdentifier(part: String): String = s"`${part.replace("`", "``")}`"

private def analyzeExpression(
spark: SparkSession,
expr: Expression,
schema: StructType): Expression = {
// Workaround for `exp` analyze
val relation = LocalRelation(schema)
val relationWithExp = Project(Seq(Alias(expr, "validate_column")()), relation)
val analyzedPlan = spark.sessionState.analyzer.execute(relationWithExp)
analyzedPlan.collectFirst {
case Project(Seq(a: Alias), _: LocalRelation) => a.child
}.get
}

/**
* Will a column change, e.g., rename, need to be populated to the expression. This is true when
* the column to change itself or any of its descendent column is referenced by expression.
* Collects all attribute references in the given expression tree as a list of paths.
* In particular, generates paths for nested fields accessed using extraction expressions.
* For example:
* - a, length(a) -> true
* - b, (b.c + 1) -> true, because renaming b1 will need to change the expr to (b1.c + 1).
* - b.c, (cast b as string) -> false, because you can change b.c to b.c1 without affecting b.
* - GetStructField(AttributeReference("struct"), "a") -> ["struct.a"]
* - Size(AttributeReference("array")) -> ["array"]
*/
def containsDependentExpression(
spark: SparkSession,
private def collectUsedColumns(expression: Expression): Seq[Seq[String]] = {
val result = new collection.mutable.ArrayBuffer[Seq[String]]()

// Firstly, try to get referenced column for a child's expression.
// If it exists then we try to extend it by current expression.
// In case if we cannot extend one, we save the received column path (it's as long as possible).
def traverseAllPaths(exp: Expression): Option[Seq[String]] = exp match {
case GetStructField(child, _, Some(name)) => traverseAllPaths(child).map(_ :+ name)
case GetMapValue(child, key) =>
traverseAllPaths(key).foreach(result += _)
traverseAllPaths(child).map { childPath =>
result += childPath :+ "key"
childPath :+ "value"
}
case arrayExtract: GetArrayItem => traverseAllPaths(arrayExtract.child).map(_ :+ "element")
case arrayExtract: GetArrayStructFields =>
traverseAllPaths(arrayExtract.child).map(_ :+ "element" :+ arrayExtract.field.name)
case refCol: AttributeReference => Some(Seq(refCol.name))
case _ =>
exp.children.foreach(child => traverseAllPaths(child).foreach(result += _))
None
}

traverseAllPaths(expression).foreach(result += _)

result.toSeq
}

private def fallbackContainsDependentExpression(
expression: Expression,
columnToChange: Seq[String],
exprString: String,
resolver: Resolver): Boolean = {
val expression = spark.sessionState.sqlParser.parseExpression(exprString)
expression.foreach {
case refCol: UnresolvedAttribute =>
// columnToChange is the referenced column or its prefix
Expand All @@ -1294,6 +1333,51 @@ def normalizeColumnNamesInDataType(
false
}

/**
* Will a column change, e.g., rename, need to be populated to the expression. This is true when
* the column to change itself or any of its descendent column is referenced by expression.
* For example:
* - a, length(a) -> true
* - b, (b.c + 1) -> true, because renaming b1 will need to change the expr to (b1.c + 1).
* - b.c, (cast b as string) -> true, because change b.c to b.c1 affects (b as string) result.
*/
def containsDependentExpression(
spark: SparkSession,
columnToChange: Seq[String],
exprString: String,
schema: StructType,
resolver: Resolver): Boolean = {
val expression = spark.sessionState.sqlParser.parseExpression(exprString)
if (spark.sessionState.conf.getConf(
DeltaSQLConf.DELTA_CHANGE_COLUMN_CHECK_DEPENDENT_EXPRESSIONS_USE_V2)) {
try {
val analyzedExpr = analyzeExpression(spark, expression, schema)
val exprColumns = collectUsedColumns(analyzedExpr)
exprColumns.exists { exprColumn =>
// Changed column violates expression's column only when:
// 1) the changed column is a prefix of the referenced column,
// for example changing type of `col` affects `hash(col[0]) == 0`;
// 2) or the referenced column is a prefix of the changed column,
// for example changing type of `col.element` affects `concat_ws('', col) == 'abc'`;
// 3) or they are equal.
exprColumn.zip(columnToChange).forall {
case (exprFieldName, changedFieldName) => resolver(exprFieldName, changedFieldName)
}
}
} catch {
case NonFatal(e) =>
deltaAssert(
check = false,
name = "containsDependentExpression.checkV2Error",
msg = "Exception during dependent expression V2 checking: " + e.getMessage
)
fallbackContainsDependentExpression(expression, columnToChange, resolver)
}
} else {
fallbackContainsDependentExpression(expression, columnToChange, resolver)
}
}

/**
* Find the unsupported data type in a table schema. Return all columns that are using unsupported
* data types. For example,
Expand Down Expand Up @@ -1402,7 +1486,7 @@ def normalizeColumnNamesInDataType(
SchemaMergingUtils.transformColumns(schema) { (_, field, _) =>
GeneratedColumn.getGenerationExpressionStr(field.metadata).foreach { exprStr =>
val needsToChangeExpr = SchemaUtils.containsDependentExpression(
sparkSession, targetColumn, exprStr, sparkSession.sessionState.conf.resolver)
sparkSession, targetColumn, exprStr, schema, sparkSession.sessionState.conf.resolver)
if (needsToChangeExpr) dependentGenCols += field.name -> exprStr
}
field
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1580,6 +1580,21 @@ trait DeltaSQLConfBase {
.booleanConf
.createWithDefault(true)

val DELTA_CHANGE_COLUMN_CHECK_DEPENDENT_EXPRESSIONS_USE_V2 =
buildConf("changeColumn.checkDependentExpressionsUseV2")
.internal()
.doc(
"""
|More accurate implementation of checker for altering/renaming/dropping columns
|that might be referenced by constraints or generation rules.
|It respects nested arrays and maps, unlike the V1 checker.
|
|This is a safety switch - we should only turn this off when there is an issue with
|expression checking logic that prevents a valid column change from going through.
|""".stripMargin)
.booleanConf
.createWithDefault(true)

val DELTA_ALTER_TABLE_DROP_COLUMN_ENABLED =
buildConf("alterTable.dropColumn.enabled")
.internal()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import org.scalatest.GivenWhenThen

import org.apache.spark.sql.{AnalysisException, Column, DataFrame, QueryTest, Row}
import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.analysis.{caseInsensitiveResolution, caseSensitiveResolution, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, Expression}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -3048,6 +3048,100 @@ class SchemaUtilsSuite extends QueryTest
assert(udts.map(_.getClass.getName).toSet == Set(classOf[PointUDT].getName))
}


test("check if column affects given dependent expressions") {
val schema = StructType(Seq(
StructField("cArray", ArrayType(StringType)),
StructField("cStruct", StructType(Seq(
StructField("cMap", MapType(IntegerType, ArrayType(BooleanType))),
StructField("cMapWithComplexKey", MapType(StructType(Seq(
StructField("a", ArrayType(StringType)),
StructField("b", BooleanType)
)), IntegerType))
)))
))
assert(
SchemaUtils.containsDependentExpression(
spark,
columnToChange = Seq("cArray"),
exprString = "cast(cStruct.cMap as string) == '{}'",
schema,
caseInsensitiveResolution) === false
)
// Extracting value from map uses key type as well.
assert(
SchemaUtils.containsDependentExpression(
spark,
columnToChange = Seq("cStruct", "cMap", "key"),
exprString = "cStruct.cMap['random_key'] == 'string'",
schema,
caseInsensitiveResolution) === true
)
assert(
SchemaUtils.containsDependentExpression(
spark,
columnToChange = Seq("cstruct"),
exprString = "size(cStruct.cMap) == 0",
schema,
caseSensitiveResolution) === false
)
assert(
SchemaUtils.containsDependentExpression(
spark,
columnToChange = Seq("cStruct", "cMap", "key"),
exprString = "size(cArray) == 1",
schema,
caseInsensitiveResolution) === false
)
assert(
SchemaUtils.containsDependentExpression(
spark,
columnToChange = Seq("cStruct", "cMap", "key"),
exprString = "cStruct.cMapWithComplexKey[struct(cArray, false)] == 0",
schema,
caseInsensitiveResolution) === false
)
assert(
SchemaUtils.containsDependentExpression(
spark,
columnToChange = Seq("cArray", "element"),
exprString = "cStruct.cMapWithComplexKey[struct(cArray, false)] == 0",
schema,
caseInsensitiveResolution) === true
)
assert(
SchemaUtils.containsDependentExpression(
spark,
columnToChange = Seq("cStruct", "cMapWithComplexKey", "key", "b"),
exprString = "cStruct.cMapWithComplexKey[struct(cArray, false)] == 0",
schema,
caseInsensitiveResolution) === true
)
assert(
SchemaUtils.containsDependentExpression(
spark,
columnToChange = Seq("cArray", "element"),
exprString = "concat_ws('', cArray) == 'string'",
schema,
caseInsensitiveResolution) === true
)
assert(
SchemaUtils.containsDependentExpression(
spark,
columnToChange = Seq("CARRAY"),
exprString = "cArray[0] > 'a'",
schema,
caseInsensitiveResolution) === true
)
assert(
SchemaUtils.containsDependentExpression(
spark,
columnToChange = Seq("CARRAY", "element"),
exprString = "cArray[0] > 'a'",
schema,
caseSensitiveResolution) === false
)
}
}

object UnsupportedDataType extends DataType {
Expand Down
Loading

0 comments on commit 2937bc8

Please sign in to comment.