Skip to content

Commit

Permalink
[Spark] Widen all UDFs during conflict checking
Browse files Browse the repository at this point in the history
#### Which Delta project/connector is this regarding?

-Spark
- [ ] Standalone
- [ ] Flink
- [ ] Kernel
- [ ] Other (fill in here)

## Description

Delta conflict detection widens non-deterministic expressions before applying them to the changes of the winning transaction. Unfortunately, user defined functions are marked as deterministic by default and customers need to mark them as deterministic. This can result in actually non-deterministic UDFs incorrectly being treated as deterministic. This commit makes conflict detection widen all UDFs to prevent customers from shooting themselves in the foot.

Existing tests.

## Does this PR introduce _any_ user-facing changes?

No

Closes #2553

GitOrigin-RevId: 851a127034b98fed886621491705e344b6f87c11
  • Loading branch information
cstavr authored and allisonport-db committed Jan 31, 2024
1 parent d013462 commit 4aab4d3
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.sql.delta.RowId.RowTrackingMetadataDomain
import org.apache.spark.sql.delta.actions._
import org.apache.spark.sql.delta.metering.DeltaLogging
import org.apache.spark.sql.delta.sources.DeltaSQLConf
import org.apache.spark.sql.delta.util.DeltaSparkPlanUtils.CheckDeterministicOptions
import org.apache.spark.sql.delta.util.FileNames
import org.apache.hadoop.fs.FileStatus

Expand Down Expand Up @@ -244,15 +245,15 @@ private[delta] class ConflictChecker(
spark.conf.get(DeltaSQLConf.DELTA_CONFLICT_DETECTION_WIDEN_NONDETERMINISTIC_PREDICATES) match {
case DeltaSQLConf.NonDeterministicPredicateWidening.OFF =>
getFirstFileMatchingPartitionPredicatesInternal(
filesDf, shouldWidenNonDeterministicPredicates = false)
filesDf, shouldWidenNonDeterministicPredicates = false, shouldWidenAllUdf = false)
case wideningMode =>
val fileWithWidening = getFirstFileMatchingPartitionPredicatesInternal(
filesDf, shouldWidenNonDeterministicPredicates = true)
filesDf, shouldWidenNonDeterministicPredicates = true, shouldWidenAllUdf = true)

fileWithWidening.flatMap { fileWithWidening =>
val fileWithoutWidening =
getFirstFileMatchingPartitionPredicatesInternal(
filesDf, shouldWidenNonDeterministicPredicates = false)
filesDf, shouldWidenNonDeterministicPredicates = false, shouldWidenAllUdf = false)
if (fileWithoutWidening.isEmpty) {
// Conflict due to widening of non-deterministic predicate.
recordDeltaEvent(deltaLog,
Expand All @@ -261,7 +262,10 @@ private[delta] class ConflictChecker(
data = Map(
"wideningMode" -> wideningMode,
"predicate" ->
currentTransactionInfo.readPredicates.map(_.partitionPredicate.toString)))
currentTransactionInfo.readPredicates.map(_.partitionPredicate.toString),
"deterministicUDFs" -> containsDeterministicUDF(
currentTransactionInfo.readPredicates, partitionedOnly = true))
)
}
if (wideningMode == DeltaSQLConf.NonDeterministicPredicateWidening.ON) {
Some(fileWithWidening)
Expand All @@ -274,13 +278,16 @@ private[delta] class ConflictChecker(

private def getFirstFileMatchingPartitionPredicatesInternal(
filesDf: DataFrame,
shouldWidenNonDeterministicPredicates: Boolean): Option[AddFile] = {
shouldWidenNonDeterministicPredicates: Boolean,
shouldWidenAllUdf: Boolean): Option[AddFile] = {

def rewritePredicateFn(
predicate: Expression,
shouldRewriteFilter: Boolean): DeltaTableReadPredicate = {
val rewrittenPredicate = if (shouldWidenNonDeterministicPredicates) {
eliminateNonDeterministicPredicates(Seq(predicate)).newPredicates
val checkDeterministicOptions =
CheckDeterministicOptions(allowDeterministicUdf = !shouldWidenAllUdf)
eliminateNonDeterministicPredicates(Seq(predicate), checkDeterministicOptions).newPredicates
} else {
Seq(predicate)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.apache.spark.sql.delta

import org.apache.spark.sql.delta.util.DeltaSparkPlanUtils
import org.apache.spark.sql.delta.util.DeltaSparkPlanUtils.CheckDeterministicOptions

import org.apache.spark.sql.catalyst.expressions.{And, EmptyRow, Expression, Literal, Or}
import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
Expand Down Expand Up @@ -86,22 +87,25 @@ private[delta] trait ConflictCheckerPredicateElimination extends DeltaSparkPlanU
* files when these predicates are used for file skipping.
*/
protected def eliminateNonDeterministicPredicates(
predicates: Seq[Expression]): PredicateElimination = {
predicates: Seq[Expression],
checkDeterministicOptions: CheckDeterministicOptions): PredicateElimination = {
eliminateUnsupportedPredicates(predicates) {
case p @ SubqueryExpression(plan) =>
findFirstNonDeltaScan(plan) match {
case Some(plan) => PredicateElimination.eliminate(p, eliminated = Some(plan.nodeName))
case None =>
findFirstNonDeterministicNode(plan) match {
findFirstNonDeterministicNode(plan, checkDeterministicOptions) match {
case Some(node) =>
PredicateElimination.eliminate(p, eliminated = Some(planOrExpressionName(node)))
case None => PredicateElimination.keep(p)
}
}
// And and Or can safely be recursed through. Replacing any non-deterministic sub-tree
// with `True` will lead us to at most select more files than necessary later.
case p: And => PredicateElimination.recurse(p, eliminateNonDeterministicPredicates)
case p: Or => PredicateElimination.recurse(p, eliminateNonDeterministicPredicates)
case p: And => PredicateElimination.recurse(p,
p => eliminateNonDeterministicPredicates(p, checkDeterministicOptions))
case p: Or => PredicateElimination.recurse(p,
p => eliminateNonDeterministicPredicates(p, checkDeterministicOptions))
// All other expressions must either be completely deterministic,
// or must be replaced entirely, since replacing only their non-deterministic children
// may lead to files wrongly being deselected (e.g. `NOT True`).
Expand All @@ -110,7 +114,7 @@ private[delta] trait ConflictCheckerPredicateElimination extends DeltaSparkPlanU
// deterministic. This gives us better feedback on what caused the non-determinism in
// cases where `p` itself it deterministic but `p.deterministic = false` due to correctly
// detected non-deterministic child nodes.
findFirstNonDeterministicChildNode(p.children) match {
findFirstNonDeterministicChildNode(p.children, checkDeterministicOptions) match {
case Some(node) =>
PredicateElimination.eliminate(p, eliminated = Some(planOrExpressionName(node)))
case None => if (p.deterministic) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,8 @@ trait MergeIntoMaterializeSource extends DeltaLogging with DeltaSparkPlanUtils {
val forceMaterializationWithUnreadableFiles =
spark.conf.get(DeltaSQLConf.MERGE_FORCE_SOURCE_MATERIALIZATION_WITH_UNREADABLE_FILES)
import DeltaSQLConf.MergeMaterializeSource._
val checkDeterministicOptions =
DeltaSparkPlanUtils.CheckDeterministicOptions(allowDeterministicUdf = true)
materializeType match {
case ALL =>
(true, MergeIntoMaterializeSourceReason.MATERIALIZE_ALL)
Expand All @@ -249,7 +251,7 @@ trait MergeIntoMaterializeSource extends DeltaLogging with DeltaSparkPlanUtils {
(false, MergeIntoMaterializeSourceReason.NOT_MATERIALIZED_AUTO_INSERT_ONLY)
} else if (!planContainsOnlyDeltaScans(source)) {
(true, MergeIntoMaterializeSourceReason.NON_DETERMINISTIC_SOURCE_NON_DELTA)
} else if (!planIsDeterministic(source)) {
} else if (!planIsDeterministic(source, checkDeterministicOptions)) {
(true, MergeIntoMaterializeSourceReason.NON_DETERMINISTIC_SOURCE_OPERATORS)
// Force source materialization if Spark configs IGNORE_CORRUPT_FILES,
// IGNORE_MISSING_FILES or file source read options FileSourceOptions.IGNORE_CORRUPT_FILES
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@

package org.apache.spark.sql.delta.util

import org.apache.spark.sql.delta.DeltaTable
import org.apache.spark.sql.delta.{DeltaTable, DeltaTableReadPredicate}

import org.apache.spark.sql.catalyst.expressions.{Exists, Expression, InSubquery, LateralSubquery, ScalarSubquery}
import org.apache.spark.sql.catalyst.expressions.{Exists, Expression, InSubquery, LateralSubquery, ScalarSubquery, UserDefinedExpression}
import org.apache.spark.sql.catalyst.plans.logical.{Distinct, Filter, LeafNode, LogicalPlan, OneRowRelation, Project, SubqueryAlias, Union}
import org.apache.spark.sql.execution.datasources.LogicalRelation


trait DeltaSparkPlanUtils {
import DeltaSparkPlanUtils._

protected def planContainsOnlyDeltaScans(source: LogicalPlan): Boolean =
findFirstNonDeltaScan(source).isEmpty

Expand All @@ -43,53 +45,65 @@ trait DeltaSparkPlanUtils {
/**
* Returns `true` if `plan` has a safe level of determinism. This is a conservative
* approximation of `plan` being a truly deterministic query.
*
*/
protected def planIsDeterministic(plan: LogicalPlan): Boolean =
findFirstNonDeterministicNode(plan).isEmpty
protected def planIsDeterministic(
plan: LogicalPlan,
checkDeterministicOptions: CheckDeterministicOptions): Boolean =
findFirstNonDeterministicNode(plan, checkDeterministicOptions).isEmpty

type PlanOrExpression = Either[LogicalPlan, Expression]

/**
* Returns a part of the `plan` that does not have a safe level of determinism.
* This is a conservative approximation of `plan` being a truly deterministic query.
*/
protected def findFirstNonDeterministicNode(plan: LogicalPlan): Option[PlanOrExpression] = {
protected def findFirstNonDeterministicNode(
plan: LogicalPlan,
checkDeterministicOptions: CheckDeterministicOptions): Option[PlanOrExpression] = {
plan match {
// This is very restrictive, allowing only deterministic filters and projections directly
// on top of a Delta Table.
case Distinct(child) => findFirstNonDeterministicNode(child)
case Distinct(child) => findFirstNonDeterministicNode(child, checkDeterministicOptions)
case Project(projectList, child) =>
findFirstNonDeterministicChildNode(projectList) orElse {
findFirstNonDeterministicNode(child)
findFirstNonDeterministicChildNode(projectList, checkDeterministicOptions) orElse {
findFirstNonDeterministicNode(child, checkDeterministicOptions)
}
case Filter(cond, child) =>
findFirstNonDeterministicNode(cond) orElse {
findFirstNonDeterministicNode(child)
findFirstNonDeterministicNode(cond, checkDeterministicOptions) orElse {
findFirstNonDeterministicNode(child, checkDeterministicOptions)
}
case Union(children, _, _) => collectFirst[LogicalPlan, PlanOrExpression](
children,
findFirstNonDeterministicNode)
case SubqueryAlias(_, child) => findFirstNonDeterministicNode(child)
c => findFirstNonDeterministicNode(c, checkDeterministicOptions))
case SubqueryAlias(_, child) =>
findFirstNonDeterministicNode(child, checkDeterministicOptions)
case DeltaTable(_) => None
case OneRowRelation() => None
case node => Some(Left(node))
}
}

protected def findFirstNonDeterministicChildNode(
children: Seq[Expression]): Option[PlanOrExpression] =
children: Seq[Expression],
checkDeterministicOptions: CheckDeterministicOptions): Option[PlanOrExpression] =
collectFirst[Expression, PlanOrExpression](
children,
findFirstNonDeterministicNode)
c => findFirstNonDeterministicNode(c, checkDeterministicOptions))

protected def findFirstNonDeterministicNode(child: Expression): Option[PlanOrExpression] = {
protected def findFirstNonDeterministicNode(
child: Expression,
checkDeterministicOptions: CheckDeterministicOptions): Option[PlanOrExpression] = {
child match {
case SubqueryExpression(plan) =>
findFirstNonDeltaScan(plan).map(Left(_)).orElse(findFirstNonDeterministicNode(plan))
findFirstNonDeltaScan(plan).map(Left(_))
.orElse(findFirstNonDeterministicNode(plan, checkDeterministicOptions))
case _: UserDefinedExpression if !checkDeterministicOptions.allowDeterministicUdf =>
Some(Right(child))
case p =>
collectFirst[Expression, PlanOrExpression](
p.children,
findFirstNonDeterministicNode) orElse {
c => findFirstNonDeterministicNode(c, checkDeterministicOptions)) orElse {
if (p.deterministic) None else Some(Right(p))
}
}
Expand All @@ -113,4 +127,39 @@ trait DeltaSparkPlanUtils {
case _ => None
}
}

/** Returns whether the read predicates of a transaction contain any deterministic UDFs. */
def containsDeterministicUDF(
predicates: Seq[DeltaTableReadPredicate], partitionedOnly: Boolean): Boolean = {
if (partitionedOnly) {
predicates.exists {
_.partitionPredicates.exists(containsDeterministicUDF)
}
} else {
predicates.exists { p =>
p.dataPredicates.exists(containsDeterministicUDF) ||
p.partitionPredicates.exists(containsDeterministicUDF)
}
}
}

/** Returns whether an expression contains any deterministic UDFs. */
def containsDeterministicUDF(expr: Expression): Boolean = expr.exists {
case udf: UserDefinedExpression => udf.deterministic
case _ => false
}
}


object DeltaSparkPlanUtils {
/**
* Options for deciding whether plans contain non-deterministic nodes and expressions.
*
* @param allowDeterministicUdf If true, allow UDFs that are marked by users as deterministic.
* If false, always treat them as non-deterministic to be more
* defensive against user bugs.
*/
case class CheckDeterministicOptions(
allowDeterministicUdf: Boolean
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.apache.spark.sql.delta

import org.apache.spark.sql.delta.DeltaTestUtils.BOOLEAN_DOMAIN
import org.apache.spark.sql.delta.util.DeltaSparkPlanUtils

import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.catalyst.dsl.expressions.DslExpression
Expand Down Expand Up @@ -50,11 +51,16 @@ class ConflictCheckerPredicateEliminationUnitSuite
col("c").expr > ScalarSubquery(df.queryExecution.analyzed)
}

private def defaultEliminationFunction(e: Seq[Expression]): PredicateElimination = {
val options = DeltaSparkPlanUtils.CheckDeterministicOptions(allowDeterministicUdf = false)
eliminateNonDeterministicPredicates(e, options)
}

private def checkEliminationResult(
predicate: Expression,
expected: PredicateElimination,
eliminationFunction: Seq[Expression] => PredicateElimination =
eliminateNonDeterministicPredicates): Unit = {
eliminationFunction: Seq[Expression] => PredicateElimination = defaultEliminationFunction)
: Unit = {
require(expected.newPredicates.size === 1)
val actual = eliminationFunction(Seq(predicate))
assert(actual.newPredicates.size === 1)
Expand Down

0 comments on commit 4aab4d3

Please sign in to comment.