Skip to content

Commit 816f6dd

Browse files
karenfengcloud-fan
authored andcommitted
[SPARK-34527][SQL] Resolve duplicated common columns from USING/NATURAL JOIN
### What changes were proposed in this pull request? Adds the duplicated common columns as hidden columns to the Projection used to rewrite NATURAL/USING JOINs. ### Why are the changes needed? Allows users to resolve either side of the NATURAL/USING JOIN's common keys. Previously, the user could only resolve the following columns: | Join type | Left key columns | Right key columns | | --- | --- | --- | | Inner | Yes | No | | Left | Yes | No | | Right | No | Yes | | Outer | No | No | ### Does this PR introduce _any_ user-facing change? Yes. The user can now symmetrically resolve the common columns from a NATURAL/USING JOIN. ### How was this patch tested? SQL-side tests. The behavior matches PostgreSQL and MySQL. Closes #31666 from karenfeng/spark-34527. Authored-by: Karen Feng <karen.feng@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 0fc97b5 commit 816f6dd

File tree

11 files changed

+843
-54
lines changed

11 files changed

+843
-54
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 65 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -917,41 +917,30 @@ class Analyzer(override val catalogManager: CatalogManager)
917917
* Adds metadata columns to output for child relations when nodes are missing resolved attributes.
918918
*
919919
* References to metadata columns are resolved using columns from [[LogicalPlan.metadataOutput]],
920-
* but the relation's output does not include the metadata columns until the relation is replaced
921-
* using [[DataSourceV2Relation.withMetadataColumns()]]. Unless this rule adds metadata to the
922-
* relation's output, the analyzer will detect that nothing produces the columns.
920+
* but the relation's output does not include the metadata columns until the relation is replaced.
921+
* Unless this rule adds metadata to the relation's output, the analyzer will detect that nothing
922+
* produces the columns.
923923
*
924924
* This rule only adds metadata columns when a node is resolved but is missing input from its
925925
* children. This ensures that metadata columns are not added to the plan unless they are used. By
926926
* checking only resolved nodes, this ensures that * expansion is already done so that metadata
927-
* columns are not accidentally selected by *.
927+
* columns are not accidentally selected by *. This rule resolves operators downwards to avoid
928+
* projecting away metadata columns prematurely.
928929
*/
929930
object AddMetadataColumns extends Rule[LogicalPlan] {
930-
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._
931931

932-
private def hasMetadataCol(plan: LogicalPlan): Boolean = {
933-
plan.expressions.exists(_.find {
934-
case a: Attribute => a.isMetadataCol
935-
case _ => false
936-
}.isDefined)
937-
}
938-
939-
private def addMetadataCol(plan: LogicalPlan): LogicalPlan = plan match {
940-
case r: DataSourceV2Relation => r.withMetadataColumns()
941-
case _ => plan.withNewChildren(plan.children.map(addMetadataCol))
942-
}
932+
import org.apache.spark.sql.catalyst.util._
943933

944-
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp {
934+
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsDown {
935+
// Add metadata output to all node types
945936
case node if node.children.nonEmpty && node.resolved && hasMetadataCol(node) =>
946937
val inputAttrs = AttributeSet(node.children.flatMap(_.output))
947-
val metaCols = node.expressions.flatMap(_.collect {
948-
case a: Attribute if a.isMetadataCol && !inputAttrs.contains(a) => a
949-
})
938+
val metaCols = getMetadataAttributes(node).filterNot(inputAttrs.contains)
950939
if (metaCols.isEmpty) {
951940
node
952941
} else {
953942
val newNode = addMetadataCol(node)
954-
// We should not change the output schema of the plan. We should project away the extr
943+
// We should not change the output schema of the plan. We should project away the extra
955944
// metadata columns if necessary.
956945
if (newNode.sameOutput(node)) {
957946
newNode
@@ -960,6 +949,38 @@ class Analyzer(override val catalogManager: CatalogManager)
960949
}
961950
}
962951
}
952+
953+
private def getMetadataAttributes(plan: LogicalPlan): Seq[Attribute] = {
954+
plan.expressions.flatMap(_.collect {
955+
case a: Attribute if a.isMetadataCol => a
956+
case a: Attribute
957+
if plan.children.exists(c => c.metadataOutput.exists(_.exprId == a.exprId)) =>
958+
plan.children.collectFirst {
959+
case c if c.metadataOutput.exists(_.exprId == a.exprId) =>
960+
c.metadataOutput.find(_.exprId == a.exprId).get
961+
}.get
962+
})
963+
}
964+
965+
private def hasMetadataCol(plan: LogicalPlan): Boolean = {
966+
plan.expressions.exists(_.find {
967+
case a: Attribute =>
968+
// If an attribute is resolved before being labeled as metadata
969+
// (i.e. from the originating Dataset), we check with expression ID
970+
a.isMetadataCol ||
971+
plan.children.exists(c => c.metadataOutput.exists(_.exprId == a.exprId))
972+
case _ => false
973+
}.isDefined)
974+
}
975+
976+
private def addMetadataCol(plan: LogicalPlan): LogicalPlan = plan match {
977+
case r: DataSourceV2Relation => r.withMetadataColumns()
978+
case p: Project =>
979+
p.copy(
980+
projectList = p.metadataOutput ++ p.projectList,
981+
child = addMetadataCol(p.child))
982+
case _ => plan.withNewChildren(plan.children.map(addMetadataCol))
983+
}
963984
}
964985

965986
/**
@@ -1898,10 +1919,10 @@ class Analyzer(override val catalogManager: CatalogManager)
18981919
}
18991920

19001921
/**
1901-
* This method tries to resolve expressions and find missing attributes recursively. Specially,
1902-
* when the expressions used in `Sort` or `Filter` contain unresolved attributes or resolved
1903-
* attributes which are missed from child output. This method tries to find the missing
1904-
* attributes out and add into the projection.
1922+
* This method tries to resolve expressions and find missing attributes recursively.
1923+
* Specifically, when the expressions used in `Sort` or `Filter` contain unresolved attributes
1924+
* or resolved attributes which are missing from child output. This method tries to find the
1925+
* missing attributes and add them into the projection.
19051926
*/
19061927
private def resolveExprsAndAddMissingAttrs(
19071928
exprs: Seq[Expression], plan: LogicalPlan): (Seq[Expression], LogicalPlan) = {
@@ -3150,7 +3171,9 @@ class Analyzer(override val catalogManager: CatalogManager)
31503171
joinType: JoinType,
31513172
joinNames: Seq[String],
31523173
condition: Option[Expression],
3153-
hint: JoinHint) = {
3174+
hint: JoinHint): LogicalPlan = {
3175+
import org.apache.spark.sql.catalyst.util._
3176+
31543177
val leftKeys = joinNames.map { keyName =>
31553178
left.output.find(attr => resolver(attr.name, keyName)).getOrElse {
31563179
throw QueryCompilationErrors.unresolvedUsingColForJoinError(keyName, left, "left")
@@ -3170,26 +3193,33 @@ class Analyzer(override val catalogManager: CatalogManager)
31703193
val rUniqueOutput = right.output.filterNot(att => rightKeys.contains(att))
31713194

31723195
// the output list looks like: join keys, columns from left, columns from right
3173-
val projectList = joinType match {
3196+
val (projectList, hiddenList) = joinType match {
31743197
case LeftOuter =>
3175-
leftKeys ++ lUniqueOutput ++ rUniqueOutput.map(_.withNullability(true))
3198+
(leftKeys ++ lUniqueOutput ++ rUniqueOutput.map(_.withNullability(true)), rightKeys)
31763199
case LeftExistence(_) =>
3177-
leftKeys ++ lUniqueOutput
3200+
(leftKeys ++ lUniqueOutput, Seq.empty)
31783201
case RightOuter =>
3179-
rightKeys ++ lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput
3202+
(rightKeys ++ lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput, leftKeys)
31803203
case FullOuter =>
31813204
// in full outer join, joinCols should be non-null if there is.
31823205
val joinedCols = joinPairs.map { case (l, r) => Alias(Coalesce(Seq(l, r)), l.name)() }
3183-
joinedCols ++
3206+
(joinedCols ++
31843207
lUniqueOutput.map(_.withNullability(true)) ++
3185-
rUniqueOutput.map(_.withNullability(true))
3208+
rUniqueOutput.map(_.withNullability(true)),
3209+
leftKeys ++ rightKeys)
31863210
case _ : InnerLike =>
3187-
leftKeys ++ lUniqueOutput ++ rUniqueOutput
3211+
(leftKeys ++ lUniqueOutput ++ rUniqueOutput, rightKeys)
31883212
case _ =>
31893213
sys.error("Unsupported natural join type " + joinType)
31903214
}
3191-
// use Project to trim unnecessary fields
3192-
Project(projectList, Join(left, right, joinType, newCondition, hint))
3215+
// use Project to hide duplicated common keys
3216+
// propagate hidden columns from nested USING/NATURAL JOINs
3217+
val project = Project(projectList, Join(left, right, joinType, newCondition, hint))
3218+
project.setTagValue(
3219+
Project.hiddenOutputTag,
3220+
hiddenList.map(_.markAsSupportsQualifiedStar()) ++
3221+
project.child.metadataOutput.filter(_.supportsQualifiedStar))
3222+
project
31933223
}
31943224

31953225
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._
2323
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
2424
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
2525
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, UnaryNode}
26-
import org.apache.spark.sql.catalyst.util.quoteIfNeeded
26+
import org.apache.spark.sql.catalyst.util._
2727
import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog}
2828
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
2929
import org.apache.spark.sql.types.{DataType, Metadata, StructType}
@@ -340,11 +340,11 @@ case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevalu
340340
* Returns true if the nameParts is a subset of the last elements of qualifier of the attribute.
341341
*
342342
* For example, the following should all return true:
343-
* - `SELECT ns1.ns2.t.* FROM ns1.n2.t` where nameParts is Seq("ns1", "ns2", "t") and
343+
* - `SELECT ns1.ns2.t.* FROM ns1.ns2.t` where nameParts is Seq("ns1", "ns2", "t") and
344344
* qualifier is Seq("ns1", "ns2", "t").
345-
* - `SELECT ns2.t.* FROM ns1.n2.t` where nameParts is Seq("ns2", "t") and
345+
* - `SELECT ns2.t.* FROM ns1.ns2.t` where nameParts is Seq("ns2", "t") and
346346
* qualifier is Seq("ns1", "ns2", "t").
347-
* - `SELECT t.* FROM ns1.n2.t` where nameParts is Seq("t") and
347+
* - `SELECT t.* FROM ns1.ns2.t` where nameParts is Seq("t") and
348348
* qualifier is Seq("ns1", "ns2", "t").
349349
*/
350350
private def matchedQualifier(
@@ -366,10 +366,13 @@ case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevalu
366366
override def expand(
367367
input: LogicalPlan,
368368
resolver: Resolver): Seq[NamedExpression] = {
369-
// If there is no table specified, use all input attributes.
369+
// If there is no table specified, use all non-hidden input attributes.
370370
if (target.isEmpty) return input.output
371371

372-
val expandedAttributes = input.output.filter(matchedQualifier(_, target.get, resolver))
372+
// If there is a table specified, use hidden input attributes as well
373+
val hiddenOutput = input.metadataOutput.filter(_.supportsQualifiedStar)
374+
val expandedAttributes = (hiddenOutput ++ input.output).filter(
375+
matchedQualifier(_, target.get, resolver))
373376

374377
if (expandedAttributes.nonEmpty) return expandedAttributes
375378

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ trait AnalysisHelper extends QueryPlan[LogicalPlan] { self: LogicalPlan =>
145145
self.markRuleAsIneffective(ruleId)
146146
self
147147
} else {
148+
afterRule.copyTagsFrom(self)
148149
afterRule
149150
}
150151
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,18 @@ import org.apache.spark.sql.catalyst.expressions._
2525
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
2626
import org.apache.spark.sql.catalyst.plans._
2727
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition}
28+
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
2829
import org.apache.spark.sql.catalyst.trees.TreePattern.{
2930
INNER_LIKE_JOIN, JOIN, LEFT_SEMI_OR_ANTI_JOIN, NATURAL_LIKE_JOIN, OUTER_JOIN, TreePattern
3031
}
31-
import org.apache.spark.sql.catalyst.util.truncatedString
32+
import org.apache.spark.sql.catalyst.util._
3233
import org.apache.spark.sql.errors.QueryCompilationErrors
3334
import org.apache.spark.sql.internal.SQLConf
3435
import org.apache.spark.sql.types._
3536
import org.apache.spark.util.random.RandomSampler
3637

3738
/**
38-
* When planning take() or collect() operations, this special node that is inserted at the top of
39+
* When planning take() or collect() operations, this special node is inserted at the top of
3940
* the logical plan before invoking the query planner.
4041
*
4142
* Rules can pattern-match on this node in order to apply transformations that only take effect
@@ -69,7 +70,6 @@ object Subquery {
6970
case class Project(projectList: Seq[NamedExpression], child: LogicalPlan)
7071
extends OrderPreservingUnaryNode {
7172
override def output: Seq[Attribute] = projectList.map(_.toAttribute)
72-
override def metadataOutput: Seq[Attribute] = Nil
7373
override def maxRows: Option[Long] = child.maxRows
7474

7575
override lazy val resolved: Boolean = {
@@ -86,10 +86,17 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan)
8686
override lazy val validConstraints: ExpressionSet =
8787
getAllValidConstraints(projectList)
8888

89+
override def metadataOutput: Seq[Attribute] =
90+
getTagValue(Project.hiddenOutputTag).getOrElse(Nil)
91+
8992
override protected def withNewChildInternal(newChild: LogicalPlan): Project =
9093
copy(child = newChild)
9194
}
9295

96+
object Project {
97+
val hiddenOutputTag: TreeNodeTag[Seq[Attribute]] = TreeNodeTag[Seq[Attribute]]("hidden_output")
98+
}
99+
93100
/**
94101
* Applies a [[Generator]] to a stream of input rows, combining the
95102
* output of each into a new stream of rows. This operation is similar to a `flatMap` in functional

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import java.util.concurrent.atomic.AtomicBoolean
2525
import org.apache.spark.internal.Logging
2626
import org.apache.spark.sql.catalyst.expressions._
2727
import org.apache.spark.sql.internal.SQLConf
28-
import org.apache.spark.sql.types.{NumericType, StringType}
28+
import org.apache.spark.sql.types.{MetadataBuilder, NumericType, StringType}
2929
import org.apache.spark.unsafe.types.UTF8String
3030
import org.apache.spark.util.Utils
3131

@@ -201,4 +201,28 @@ package object util extends Logging {
201201
def truncatedString[T](seq: Seq[T], sep: String, maxFields: Int): String = {
202202
truncatedString(seq, "", sep, "", maxFields)
203203
}
204+
205+
val METADATA_COL_ATTR_KEY = "__metadata_col"
206+
207+
implicit class MetadataColumnHelper(attr: Attribute) {
208+
/**
209+
* If set, this metadata column is a candidate during qualified star expansions.
210+
*/
211+
val SUPPORTS_QUALIFIED_STAR = "__supports_qualified_star"
212+
213+
def isMetadataCol: Boolean = attr.metadata.contains(METADATA_COL_ATTR_KEY) &&
214+
attr.metadata.getBoolean(METADATA_COL_ATTR_KEY)
215+
216+
def supportsQualifiedStar: Boolean = attr.isMetadataCol &&
217+
attr.metadata.contains(SUPPORTS_QUALIFIED_STAR) &&
218+
attr.metadata.getBoolean(SUPPORTS_QUALIFIED_STAR)
219+
220+
def markAsSupportsQualifiedStar(): Attribute = attr.withMetadata(
221+
new MetadataBuilder()
222+
.withMetadata(attr.metadata)
223+
.putBoolean(METADATA_COL_ATTR_KEY, true)
224+
.putBoolean(SUPPORTS_QUALIFIED_STAR, true)
225+
.build()
226+
)
227+
}
204228
}

sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,14 @@ package org.apache.spark.sql.execution.datasources.v2
2020
import scala.collection.JavaConverters._
2121

2222
import org.apache.spark.sql.catalyst.analysis.{PartitionSpec, ResolvedPartitionSpec, UnresolvedPartitionSpec}
23-
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
23+
import org.apache.spark.sql.catalyst.expressions.AttributeReference
24+
import org.apache.spark.sql.catalyst.util.METADATA_COL_ATTR_KEY
2425
import org.apache.spark.sql.connector.catalog.{MetadataColumn, SupportsAtomicPartitionManagement, SupportsDelete, SupportsPartitionManagement, SupportsRead, SupportsWrite, Table, TableCapability, TruncatableTable}
2526
import org.apache.spark.sql.errors.QueryCompilationErrors
2627
import org.apache.spark.sql.types.{MetadataBuilder, StructField, StructType}
2728
import org.apache.spark.sql.util.CaseInsensitiveStringMap
2829

2930
object DataSourceV2Implicits {
30-
private val METADATA_COL_ATTR_KEY = "__metadata_col"
31-
3231
implicit class TableHelper(table: Table) {
3332
def asReadable: SupportsRead = {
3433
table match {
@@ -101,11 +100,6 @@ object DataSourceV2Implicits {
101100
def toAttributes: Seq[AttributeReference] = asStruct.toAttributes
102101
}
103102

104-
implicit class MetadataColumnHelper(attr: Attribute) {
105-
def isMetadataCol: Boolean = attr.metadata.contains(METADATA_COL_ATTR_KEY) &&
106-
attr.metadata.getBoolean(METADATA_COL_ATTR_KEY)
107-
}
108-
109103
implicit class OptionsHelper(options: Map[String, String]) {
110104
def asOptions: CaseInsensitiveStringMap = {
111105
new CaseInsensitiveStringMap(options.asJava)

sql/core/src/test/resources/sql-tests/inputs/natural-join.sql

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,19 @@ create temporary view nt2 as select * from values
1010
("one", 5)
1111
as nt2(k, v2);
1212

13+
create temporary view nt3 as select * from values
14+
("one", 4),
15+
("two", 5),
16+
("one", 6)
17+
as nt3(k, v3);
18+
19+
create temporary view nt4 as select * from values
20+
("one", 7),
21+
("two", 8),
22+
("one", 9)
23+
as nt4(k, v4);
24+
25+
SELECT * FROM nt1 natural join nt2;
1326

1427
SELECT * FROM nt1 natural join nt2 where k = "one";
1528

@@ -18,3 +31,43 @@ SELECT * FROM nt1 natural left join nt2 order by v1, v2;
1831
SELECT * FROM nt1 natural right join nt2 order by v1, v2;
1932

2033
SELECT count(*) FROM nt1 natural full outer join nt2;
34+
35+
SELECT k FROM nt1 natural join nt2;
36+
37+
SELECT k FROM nt1 natural join nt2 where k = "one";
38+
39+
SELECT nt1.* FROM nt1 natural join nt2;
40+
41+
SELECT nt2.* FROM nt1 natural join nt2;
42+
43+
SELECT sbq.* from (SELECT * FROM nt1 natural join nt2) sbq;
44+
45+
SELECT sbq.k from (SELECT * FROM nt1 natural join nt2) sbq;
46+
47+
SELECT nt1.*, nt2.* FROM nt1 natural join nt2;
48+
49+
SELECT *, nt2.k FROM nt1 natural join nt2;
50+
51+
SELECT nt1.k, nt2.k FROM nt1 natural join nt2;
52+
53+
SELECT nt1.k, nt2.k FROM nt1 natural join nt2 where k = "one";
54+
55+
SELECT * FROM (SELECT * FROM nt1 natural join nt2);
56+
57+
SELECT * FROM (SELECT nt1.*, nt2.* FROM nt1 natural join nt2);
58+
59+
SELECT * FROM (SELECT nt1.v1, nt2.k FROM nt1 natural join nt2);
60+
61+
SELECT nt2.k FROM (SELECT * FROM nt1 natural join nt2);
62+
63+
SELECT * FROM nt1 natural join nt2 natural join nt3;
64+
65+
SELECT nt1.*, nt2.*, nt3.* FROM nt1 natural join nt2 natural join nt3;
66+
67+
SELECT nt1.*, nt2.*, nt3.* FROM nt1 natural join nt2 join nt3 on nt2.k = nt3.k;
68+
69+
SELECT * FROM nt1 natural join nt2 join nt3 on nt1.k = nt3.k;
70+
71+
SELECT * FROM nt1 natural join nt2 join nt3 on nt2.k = nt3.k;
72+
73+
SELECT nt1.*, nt2.*, nt3.*, nt4.* FROM nt1 natural join nt2 natural join nt3 natural join nt4;

0 commit comments

Comments
 (0)