Skip to content

Commit 33db6df

Browse files
committed
Address all comments
1 parent e12fd14 commit 33db6df

File tree

3 files changed

+41
-13
lines changed

3 files changed

+41
-13
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctKeyVisitor.scala

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical
1919

2020
import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, ExpressionSet, NamedExpression}
2121
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
22-
import org.apache.spark.sql.catalyst.plans.{Inner, LeftSemiOrAnti}
22+
import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, LeftSemiOrAnti, RightOuter}
2323

2424
/**
2525
* A visitor pattern for traversing a [[LogicalPlan]] tree and propagate the distinct attributes.
@@ -83,10 +83,21 @@ object DistinctKeyVisitor extends LogicalPlanVisitor[Set[ExpressionSet]] {
8383
p match {
8484
case Join(_, _, LeftSemiOrAnti(_), _, _) =>
8585
p.left.distinctKeys
86-
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, _, _, _, _, _)
87-
if p.left.distinctKeys.exists(_.subsetOf(ExpressionSet(leftKeys))) &&
88-
p.right.distinctKeys.exists(_.subsetOf(ExpressionSet(rightKeys))) =>
89-
Set(ExpressionSet(leftKeys), ExpressionSet(rightKeys))
86+
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, left, right, _)
87+
if left.distinctKeys.nonEmpty || right.distinctKeys.nonEmpty =>
88+
val rightJoinKeySet = ExpressionSet(rightKeys)
89+
val leftJoinKeySet = ExpressionSet(leftKeys)
90+
joinType match {
91+
case Inner if left.distinctKeys.exists(_.subsetOf(leftJoinKeySet)) &&
92+
right.distinctKeys.exists(_.subsetOf(rightJoinKeySet)) =>
93+
left.distinctKeys ++ right.distinctKeys
94+
case Inner | LeftOuter if right.distinctKeys.exists(_.subsetOf(rightJoinKeySet)) =>
95+
p.left.distinctKeys
96+
case Inner | RightOuter if left.distinctKeys.exists(_.subsetOf(leftJoinKeySet)) =>
97+
p.right.distinctKeys
98+
case _ =>
99+
default(p)
100+
}
90101
case _ => default(p)
91102
}
92103
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregatesSuite.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,13 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
150150
comparePlans(optimized, expected)
151151
}
152152

153+
test("Remove redundant aggregate - upper has contains foldable expressions") {
154+
val originalQuery = x.groupBy('a, 'b)('a, 'b).groupBy('a)('a, TrueLiteral).analyze
155+
val correctAnswer = x.groupBy('a)('a, TrueLiteral).analyze
156+
val optimized = Optimize.execute(originalQuery)
157+
comparePlans(optimized, correctAnswer)
158+
}
159+
153160
test("Keep non-redundant aggregate - upper references agg expression") {
154161
for (agg <- aggregates('b)) {
155162
val query = relation

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctKeyVisitorSuite.scala

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -104,19 +104,29 @@ class DistinctKeyVisitorSuite extends PlanTest {
104104
Distinct(t1).join(Distinct(t2), Inner, Some('a === 'x && 'b === 'y && 'c === 'z)),
105105
Set(ExpressionSet(Seq(a, b, c)), ExpressionSet(Seq(x, y, z))))
106106

107-
checkDistinctAttributes(t1.join(t2, LeftSemi, Some('a === 'x)),
108-
Set.empty)
109107
checkDistinctAttributes(
110-
Distinct(t1).join(Distinct(t2), Inner, Some('a === 'x && 'b === 'y)),
111-
Set.empty)
108+
Distinct(t1).join(Distinct(t2), LeftOuter, Some('a === 'x && 'b === 'y && 'c === 'z)),
109+
Set(ExpressionSet(Seq(a, b, c))))
110+
112111
checkDistinctAttributes(
113-
Distinct(t1).join(Distinct(t2), Inner, Some('a === 'x && 'b === 'y && 'c % 5 === 'z % 5)),
114-
Set.empty)
115-
Seq(LeftOuter, Cross, RightOuter).foreach { joinType =>
112+
Distinct(t1).join(Distinct(t2), RightOuter, Some('a === 'x && 'b === 'y && 'c === 'z)),
113+
Set(ExpressionSet(Seq(x, y, z))))
114+
115+
Seq(Inner, Cross, LeftOuter, RightOuter).foreach { joinType =>
116+
checkDistinctAttributes(t1.join(t2, joinType, Some('a === 'x)),
117+
Set.empty)
118+
checkDistinctAttributes(
119+
Distinct(t1).join(Distinct(t2), joinType, Some('a === 'x && 'b === 'y)),
120+
Set.empty)
116121
checkDistinctAttributes(
117-
Distinct(t1).join(Distinct(t2), joinType, Some('a === 'x && 'b === 'y && 'c === 'z)),
122+
Distinct(t1).join(Distinct(t2), joinType,
123+
Some('a === 'x && 'b === 'y && 'c % 5 === 'z % 5)),
118124
Set.empty)
119125
}
126+
127+
checkDistinctAttributes(
128+
Distinct(t1).join(Distinct(t2), Cross, Some('a === 'x && 'b === 'y && 'c === 'z)),
129+
Set.empty)
120130
}
121131

122132
test("Project's distinct attributes") {

0 commit comments

Comments
 (0)