Skip to content

[SPARK-13427][SQL] Support USING clause in JOIN. #11297

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,17 @@ fromClause
joinSource
@init { gParent.pushMsg("join source", state); }
@after { gParent.popMsg(state); }
: fromSource ( joinToken^ fromSource ( KW_ON! expression {$joinToken.start.getType() != COMMA}? )? )*
: fromSource ( joinToken^ fromSource ( joinCond {$joinToken.start.getType() != COMMA}? )? )*
| uniqueJoinToken^ uniqueJoinSource (COMMA! uniqueJoinSource)+
;

joinCond
@init { gParent.pushMsg("join expression list", state); }
@after { gParent.popMsg(state); }
: KW_ON! expression
| KW_USING LPAREN columnNameList RPAREN -> ^(TOK_USING columnNameList)
;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks pretty good.


uniqueJoinSource
@init { gParent.pushMsg("unique join source", state); }
@after { gParent.popMsg(state); }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ TOK_SETCONFIG;
TOK_DFS;
TOK_ADDFILE;
TOK_ADDJAR;
TOK_USING;
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class Analyzer(
ResolveSubquery ::
ResolveWindowOrder ::
ResolveWindowFrame ::
ResolveNaturalJoin ::
ResolveNaturalAndUsingJoin ::
ExtractWindowExpressions ::
GlobalAggregates ::
ResolveAggregateFunctions ::
Expand Down Expand Up @@ -1329,48 +1329,69 @@ class Analyzer(
}

/**
* Removes natural joins by calculating output columns based on output from two sides,
* Then apply a Project on a normal Join to eliminate natural join.
* Removes natural or using joins by calculating output columns based on output from two sides,
* Then apply a Project on a normal Join to eliminate natural or using join.
*/
object ResolveNaturalJoin extends Rule[LogicalPlan] {
object ResolveNaturalAndUsingJoin extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case j @ Join(left, right, UsingJoin(joinType, usingCols), condition)
if left.resolved && right.resolved && j.duplicateResolved =>
// Resolve the column names referenced in using clause from both the legs of join.
val lCols = usingCols.flatMap(col => left.resolveQuoted(col.name, resolver))
val rCols = usingCols.flatMap(col => right.resolveQuoted(col.name, resolver))
if ((lCols.length == usingCols.length) && (rCols.length == usingCols.length)) {
val joinNames = lCols.map(exp => exp.name)
commonNaturalJoinProcessing(left, right, joinType, joinNames, None)
} else {
j
}
case j @ Join(left, right, NaturalJoin(joinType), condition) if j.resolvedExceptNatural =>
// find common column names from both sides
val joinNames = left.output.map(_.name).intersect(right.output.map(_.name))
val leftKeys = joinNames.map(keyName => left.output.find(_.name == keyName).get)
val rightKeys = joinNames.map(keyName => right.output.find(_.name == keyName).get)
val joinPairs = leftKeys.zip(rightKeys)

// Add joinPairs to joinConditions
val newCondition = (condition ++ joinPairs.map {
case (l, r) => EqualTo(l, r)
}).reduceOption(And)

// columns not in joinPairs
val lUniqueOutput = left.output.filterNot(att => leftKeys.contains(att))
val rUniqueOutput = right.output.filterNot(att => rightKeys.contains(att))

// the output list looks like: join keys, columns from left, columns from right
val projectList = joinType match {
case LeftOuter =>
leftKeys ++ lUniqueOutput ++ rUniqueOutput.map(_.withNullability(true))
case RightOuter =>
rightKeys ++ lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput
case FullOuter =>
// in full outer join, joinCols should be non-null if there is.
val joinedCols = joinPairs.map { case (l, r) => Alias(Coalesce(Seq(l, r)), l.name)() }
joinedCols ++
lUniqueOutput.map(_.withNullability(true)) ++
rUniqueOutput.map(_.withNullability(true))
case Inner =>
rightKeys ++ lUniqueOutput ++ rUniqueOutput
case _ =>
sys.error("Unsupported natural join type " + joinType)
}
// use Project to trim unnecessary fields
Project(projectList, Join(left, right, joinType, newCondition))
commonNaturalJoinProcessing(left, right, joinType, joinNames, condition)
}
}

private def commonNaturalJoinProcessing(
left: LogicalPlan,
right: LogicalPlan,
joinType: JoinType,
joinNames: Seq[String],
condition: Option[Expression]) = {
val leftKeys = joinNames.map(keyName => left.output.find(_.name == keyName).get)
val rightKeys = joinNames.map(keyName => right.output.find(_.name == keyName).get)
val joinPairs = leftKeys.zip(rightKeys)

val newCondition = (condition ++ joinPairs.map(EqualTo.tupled)).reduceOption(And)

// columns not in joinPairs
val lUniqueOutput = left.output.filterNot(att => leftKeys.contains(att))
val rUniqueOutput = right.output.filterNot(att => rightKeys.contains(att))

// the output list looks like: join keys, columns from left, columns from right
val projectList = joinType match {
case LeftOuter =>
leftKeys ++ lUniqueOutput ++ rUniqueOutput.map(_.withNullability(true))
case LeftSemi =>
leftKeys ++ lUniqueOutput
case RightOuter =>
rightKeys ++ lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput
case FullOuter =>
// in full outer join, joinCols should be non-null if there is.
val joinedCols = joinPairs.map { case (l, r) => Alias(Coalesce(Seq(l, r)), l.name)() }
joinedCols ++
lUniqueOutput.map(_.withNullability(true)) ++
rUniqueOutput.map(_.withNullability(true))
case Inner =>
leftKeys ++ lUniqueOutput ++ rUniqueOutput
case _ =>
sys.error("Unsupported natural join type " + joinType)
}
// use Project to trim unnecessary fields
Project(projectList, Join(left, right, joinType, newCondition))
}


}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.UsingJoin
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -109,6 +110,12 @@ trait CheckAnalysis {
s"filter expression '${f.condition.sql}' " +
s"of type ${f.condition.dataType.simpleString} is not a boolean.")

case j @ Join(_, _, UsingJoin(_, cols), _) =>
val from = operator.inputSet.map(_.name).mkString(", ")
failAnalysis(
s"using columns [${cols.mkString(",")}] " +
s"can not be resolved given input columns: [$from] ")

case j @ Join(_, _, _, Some(condition)) if condition.dataType != BooleanType =>
failAnalysis(
s"join condition '${condition.sql}' " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1136,6 +1136,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin)
case FullOuter => f // DO Nothing for Full Outer Join
case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node")
case UsingJoin(_, _) => sys.error("Untransformed Using join node")
}

// push down the join filter into sub query scanning if applicable
Expand Down Expand Up @@ -1171,6 +1172,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
Join(newLeft, newRight, LeftOuter, newJoinCond)
case FullOuter => f
case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node")
case UsingJoin(_, _) => sys.error("Untransformed Using join node")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I'm not really sure what the point of these extra checks is. Is it only to remove a warning? All kinds of things break in the optimizer if the plan is unresolved.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JoinType is sealed, so we need to put something in this pattern match

}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -419,30 +419,47 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
sys.error(s"Unsupported join operation: $other")
}

val joinType = joinToken match {
case "TOK_JOIN" => Inner
case "TOK_CROSSJOIN" => Inner
case "TOK_RIGHTOUTERJOIN" => RightOuter
case "TOK_LEFTOUTERJOIN" => LeftOuter
case "TOK_FULLOUTERJOIN" => FullOuter
case "TOK_LEFTSEMIJOIN" => LeftSemi
case "TOK_UNIQUEJOIN" => noParseRule("Unique Join", node)
case "TOK_ANTIJOIN" => noParseRule("Anti Join", node)
case "TOK_NATURALJOIN" => NaturalJoin(Inner)
case "TOK_NATURALRIGHTOUTERJOIN" => NaturalJoin(RightOuter)
case "TOK_NATURALLEFTOUTERJOIN" => NaturalJoin(LeftOuter)
case "TOK_NATURALFULLOUTERJOIN" => NaturalJoin(FullOuter)
}
val (joinType, joinCondition) = getJoinInfo(joinToken, other, node)

Join(nodeToRelation(relation1),
nodeToRelation(relation2),
joinType,
other.headOption.map(nodeToExpr))

joinCondition)
case _ =>
noParseRule("Relation", node)
}
}

protected def getJoinInfo(
joinToken: String,
joinConditionToken: Seq[ASTNode],
node: ASTNode): (JoinType, Option[Expression]) = {
val joinType = joinToken match {
case "TOK_JOIN" => Inner
case "TOK_CROSSJOIN" => Inner
case "TOK_RIGHTOUTERJOIN" => RightOuter
case "TOK_LEFTOUTERJOIN" => LeftOuter
case "TOK_FULLOUTERJOIN" => FullOuter
case "TOK_LEFTSEMIJOIN" => LeftSemi
case "TOK_UNIQUEJOIN" => noParseRule("Unique Join", node)
case "TOK_ANTIJOIN" => noParseRule("Anti Join", node)
case "TOK_NATURALJOIN" => NaturalJoin(Inner)
case "TOK_NATURALRIGHTOUTERJOIN" => NaturalJoin(RightOuter)
case "TOK_NATURALLEFTOUTERJOIN" => NaturalJoin(LeftOuter)
case "TOK_NATURALFULLOUTERJOIN" => NaturalJoin(FullOuter)
}

joinConditionToken match {
case Token("TOK_USING", columnList :: Nil) :: Nil =>
val colNames = columnList.children.collect {
case Token(name, Nil) => UnresolvedAttribute(name)
}
(UsingJoin(joinType, colNames), None)
/* Join expression specified using ON clause */
case _ => (joinType, joinConditionToken.headOption.map(nodeToExpr))
}
}

protected def nodeToSortOrder(node: ASTNode): SortOrder = node match {
case Token("TOK_TABSORTCOLNAMEASC", sortExpr :: Nil) =>
SortOrder(nodeToExpr(sortExpr), Ascending)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.plans

import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute

object JoinType {
def apply(typ: String): JoinType = typ.toLowerCase.replace("_", "") match {
case "inner" => Inner
Expand Down Expand Up @@ -66,3 +68,9 @@ case class NaturalJoin(tpe: JoinType) extends JoinType {
"Unsupported natural join type " + tpe)
override def sql: String = "NATURAL " + tpe.sql
}

case class UsingJoin(tpe: JoinType, usingColumns: Seq[UnresolvedAttribute]) extends JoinType {
require(Seq(Inner, LeftOuter, LeftSemi, RightOuter, FullOuter).contains(tpe),
"Unsupported using join type " + tpe)
override def sql: String = "USING " + tpe.sql
}
Original file line number Diff line number Diff line change
Expand Up @@ -298,10 +298,11 @@ case class Join(
condition.forall(_.dataType == BooleanType)
}

// if not a natural join, use `resolvedExceptNatural`. if it is a natural join, we still need
// to eliminate natural before we mark it resolved.
// if not a natural join, use `resolvedExceptNatural`. if it is a natural join or
// using join, we still need to eliminate natural or using before we mark it resolved.
override lazy val resolved: Boolean = joinType match {
case NaturalJoin(_) => false
case UsingJoin(_, _) => false
case _ => resolvedExceptNatural
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
Expand All @@ -35,56 +36,81 @@ class ResolveNaturalJoinSuite extends AnalysisTest {
lazy val r3 = LocalRelation(aNotNull, bNotNull)
lazy val r4 = LocalRelation(cNotNull, bNotNull)

test("natural inner join") {
val plan = r1.join(r2, NaturalJoin(Inner), None)
test("natural/using inner join") {
val naturalPlan = r1.join(r2, NaturalJoin(Inner), None)
val usingPlan = r1.join(r2, UsingJoin(Inner, Seq(UnresolvedAttribute("a"))), None)
val expected = r1.join(r2, Inner, Some(EqualTo(a, a))).select(a, b, c)
checkAnalysis(plan, expected)
checkAnalysis(naturalPlan, expected)
checkAnalysis(usingPlan, expected)
}

test("natural left join") {
val plan = r1.join(r2, NaturalJoin(LeftOuter), None)
test("natural/using left join") {
val naturalPlan = r1.join(r2, NaturalJoin(LeftOuter), None)
val usingPlan = r1.join(r2, UsingJoin(LeftOuter, Seq(UnresolvedAttribute("a"))), None)
val expected = r1.join(r2, LeftOuter, Some(EqualTo(a, a))).select(a, b, c)
checkAnalysis(plan, expected)
checkAnalysis(naturalPlan, expected)
checkAnalysis(usingPlan, expected)
}

test("natural right join") {
val plan = r1.join(r2, NaturalJoin(RightOuter), None)
test("natural/using right join") {
val naturalPlan = r1.join(r2, NaturalJoin(RightOuter), None)
val usingPlan = r1.join(r2, UsingJoin(RightOuter, Seq(UnresolvedAttribute("a"))), None)
val expected = r1.join(r2, RightOuter, Some(EqualTo(a, a))).select(a, b, c)
checkAnalysis(plan, expected)
checkAnalysis(naturalPlan, expected)
checkAnalysis(usingPlan, expected)
}

test("natural full outer join") {
val plan = r1.join(r2, NaturalJoin(FullOuter), None)
test("natural/using full outer join") {
val naturalPlan = r1.join(r2, NaturalJoin(FullOuter), None)
val usingPlan = r1.join(r2, UsingJoin(FullOuter, Seq(UnresolvedAttribute("a"))), None)
val expected = r1.join(r2, FullOuter, Some(EqualTo(a, a))).select(
Alias(Coalesce(Seq(a, a)), "a")(), b, c)
checkAnalysis(plan, expected)
checkAnalysis(naturalPlan, expected)
checkAnalysis(usingPlan, expected)
}

test("natural inner join with no nullability") {
val plan = r3.join(r4, NaturalJoin(Inner), None)
test("natural/using inner join with no nullability") {
val naturalPlan = r3.join(r4, NaturalJoin(Inner), None)
val usingPlan = r3.join(r4, UsingJoin(Inner, Seq(UnresolvedAttribute("b"))), None)
val expected = r3.join(r4, Inner, Some(EqualTo(bNotNull, bNotNull))).select(
bNotNull, aNotNull, cNotNull)
checkAnalysis(plan, expected)
checkAnalysis(naturalPlan, expected)
checkAnalysis(usingPlan, expected)
}

test("natural left join with no nullability") {
val plan = r3.join(r4, NaturalJoin(LeftOuter), None)
test("natural/using left join with no nullability") {
val naturalPlan = r3.join(r4, NaturalJoin(LeftOuter), None)
val usingPlan = r3.join(r4, UsingJoin(LeftOuter, Seq(UnresolvedAttribute("b"))), None)
val expected = r3.join(r4, LeftOuter, Some(EqualTo(bNotNull, bNotNull))).select(
bNotNull, aNotNull, c)
checkAnalysis(plan, expected)
checkAnalysis(naturalPlan, expected)
checkAnalysis(usingPlan, expected)
}

test("natural right join with no nullability") {
val plan = r3.join(r4, NaturalJoin(RightOuter), None)
test("natural/using right join with no nullability") {
val naturalPlan = r3.join(r4, NaturalJoin(RightOuter), None)
val usingPlan = r3.join(r4, UsingJoin(RightOuter, Seq(UnresolvedAttribute("b"))), None)
val expected = r3.join(r4, RightOuter, Some(EqualTo(bNotNull, bNotNull))).select(
bNotNull, a, cNotNull)
checkAnalysis(plan, expected)
checkAnalysis(naturalPlan, expected)
checkAnalysis(usingPlan, expected)
}

test("natural full outer join with no nullability") {
val plan = r3.join(r4, NaturalJoin(FullOuter), None)
test("natural/using full outer join with no nullability") {
val naturalPlan = r3.join(r4, NaturalJoin(FullOuter), None)
val usingPlan = r3.join(r4, UsingJoin(FullOuter, Seq(UnresolvedAttribute("b"))), None)
val expected = r3.join(r4, FullOuter, Some(EqualTo(bNotNull, bNotNull))).select(
Alias(Coalesce(Seq(bNotNull, bNotNull)), "b")(), a, c)
checkAnalysis(plan, expected)
checkAnalysis(naturalPlan, expected)
checkAnalysis(usingPlan, expected)
}

test("using unresolved attribute") {
val usingPlan = r1.join(r2, UsingJoin(Inner, Seq(UnresolvedAttribute("d"))), None)
val error = intercept[AnalysisException] {
SimpleAnalyzer.checkAnalysis(usingPlan)
}
assert(error.message.contains(
"using columns ['d] can not be resolved given input columns: [b, a, c]"))
}
}
Loading