Skip to content

Use storageAssighmentPolicy for casts in Delta DML commands #1938

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions spark/src/main/resources/error/delta-error-classes.json
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,14 @@
],
"sqlState" : "0A000"
},
"DELTA_CAST_OVERFLOW_IN_TABLE_WRITE" : {
"message" : [
"Failed to write a value of <sourceType> type into the <targetType> type column <columnName> due to an overflow.",
"Use `try_cast` on the input value to tolerate overflow and return NULL instead.",
"If necessary, set <storeAssignmentPolicyFlag> to \"LEGACY\" to bypass this error or set <updateAndMergeCastingFollowsAnsiEnabledFlag> to true to revert to the old behaviour and follow <ansiEnabledFlag> in UPDATE and MERGE."
],
"sqlState" : "22003"
},
"DELTA_CDC_NOT_ALLOWED_IN_THIS_VERSION" : {
"message" : [
"Configuration delta.enableChangeDataFeed cannot be set. Change data feed from Delta is not yet available."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference,
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.catalog.Identifier
import org.apache.spark.sql.errors.QueryErrorsBase
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, StructField, StructType}

Expand Down Expand Up @@ -118,7 +119,8 @@ trait DocsPath {
*/
trait DeltaErrorsBase
extends DocsPath
with DeltaLogging {
with DeltaLogging
with QueryErrorsBase {

def baseDocsPath(spark: SparkSession): String = baseDocsPath(spark.sparkContext.getConf)

Expand Down Expand Up @@ -614,6 +616,22 @@ trait DeltaErrorsBase
)
}

def castingCauseOverflowErrorInTableWrite(
from: DataType,
to: DataType,
columnName: String): ArithmeticException = {
new DeltaArithmeticException(
errorClass = "DELTA_CAST_OVERFLOW_IN_TABLE_WRITE",
messageParameters = Map(
"sourceType" -> toSQLType(from),
"targetType" -> toSQLType(to),
"columnName" -> toSQLId(columnName),
"storeAssignmentPolicyFlag" -> SQLConf.STORE_ASSIGNMENT_POLICY.key,
"updateAndMergeCastingFollowsAnsiEnabledFlag" ->
DeltaSQLConf.UPDATE_AND_MERGE_CASTING_FOLLOWS_ANSI_ENABLED_FLAG.key,
"ansiEnabledFlag" -> SQLConf.ANSI_ENABLED.key))
}

def notADeltaTable(table: String): Throwable = {
new DeltaAnalysisException(errorClass = "DELTA_NOT_A_DELTA_TABLE",
messageParameters = Array(table))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,11 @@ class DeltaParseException(
ParserUtils.position(ctx.getStop)
) with DeltaThrowable

class DeltaArithmeticException(
errorClass: String,
messageParameters: Map[String, String]) extends ArithmeticException with DeltaThrowable {
override def getErrorClass: String = errorClass

override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava
}

Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ case class PreprocessTableMerge(override val conf: SQLConf)
castIfNeeded(
a.expr,
targetAttrib.dataType,
allowStructEvolution = migrateSchema),
allowStructEvolution = migrateSchema,
targetAttrib.name),
targetColNameResolved = true)
}.getOrElse {
// If a target table column was not found in the INSERT columns and expressions,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,20 @@ import org.apache.spark.sql.delta.sources.DeltaSQLConf
import org.apache.spark.sql.delta.util.AnalysisHelper

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.analysis.{CastSupport, Resolver}
import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper}
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

/**
* Trait with helper functions to generate expressions to update target columns, even if they are
* nested fields.
*/
trait UpdateExpressionsSupport extends CastSupport with SQLConfHelper with AnalysisHelper {
trait UpdateExpressionsSupport extends SQLConfHelper with AnalysisHelper {
/**
* Specifies an operation that updates a target column with the given expression.
* The target column may or may not be a nested field and it is specified as a full quoted name
Expand All @@ -49,11 +52,13 @@ trait UpdateExpressionsSupport extends CastSupport with SQLConfHelper with Analy
* struct casting will throw an error if there's any mismatch between
* column names. For example, (b, c, a) -> (a, b, c) is always a valid
* cast, but (a, b) -> (a, b, c) is valid only with this flag set.
* @param columnName The name of the column written to. It is used for the error message.
*/
protected def castIfNeeded(
fromExpression: Expression,
dataType: DataType,
allowStructEvolution: Boolean = false): Expression = {
allowStructEvolution: Boolean,
columnName: String): Expression = {

fromExpression match {
// Need to deal with NullType here, as some types cannot be casted from NullType, e.g.,
Expand All @@ -69,36 +74,46 @@ trait UpdateExpressionsSupport extends CastSupport with SQLConfHelper with Analy
// If fromExpression is an array function returning an array, cast the
// underlying array first and then perform the function on the transformed array.
case ArrayUnion(leftExpression, rightExpression) =>
val castedLeft = castIfNeeded(leftExpression, dataType, allowStructEvolution)
val castedRight = castIfNeeded(rightExpression, dataType, allowStructEvolution)
val castedLeft =
castIfNeeded(leftExpression, dataType, allowStructEvolution, columnName)
val castedRight =
castIfNeeded(rightExpression, dataType, allowStructEvolution, columnName)
ArrayUnion(castedLeft, castedRight)

case ArrayIntersect(leftExpression, rightExpression) =>
val castedLeft = castIfNeeded(leftExpression, dataType, allowStructEvolution)
val castedRight = castIfNeeded(rightExpression, dataType, allowStructEvolution)
val castedLeft =
castIfNeeded(leftExpression, dataType, allowStructEvolution, columnName)
val castedRight =
castIfNeeded(rightExpression, dataType, allowStructEvolution, columnName)
ArrayIntersect(castedLeft, castedRight)

case ArrayExcept(leftExpression, rightExpression) =>
val castedLeft = castIfNeeded(leftExpression, dataType, allowStructEvolution)
val castedRight = castIfNeeded(rightExpression, dataType, allowStructEvolution)
val castedLeft =
castIfNeeded(leftExpression, dataType, allowStructEvolution, columnName)
val castedRight =
castIfNeeded(rightExpression, dataType, allowStructEvolution, columnName)
ArrayExcept(castedLeft, castedRight)

case ArrayRemove(leftExpression, rightExpression) =>
val castedLeft = castIfNeeded(leftExpression, dataType, allowStructEvolution)
val castedLeft =
castIfNeeded(leftExpression, dataType, allowStructEvolution, columnName)
// ArrayRemove removes all elements that equal to element from the given array.
// In this case, the element to be removed also needs to be casted into the target
// array's element type.
val castedRight = castIfNeeded(rightExpression, toEt, allowStructEvolution)
val castedRight =
castIfNeeded(rightExpression, toEt, allowStructEvolution, columnName)
ArrayRemove(castedLeft, castedRight)

case ArrayDistinct(expression) =>
val castedExpr = castIfNeeded(expression, dataType, allowStructEvolution)
val castedExpr =
castIfNeeded(expression, dataType, allowStructEvolution, columnName)
ArrayDistinct(castedExpr)

case _ =>
// generate a lambda function to cast each array item into to element struct type.
val structConverter: (Expression, Expression) => Expression = (_, i) =>
castIfNeeded(GetArrayItem(fromExpression, i), toEt, allowStructEvolution)
castIfNeeded(
GetArrayItem(fromExpression, i), toEt, allowStructEvolution, columnName)
val transformLambdaFunc = {
val elementVar = NamedLambdaVariable("elementVar", toEt, toContainsNull)
val indexVar = NamedLambdaVariable("indexVar", IntegerType, false)
Expand All @@ -111,7 +126,8 @@ trait UpdateExpressionsSupport extends CastSupport with SQLConfHelper with Analy
// containsNull as true to avoid casting failures.
cast(
ArrayTransform(fromExpression, transformLambdaFunc),
ArrayType(toEt, containsNull = true)
ArrayType(toEt, containsNull = true),
columnName
)
}
case (from: MapType, to: MapType) if !Cast.canCast(from, to) =>
Expand All @@ -128,17 +144,17 @@ trait UpdateExpressionsSupport extends CastSupport with SQLConfHelper with Analy
if (from.keyType != to.keyType) {
transformedKeysAndValues =
TransformKeys(transformedKeysAndValues, createMapConverter {
(key, _) => castIfNeeded(key, to.keyType, allowStructEvolution)
(key, _) => castIfNeeded(key, to.keyType, allowStructEvolution, columnName)
})
}

if (from.valueType != to.valueType) {
transformedKeysAndValues =
TransformValues(transformedKeysAndValues, createMapConverter {
(_, value) => castIfNeeded(value, to.valueType, allowStructEvolution)
(_, value) => castIfNeeded(value, to.valueType, allowStructEvolution, columnName)
})
}
cast(transformedKeysAndValues, to)
cast(transformedKeysAndValues, to, columnName)
case (from: StructType, to: StructType)
if !DataType.equalsIgnoreCaseAndNullability(from, to) && resolveStructsByName =>
// All from fields must be present in the final schema, or we'll silently lose data.
Expand Down Expand Up @@ -167,12 +183,13 @@ trait UpdateExpressionsSupport extends CastSupport with SQLConfHelper with Analy
}
Literal(null)
}
Seq(fieldNameLit, castIfNeeded(extractedField, field.dataType, allowStructEvolution))
Seq(fieldNameLit,
castIfNeeded(extractedField, field.dataType, allowStructEvolution, field.name))
})

cast(nameMappedStruct, to.asNullable)
cast(nameMappedStruct, to.asNullable, columnName)

case (from, to) if (from != to) => cast(fromExpression, dataType)
case (from, to) if (from != to) => cast(fromExpression, dataType, columnName)
case _ => fromExpression
}
}
Expand Down Expand Up @@ -267,8 +284,11 @@ trait UpdateExpressionsSupport extends CastSupport with SQLConfHelper with Analy
prefixMatchedOps.map(op => (pathPrefix ++ op.targetColNameParts).mkString(".")))
}
// For an exact match, return the updateExpr from the update operation.
Some(
castIfNeeded(fullyMatchedOp.get.updateExpr, targetCol.dataType, allowStructEvolution))
Some(castIfNeeded(
fullyMatchedOp.get.updateExpr,
targetCol.dataType,
allowStructEvolution,
targetCol.name))
} else {
// So there are prefix-matched update operations, but none of them is a full match. Then
// that means targetCol is a complex data type, so we recursively pass along the update
Expand Down Expand Up @@ -387,4 +407,110 @@ trait UpdateExpressionsSupport extends CastSupport with SQLConfHelper with Analy
}
}
}

/**
* Replaces 'CastSupport.cast'. Selects a cast based on 'spark.sql.storeAssignmentPolicy' if
* 'spark.databricks.delta.updateAndMergeCastingFollowsAnsiEnabledFlag. is false, and based on
* 'spark.sql.ansi.enabled' otherwise.
*/
private def cast(child: Expression, dataType: DataType, columnName: String): Expression = {
if (conf.getConf(DeltaSQLConf.UPDATE_AND_MERGE_CASTING_FOLLOWS_ANSI_ENABLED_FLAG)) {
return Cast(child, dataType, Option(conf.sessionLocalTimeZone))
}

conf.storeAssignmentPolicy match {
case SQLConf.StoreAssignmentPolicy.LEGACY =>
Cast(child, dataType, Some(conf.sessionLocalTimeZone), ansiEnabled = false)
case SQLConf.StoreAssignmentPolicy.ANSI =>
val cast = Cast(child, dataType, Some(conf.sessionLocalTimeZone), ansiEnabled = true)
if (canCauseCastOverflow(cast)) {
CheckOverflowInTableWrite(cast, columnName)
} else {
cast
}
case SQLConf.StoreAssignmentPolicy.STRICT =>
UpCast(child, dataType)
}
}

private def containsIntegralOrDecimalType(dt: DataType): Boolean = dt match {
case _: IntegralType | _: DecimalType => true
case a: ArrayType => containsIntegralOrDecimalType(a.elementType)
case m: MapType =>
containsIntegralOrDecimalType(m.keyType) || containsIntegralOrDecimalType(m.valueType)
case s: StructType =>
s.fields.exists(sf => containsIntegralOrDecimalType(sf.dataType))
case _ => false
}

private def canCauseCastOverflow(cast: Cast): Boolean = {
containsIntegralOrDecimalType(cast.dataType) &&
!Cast.canUpCast(cast.child.dataType, cast.dataType)
}
}

case class CheckOverflowInTableWrite(child: Expression, columnName: String)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to be almost the same as CheckOverflowInTableInsert in Spark, could we inherit it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was indeed looking into it, but there are some problems with that

  • This case is a case class and CheckOverflowInTableInsert is one as well, so Scala does not like that
  • I cannot refactor CheckOverflowInTableInsert because it is part of the OSS Spark imports. This also inhibits me from making methods like getCast protected or the error throwing function configurable
  • Even if I could overwrite methods, with doGenCodeWithBetterErrorMsg and eval I would have to overwrite the main ones so most of the code would not be shared (this is more of a minor thing though)

extends UnaryExpression {
override protected def withNewChildInternal(newChild: Expression): Expression = {
copy(child = newChild)
}

private def getCast: Option[Cast] = child match {
case c: Cast => Some(c)
case ExpressionProxy(c: Cast, _, _) => Some(c)
case _ => None
}

override def eval(input: InternalRow): Any = try {
child.eval(input)
} catch {
case e: ArithmeticException =>
getCast match {
case Some(cast) =>
throw DeltaErrors.castingCauseOverflowErrorInTableWrite(
cast.child.dataType,
cast.dataType,
columnName)
case None => throw e
}
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
getCast match {
case Some(child) => doGenCodeWithBetterErrorMsg(ctx, ev, child)
case None => child.genCode(ctx)
}
}

def doGenCodeWithBetterErrorMsg(ctx: CodegenContext, ev: ExprCode, child: Cast): ExprCode = {
val childGen = child.genCode(ctx)
val exceptionClass = classOf[ArithmeticException].getCanonicalName
assert(child.isInstanceOf[Cast])
val cast = child.asInstanceOf[Cast]
val fromDt =
ctx.addReferenceObj("from", cast.child.dataType, cast.child.dataType.getClass.getName)
val toDt = ctx.addReferenceObj("to", child.dataType, child.dataType.getClass.getName)
val col = ctx.addReferenceObj("colName", columnName, "java.lang.String")
// scalastyle:off line.size.limit
ev.copy(code =
code"""
boolean ${ev.isNull} = true;
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
try {
${childGen.code}
${ev.isNull} = ${childGen.isNull};
${ev.value} = ${childGen.value};
} catch ($exceptionClass e) {
throw org.apache.spark.sql.delta.DeltaErrors
.castingCauseOverflowErrorInTableWrite($fromDt, $toDt, $col);
}"""
)
// scalastyle:on line.size.limit
}

override def dataType: DataType = child.dataType

override def sql: String = child.sql

override def toString: String = child.toString
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import org.apache.spark.internal.config.ConfigBuilder
import org.apache.spark.network.util.ByteUnit
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils

/**
* [[SQLConf]] entries for Delta features.
Expand Down Expand Up @@ -1254,6 +1253,15 @@ trait DeltaSQLConfBase {
.intConf
.createWithDefault(100 * 1000)

val UPDATE_AND_MERGE_CASTING_FOLLOWS_ANSI_ENABLED_FLAG =
buildConf("updateAndMergeCastingFollowsAnsiEnabledFlag")
.internal()
.doc("""If false, casting behaviour in implicit casts in UPDATE and MERGE follows
|'spark.sql.storeAssignmentPolicy'. If true, these casts follow 'ansi.enabled'. This
|was the default before Delta 3.5.""".stripMargin)
.booleanConf
.createWithDefault(false)

}

object DeltaSQLConf extends DeltaSQLConfBase
Loading