Skip to content

[SPARK-51984][SQL] Support Check Constraint enforcement in UpdateTable and MergeIntoTable #50943

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -449,14 +449,16 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
ResolveEncodersInUDF),
// The rewrite rules might move resolved query plan into subquery. Once the resolved plan
// contains ScalaUDF, their encoders won't be resolved if `ResolveEncodersInUDF` is not
// applied before the rewrite rules. So we need to apply `ResolveEncodersInUDF` before the
// rewrite rules.
// applied before the rewrite rules. So we need to apply the rewrite rules after
// `ResolveEncodersInUDF`
Batch("DML rewrite", fixedPoint,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Context of why this batch is not in the main resolution batch: fab6d83

RewriteDeleteFromTable,
RewriteUpdateTable,
RewriteMergeIntoTable,
// Ensures columns of an output table are correctly resolved from the data in a logical plan.
ResolveOutputRelation),
ResolveOutputRelation,
// Apply table check constraints to validate data during write operations.
new ResolveTableConstraints(catalogManager)),
Batch("Subquery", Once,
UpdateOuterReferences),
Batch("Cleanup", fixedPoint,
Expand Down Expand Up @@ -1437,6 +1439,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
new ResolveReferencesInUpdate(catalogManager)
private val resolveReferencesInSort =
new ResolveReferencesInSort(catalogManager)
private val resolveReferencesInFilter =
new ResolveReferencesInFilter(catalogManager)

/**
* Return true if there're conflicting attributes among children's outputs of a plan
Expand Down Expand Up @@ -1483,9 +1487,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
checkTrailingCommaInSelect(expanded, starRemoved = true)
}
expanded
// If the filter list contains Stars, expand it.
case p: Filter if containsStar(Seq(p.condition)) =>
p.copy(expandStarExpression(p.condition, p.child))
// If the aggregate function argument contains Stars, expand it.
case a: Aggregate if containsStar(a.aggregateExpressions) =>
if (a.groupingExpressions.exists(_.isInstanceOf[UnresolvedOrdinal])) {
Expand Down Expand Up @@ -1711,23 +1712,14 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
Project(child.output, r.copy(resolvedFinal, newChild))
}

// Filter can host both grouping expressions/aggregate functions and missing attributes.
// The grouping expressions/aggregate functions resolution takes precedence over missing
// attributes. See the classdoc of `ResolveReferences` for details.
case f @ Filter(cond, child) if !cond.resolved || f.missingInput.nonEmpty =>
val resolvedBasic = resolveExpressionByPlanChildren(cond, f)
val resolvedWithAgg = resolveColWithAgg(resolvedBasic, child)
val (newCond, newChild) = resolveExprsAndAddMissingAttrs(Seq(resolvedWithAgg), child)
// Missing columns should be resolved right after basic column resolution.
// See the doc of `ResolveReferences`.
val resolvedFinal = resolveColsLastResort(newCond.head)
if (child.output == newChild.output) {
f.copy(condition = resolvedFinal)
case f: Filter =>
// If the filter list contains Stars, expand it.
val afterStarExpansion = if (containsStar(Seq(f.condition))) {
f.copy(expandStarExpression(f.condition, f.child))
} else {
// Add missing attributes and then project them away.
val newFilter = Filter(resolvedFinal, newChild)
Project(child.output, newFilter)
f
}
resolveReferencesInFilter.apply(afterStarExpansion)

case s: Sort if !s.resolved || s.missingInput.nonEmpty =>
resolveReferencesInSort(s)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
import org.apache.spark.sql.connector.catalog.CatalogManager


/**
* A virtual rule to resolve [[UnresolvedAttribute]] in [[Filter]]. It's only used by the real
* rules `ResolveReferences` and `ResolveTableConstraints`. Filters containing unresolved stars
* should have been expanded before applying this rule.
* Filter can host both grouping expressions/aggregate functions and missing attributes.
* The grouping expressions/aggregate functions resolution takes precedence over missing
* attributes. See the classdoc of `ResolveReferences` for details.
*/
class ResolveReferencesInFilter(val catalogManager: CatalogManager)
extends SQLConfHelper with ColumnResolutionHelper {
def apply(f: Filter): LogicalPlan = {
if (f.condition.resolved && f.missingInput.isEmpty) {
return f
}
val resolvedBasic = resolveExpressionByPlanChildren(f.condition, f)
val resolvedWithAgg = resolveColWithAgg(resolvedBasic, f.child)
val (newCond, newChild) = resolveExprsAndAddMissingAttrs(Seq(resolvedWithAgg), f.child)
// Missing columns should be resolved right after basic column resolution.
// See the doc of `ResolveReferences`.
val resolvedFinal = resolveColsLastResort(newCond.head)
if (f.child.output == newChild.output) {
f.copy(condition = resolvedFinal)
} else {
// Add missing attributes and then project them away.
val newFilter = Filter(resolvedFinal, newChild)
Project(f.child.output, newFilter)
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -22,38 +22,58 @@ import org.apache.spark.sql.catalyst.expressions.{And, CheckInvariant, Expressio
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, V2WriteCommand}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.COMMAND
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.connector.catalog.{CatalogManager, Table}
import org.apache.spark.sql.connector.catalog.constraints.Check
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation

class ResolveTableConstraints(val catalogManager: CatalogManager) extends Rule[LogicalPlan] {

private val resolveReferencesInFilter = new ResolveReferencesInFilter(catalogManager)

override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning(
_.containsPattern(COMMAND), ruleId) {
case v2Write: V2WriteCommand
if v2Write.table.resolved && v2Write.query.resolved &&
!containsCheckInvariant(v2Write.query) && v2Write.outputResolved =>
v2Write.table match {
case r: DataSourceV2Relation
if r.table.constraints != null && r.table.constraints.nonEmpty =>
// Check constraint is the only enforced constraint for DSV2 tables.
val checkInvariants = r.table.constraints.collect {
case c: Check =>
val unresolvedExpr = buildCatalystExpression(c)
val columnExtractors = mutable.Map[String, Expression]()
buildColumnExtractors(unresolvedExpr, columnExtractors)
CheckInvariant(unresolvedExpr, columnExtractors.toSeq, c.name, c.predicateSql)
}
// Combine the check invariants into a single expression using conjunctive AND.
checkInvariants.reduceOption(And).fold(v2Write)(
condition => v2Write.withNewQuery(Filter(condition, v2Write.query)))
case r: DataSourceV2Relation =>
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes in this file is purely code clean up. cc @aokolnychyi

buildCheckCondition(r.table).map { condition =>
val filter = Filter(condition, v2Write.query)
// Resolve attribute references in the filter condition only, not the entire query.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See https://github.com/gengliangwang/spark/actions/runs/15130563620/job/42530638818 for details.

=== Result of Batch DML rewrite ===
!'Sort [(cast(x#26 as bigint) + cast('c.y as bigint)) ASC NULLS FIRST], true   Sort [(cast(x#26 as bigint) + cast(tempresolvedcolumn(c#24.y, c, y, false) as bigint)) ASC NULLS FIRST], true
 +- Aggregate [c#24.x], [c#24.x AS x#26]                                       +- Aggregate [c#24.x], [c#24.x AS x#26]
    +- SubqueryAlias t                                                            +- SubqueryAlias t
       +- LocalRelation [c#24]                                                       +- LocalRelation [c#24]
          ,{})

// We use a targeted resolver (ResolveReferencesInFilter) instead of the full
// `ResolveReferences` rule to avoid the creation of `TempResolvedColumn` nodes that
// would interfere with the analyzer's ability to correctly identify unresolved
// attributes.
val resolvedFilter = resolveReferencesInFilter(filter)
v2Write.withNewQuery(resolvedFilter)
}.getOrElse(v2Write)
case _ =>
v2Write
}
}

// Constructs an optional check condition based on the table's check constraints.
// This condition validates data during write operations.
// Returns None if no check constraints exist; otherwise, combines all constraints using
// logical AND.
private def buildCheckCondition(table: Table): Option[Expression] = {
if (table.constraints == null || table.constraints.isEmpty) {
None
} else {
val checkInvariants = table.constraints.collect {
// Check constraint is the only enforced constraint for DSV2 tables.
case c: Check =>
val unresolvedExpr = buildCatalystExpression(c)
val columnExtractors = mutable.Map[String, Expression]()
buildColumnExtractors(unresolvedExpr, columnExtractors)
CheckInvariant(unresolvedExpr, columnExtractors.toSeq, c.name, c.predicateSql)
}
checkInvariants.reduceOption(And)
}
}

private def containsCheckInvariant(plan: LogicalPlan): Boolean = {
plan match {
plan exists {
case Filter(condition, _) =>
condition.exists(_.isInstanceOf[CheckInvariant])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.connector.write
import java.util

import org.apache.spark.sql.connector.catalog.{Column, SupportsRead, SupportsRowLevelOperations, SupportsWrite, Table, TableCapability}
import org.apache.spark.sql.connector.catalog.constraints.Constraint
import org.apache.spark.sql.connector.read.ScanBuilder
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
Expand All @@ -40,6 +41,7 @@ private[sql] case class RowLevelOperationTable(
override def schema: StructType = table.schema
override def columns: Array[Column] = table.columns()
override def capabilities: util.Set[TableCapability] = table.capabilities
override def constraints(): Array[Constraint] = table.constraints()
override def toString: String = table.toString

override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,37 +17,27 @@

package org.apache.spark.sql.connector.catalog

import java.util

import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
import org.apache.spark.sql.connector.expressions.Transform

class InMemoryRowLevelOperationTableCatalog extends InMemoryTableCatalog {
import CatalogV2Implicits._

override def createTable(
ident: Identifier,
columns: Array[Column],
partitions: Array[Transform],
properties: util.Map[String, String]): Table = {
override def createTable(ident: Identifier, tableInfo: TableInfo): Table = {
if (tables.containsKey(ident)) {
throw new TableAlreadyExistsException(ident.asMultipartIdentifier)
}

InMemoryTableCatalog.maybeSimulateFailedTableCreation(properties)
InMemoryTableCatalog.maybeSimulateFailedTableCreation(tableInfo.properties)

val tableName = s"$name.${ident.quoted}"
val schema = CatalogV2Util.v2ColumnsToStructType(columns)
val table = new InMemoryRowLevelOperationTable(tableName, schema, partitions, properties)
val schema = CatalogV2Util.v2ColumnsToStructType(tableInfo.columns)
val table = new InMemoryRowLevelOperationTable(
tableName, schema, tableInfo.partitions, tableInfo.properties, tableInfo.constraints())
tables.put(ident, table)
namespaces.putIfAbsent(ident.namespace.toList, Map())
table
}

override def createTable(ident: Identifier, tableInfo: TableInfo): Table = {
createTable(ident, tableInfo.columns(), tableInfo.partitions(), tableInfo.properties)
}

override def alterTable(ident: Identifier, changes: TableChange*): Table = {
val table = loadTable(ident).asInstanceOf[InMemoryRowLevelOperationTable]
val properties = CatalogV2Util.applyPropertiesChanges(table.properties, changes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,31 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase {
Row(6, "hr", "new-text")))
}

test("delete from table with table constraints") {
sql(
s"""
|CREATE TABLE $tableNameAsString (
| pk INT NOT NULL PRIMARY KEY,
| id INT UNIQUE,
| dep STRING,
| CONSTRAINT pk_check CHECK (pk > 0))
| PARTITIONED BY (dep)
|""".stripMargin)
append("pk INT NOT NULL, id INT, dep STRING",
"""{ "pk": 1, "id": 2, "dep": "hr" }
|{ "pk": 2, "id": 4, "dep": "eng" }
|{ "pk": 3, "id": 6, "dep": "eng" }
|""".stripMargin)
sql(s"DELETE FROM $tableNameAsString WHERE pk < 2")
checkAnswer(
sql(s"SELECT * FROM $tableNameAsString"),
Seq(Row(2, 4, "eng"), Row(3, 6, "eng")))
sql(s"DELETE FROM $tableNameAsString WHERE pk >=3")
checkAnswer(
sql(s"SELECT * FROM $tableNameAsString"),
Seq(Row(2, 4, "eng")))
}

test("delete from table containing struct column with default value") {
sql(
s"""CREATE TABLE $tableNameAsString (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression, GenericRowWithSchema}
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.METADATA_COL_ATTR_KEY
import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Delete, Identifier, InMemoryRowLevelOperationTable, InMemoryRowLevelOperationTableCatalog, Insert, MetadataColumn, Operation, Reinsert, Update, Write}
import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Delete, Identifier, InMemoryRowLevelOperationTable, InMemoryRowLevelOperationTableCatalog, Insert, MetadataColumn, Operation, Reinsert, TableInfo, Update, Write}
import org.apache.spark.sql.connector.expressions.LogicalExpressions.{identity, reference}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.execution.{InSubqueryExec, QueryExecution, SparkPlan}
Expand Down Expand Up @@ -102,7 +102,12 @@ abstract class RowLevelOperationSuiteBase

protected def createTable(columns: Array[Column]): Unit = {
val transforms = Array[Transform](identity(reference(Seq("dep"))))
catalog.createTable(ident, columns, transforms, extraTableProps)
val tableInfo = new TableInfo.Builder()
.withColumns(columns)
.withPartitions(transforms)
.withProperties(extraTableProps)
.build()
catalog.createTable(ident, tableInfo)
}

protected def createAndInitTable(schemaString: String, jsonData: String): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ trait DDLCommandTestUtils extends SQLTestUtils {
(f: String => Unit): Unit = {
val nsCat = s"$cat.$ns"
withNamespace(nsCat) {
sql(s"CREATE NAMESPACE $nsCat")
sql(s"CREATE NAMESPACE IF NOT EXISTS $nsCat")
val t = s"$nsCat.$tableName"
withTable(t) {
f(t)
Expand Down
Loading