Skip to content

Commit 19d77ea

Browse files
committed
MERGE ... UPDATE/INSERT * should do by-name resolution
1 parent 33c1034 commit 19d77ea

File tree

2 files changed

+86
-36
lines changed

2 files changed

+86
-36
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 40 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1469,14 +1469,18 @@ class Analyzer(override val catalogManager: CatalogManager)
14691469
case UpdateAction(updateCondition, assignments) =>
14701470
val resolvedUpdateCondition = updateCondition.map(
14711471
resolveExpressionByPlanChildren(_, m))
1472-
// The update value can access columns from both target and source tables.
14731472
UpdateAction(
14741473
resolvedUpdateCondition,
1475-
resolveAssignments(Some(assignments), m, resolveValuesWithSourceOnly = false))
1474+
// The update value can access columns from both target and source tables.
1475+
resolveAssignments(assignments, m, resolveValuesWithSourceOnly = false))
14761476
case UpdateStarAction(updateCondition) =>
1477+
val assignments = targetTable.output.map { attr =>
1478+
Assignment(attr, UnresolvedAttribute(Seq(attr.name)))
1479+
}
14771480
UpdateAction(
14781481
updateCondition.map(resolveExpressionByPlanChildren(_, m)),
1479-
resolveAssignments(assignments = None, m, resolveValuesWithSourceOnly = false))
1482+
// For UPDATE *, the value must from source table.
1483+
resolveAssignments(assignments, m, resolveValuesWithSourceOnly = true))
14801484
case o => o
14811485
}
14821486
val newNotMatchedActions = m.notMatchedActions.map {
@@ -1487,15 +1491,18 @@ class Analyzer(override val catalogManager: CatalogManager)
14871491
resolveExpressionByPlanChildren(_, Project(Nil, m.sourceTable)))
14881492
InsertAction(
14891493
resolvedInsertCondition,
1490-
resolveAssignments(Some(assignments), m, resolveValuesWithSourceOnly = true))
1494+
resolveAssignments(assignments, m, resolveValuesWithSourceOnly = true))
14911495
case InsertStarAction(insertCondition) =>
14921496
// The insert action is used when not matched, so its condition and value can only
14931497
// access columns from the source table.
14941498
val resolvedInsertCondition = insertCondition.map(
14951499
resolveExpressionByPlanChildren(_, Project(Nil, m.sourceTable)))
1500+
val assignments = targetTable.output.map { attr =>
1501+
Assignment(attr, UnresolvedAttribute(Seq(attr.name)))
1502+
}
14961503
InsertAction(
14971504
resolvedInsertCondition,
1498-
resolveAssignments(assignments = None, m, resolveValuesWithSourceOnly = true))
1505+
resolveAssignments(assignments, m, resolveValuesWithSourceOnly = true))
14991506
case o => o
15001507
}
15011508
val resolvedMergeCondition = resolveExpressionByPlanChildren(m.mergeCondition, m)
@@ -1513,33 +1520,38 @@ class Analyzer(override val catalogManager: CatalogManager)
15131520
}
15141521

15151522
def resolveAssignments(
1516-
assignments: Option[Seq[Assignment]],
1523+
assignments: Seq[Assignment],
15171524
mergeInto: MergeIntoTable,
15181525
resolveValuesWithSourceOnly: Boolean): Seq[Assignment] = {
1519-
if (assignments.isEmpty) {
1520-
val expandedColumns = mergeInto.targetTable.output
1521-
val expandedValues = mergeInto.sourceTable.output
1522-
expandedColumns.zip(expandedValues).map(kv => Assignment(kv._1, kv._2))
1523-
} else {
1524-
assignments.get.map { assign =>
1525-
val resolvedKey = assign.key match {
1526-
case c if !c.resolved =>
1527-
resolveExpressionByPlanChildren(c, Project(Nil, mergeInto.targetTable))
1528-
case o => o
1529-
}
1530-
val resolvedValue = assign.value match {
1531-
// The update values may contain target and/or source references.
1532-
case c if !c.resolved =>
1533-
if (resolveValuesWithSourceOnly) {
1534-
resolveExpressionByPlanChildren(c, Project(Nil, mergeInto.sourceTable))
1535-
} else {
1536-
resolveExpressionByPlanChildren(c, mergeInto)
1537-
}
1538-
case o => o
1539-
}
1540-
Assignment(resolvedKey, resolvedValue)
1526+
assignments.map { assign =>
1527+
val resolvedKey = assign.key match {
1528+
case c if !c.resolved =>
1529+
resolveMergeExprOrFail(c, Project(Nil, mergeInto.targetTable))
1530+
case o => o
15411531
}
1532+
val resolvedValue = assign.value match {
1533+
// The update values may contain target and/or source references.
1534+
case c if !c.resolved =>
1535+
if (resolveValuesWithSourceOnly) {
1536+
resolveMergeExprOrFail(c, Project(Nil, mergeInto.sourceTable))
1537+
} else {
1538+
resolveMergeExprOrFail(c, mergeInto)
1539+
}
1540+
case o => o
1541+
}
1542+
Assignment(resolvedKey, resolvedValue)
1543+
}
1544+
}
1545+
1546+
private def resolveMergeExprOrFail(e: Expression, p: LogicalPlan): Expression = {
1547+
val resolved = resolveExpressionByPlanChildren(e, p)
1548+
resolved.references.filter(!_.resolved).foreach { a =>
1549+
// Note: This will throw error only on unresolved attribute issues,
1550+
// not other resolution errors like mismatched data types.
1551+
val cols = p.inputSet.toSeq.map(_.sql).mkString(", ")
1552+
a.failAnalysis(s"cannot resolve ${a.sql} in MERGE command given columns [$cols]")
15421553
}
1554+
resolved
15431555
}
15441556

15451557
// This method is used to trim groupByExpressions/selectedGroupByExpressions's top-level

sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,20 @@ class PlanResolutionSuite extends AnalysisTest {
5555
t
5656
}
5757

58+
private val table1: Table = {
59+
val t = mock(classOf[Table])
60+
when(t.schema()).thenReturn(new StructType().add("s", "string").add("i", "int"))
61+
when(t.partitioning()).thenReturn(Array.empty[Transform])
62+
t
63+
}
64+
65+
private val table2: Table = {
66+
val t = mock(classOf[Table])
67+
when(t.schema()).thenReturn(new StructType().add("i", "int").add("x", "string"))
68+
when(t.partitioning()).thenReturn(Array.empty[Transform])
69+
t
70+
}
71+
5872
private val tableWithAcceptAnySchemaCapability: Table = {
5973
val t = mock(classOf[Table])
6074
when(t.schema()).thenReturn(new StructType().add("i", "int"))
@@ -91,7 +105,8 @@ class PlanResolutionSuite extends AnalysisTest {
91105
when(newCatalog.loadTable(any())).thenAnswer((invocation: InvocationOnMock) => {
92106
invocation.getArgument[Identifier](0).name match {
93107
case "tab" => table
94-
case "tab1" => table
108+
case "tab1" => table1
109+
case "tab2" => table2
95110
case name => throw new NoSuchTableException(name)
96111
}
97112
})
@@ -107,7 +122,7 @@ class PlanResolutionSuite extends AnalysisTest {
107122
case "v1Table1" => v1Table
108123
case "v1HiveTable" => v1HiveTable
109124
case "v2Table" => table
110-
case "v2Table1" => table
125+
case "v2Table1" => table1
111126
case "v2TableWithAcceptAnySchemaCapability" => tableWithAcceptAnySchemaCapability
112127
case "view" => view
113128
case name => throw new NoSuchTableException(name)
@@ -1385,7 +1400,7 @@ class PlanResolutionSuite extends AnalysisTest {
13851400
// cte
13861401
val sql5 =
13871402
s"""
1388-
|WITH source(i, s) AS
1403+
|WITH source(s, i) AS
13891404
| (SELECT * FROM $source)
13901405
|MERGE INTO $target AS target
13911406
|USING source
@@ -1405,7 +1420,7 @@ class PlanResolutionSuite extends AnalysisTest {
14051420
updateAssigns)),
14061421
Seq(InsertAction(Some(EqualTo(il: AttributeReference, StringLiteral("insert"))),
14071422
insertAssigns))) =>
1408-
assert(source.output.map(_.name) == Seq("i", "s"))
1423+
assert(source.output.map(_.name) == Seq("s", "i"))
14091424
checkResolution(target, source, mergeCondition, Some(dl), Some(ul), Some(il),
14101425
updateAssigns, insertAssigns)
14111426

@@ -1414,8 +1429,7 @@ class PlanResolutionSuite extends AnalysisTest {
14141429
}
14151430

14161431
// no aliases
1417-
Seq(("v2Table", "v2Table1"),
1418-
("testcat.tab", "testcat.tab1")).foreach { pair =>
1432+
Seq(("v2Table", "v2Table1"), ("testcat.tab", "testcat.tab1")).foreach { pair =>
14191433

14201434
val target = pair._1
14211435
val source = pair._2
@@ -1507,7 +1521,7 @@ class PlanResolutionSuite extends AnalysisTest {
15071521
assert(e5.message.contains("Reference 's' is ambiguous"))
15081522
}
15091523

1510-
val sql6 =
1524+
val sql1 =
15111525
s"""
15121526
|MERGE INTO non_exist_target
15131527
|USING non_exist_source
@@ -1516,13 +1530,37 @@ class PlanResolutionSuite extends AnalysisTest {
15161530
|WHEN MATCHED THEN UPDATE SET *
15171531
|WHEN NOT MATCHED THEN INSERT *
15181532
""".stripMargin
1519-
val parsed = parseAndResolve(sql6)
1533+
val parsed = parseAndResolve(sql1)
15201534
parsed match {
15211535
case u: MergeIntoTable =>
15221536
assert(u.targetTable.isInstanceOf[UnresolvedRelation])
15231537
assert(u.sourceTable.isInstanceOf[UnresolvedRelation])
15241538
case _ => fail("Expect MergeIntoTable, but got:\n" + parsed.treeString)
15251539
}
1540+
1541+
// UPDATE * with incompatible schema between source and target tables.
1542+
val sql2 =
1543+
"""
1544+
|MERGE INTO testcat.tab
1545+
|USING testcat.tab2
1546+
|ON 1 = 1
1547+
|WHEN MATCHED THEN UPDATE SET *
1548+
|""".stripMargin
1549+
val e2 = intercept[AnalysisException](parseAndResolve(sql2))
1550+
assert(e2.message.contains(
1551+
"cannot resolve s in MERGE command given columns [testcat.tab2.i, testcat.tab2.x]"))
1552+
1553+
// INSERT * with incompatible schema between source and target tables.
1554+
val sql3 =
1555+
"""
1556+
|MERGE INTO testcat.tab
1557+
|USING testcat.tab2
1558+
|ON 1 = 1
1559+
|WHEN NOT MATCHED THEN INSERT *
1560+
|""".stripMargin
1561+
val e3 = intercept[AnalysisException](parseAndResolve(sql3))
1562+
assert(e3.message.contains(
1563+
"cannot resolve s in MERGE command given columns [testcat.tab2.i, testcat.tab2.x]"))
15261564
}
15271565

15281566
test("MERGE INTO TABLE - skip resolution on v2 tables that accept any schema") {

0 commit comments

Comments
 (0)