Skip to content

Commit e12fd14

Browse files
committed
Address comments
1 parent 0621360 commit e12fd14

File tree

5 files changed

+43
-50
lines changed

5 files changed

+43
-50
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregates.scala

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.catalyst.optimizer
1919

2020
import org.apache.spark.sql.catalyst.analysis.PullOutNondeterministic
21-
import org.apache.spark.sql.catalyst.expressions.{AliasHelper, AttributeSet, ExpressionSet}
21+
import org.apache.spark.sql.catalyst.expressions.{Alias, AliasHelper, AttributeSet, ExpressionSet}
2222
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
2323
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
2424
import org.apache.spark.sql.catalyst.rules.Rule
@@ -48,8 +48,15 @@ object RemoveRedundantAggregates extends Rule[LogicalPlan] with AliasHelper {
4848
newAggregate
4949
}
5050

51-
case agg @ Aggregate(groupingExps, _, child) if agg.groupOnly && child.deterministic &&
52-
child.distinctKeys.exists(_.subsetOf(ExpressionSet(groupingExps))) =>
51+
case agg @ Aggregate(groupingExps, _, child)
52+
if agg.groupOnly && child.deterministic &&
53+
child.distinctKeys.exists(_.subsetOf(ExpressionSet(groupingExps))) =>
54+
Project(agg.aggregateExpressions, child)
55+
56+
case agg @ Aggregate(groupingExps, aggregateExps, child)
57+
if aggregateExps.forall(a => a.isInstanceOf[Alias] && a.children.forall(_.foldable)) &&
58+
child.deterministic &&
59+
child.distinctKeys.exists(_.subsetOf(ExpressionSet(groupingExps))) =>
5360
Project(agg.aggregateExpressions, child)
5461
}
5562

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctKeyVisitor.scala

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -80,16 +80,13 @@ object DistinctKeyVisitor extends LogicalPlanVisitor[Set[ExpressionSet]] {
8080
}
8181

8282
override def visitJoin(p: Join): Set[ExpressionSet] = {
83-
p.joinType match {
84-
case LeftSemiOrAnti(_) => p.left.distinctKeys
85-
case Inner =>
86-
p match {
87-
case ExtractEquiJoinKeys(_, leftKeys, rightKeys, _, _, _, _, _)
88-
if p.left.distinctKeys.exists(_.subsetOf(ExpressionSet(leftKeys))) &&
89-
p.right.distinctKeys.exists(_.subsetOf(ExpressionSet(rightKeys))) =>
90-
Set(ExpressionSet(leftKeys), ExpressionSet(rightKeys))
91-
case _ => default(p)
92-
}
83+
p match {
84+
case Join(_, _, LeftSemiOrAnti(_), _, _) =>
85+
p.left.distinctKeys
86+
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, _, _, _, _, _)
87+
if p.left.distinctKeys.exists(_.subsetOf(ExpressionSet(leftKeys))) &&
88+
p.right.distinctKeys.exists(_.subsetOf(ExpressionSet(rightKeys))) =>
89+
Set(ExpressionSet(leftKeys), ExpressionSet(rightKeys))
9390
case _ => default(p)
9491
}
9592
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanDistinctKeys.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.catalyst.plans.logical
1919

2020
import org.apache.spark.sql.catalyst.expressions.ExpressionSet
21+
import org.apache.spark.sql.internal.SQLConf.PROPAGATE_DISTINCT_KEYS_ENABLED
2122

2223
/**
2324
* A trait to add distinct attributes to [[LogicalPlan]]. For example:
@@ -28,6 +29,6 @@ import org.apache.spark.sql.catalyst.expressions.ExpressionSet
2829
*/
2930
trait LogicalPlanDistinctKeys { self: LogicalPlan =>
3031
lazy val distinctKeys: Set[ExpressionSet] = {
31-
if (conf.propagateDistinctKeysEnabled) DistinctKeyVisitor.visit(self) else Set.empty
32+
if (conf.getConf(PROPAGATE_DISTINCT_KEYS_ENABLED)) DistinctKeyVisitor.visit(self) else Set.empty
3233
}
3334
}

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3908,8 +3908,6 @@ class SQLConf extends Serializable with Logging {
39083908

39093909
def constraintPropagationEnabled: Boolean = getConf(CONSTRAINT_PROPAGATION_ENABLED)
39103910

3911-
def propagateDistinctKeysEnabled: Boolean = getConf(PROPAGATE_DISTINCT_KEYS_ENABLED)
3912-
39133911
def escapedStringLiterals: Boolean = getConf(ESCAPED_STRING_LITERALS)
39143912

39153913
def fileCompressionFactor: Double = getConf(FILE_COMPRESSION_FACTOR)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregatesSuite.scala

Lines changed: 24 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.dsl.plans._
2323
import org.apache.spark.sql.catalyst.expressions.{Expression, PythonUDF}
2424
import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
2525
import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi, PlanTest}
26-
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
26+
import org.apache.spark.sql.catalyst.plans.logical.{Distinct, LocalRelation, LogicalPlan}
2727
import org.apache.spark.sql.catalyst.rules.RuleExecutor
2828
import org.apache.spark.sql.types.IntegerType
2929

@@ -230,7 +230,7 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
230230
}
231231
}
232232

233-
test("SPARK-36194: Remove aggregation from aggregation") {
233+
test("SPARK-36194: Child distinct keys is the subset of required keys") {
234234
val originalQuery = relation
235235
.groupBy('a)('a, count('b).as("cnt"))
236236
.groupBy('a, 'cnt)('a, 'cnt)
@@ -243,48 +243,38 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
243243
comparePlans(optimized, correctAnswer)
244244
}
245245

246-
test("SPARK-36194: Negative case: The grouping expressions not same") {
247-
Seq(LeftSemi, LeftAnti).foreach { joinType =>
248-
val originalQuery = x.groupBy('a, 'b)('a, 'b)
249-
.join(y, joinType, Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr))
250-
.groupBy("x.a".attr)("x.a".attr)
251-
252-
val optimized = Optimize.execute(originalQuery.analyze)
253-
comparePlans(optimized, originalQuery.analyze)
254-
}
246+
test("SPARK-36194: Child distinct keys are subsets and aggregateExpressions are foldable") {
247+
val originalQuery = x.groupBy('a, 'b)('a, 'b)
248+
.join(y, LeftSemi, Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr))
249+
.groupBy("x.a".attr, "x.b".attr)(TrueLiteral)
250+
.analyze
251+
val correctAnswer = x.groupBy('a, 'b)('a, 'b)
252+
.join(y, LeftSemi, Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr))
253+
.select(TrueLiteral)
254+
.analyze
255+
val optimized = Optimize.execute(originalQuery)
256+
comparePlans(optimized, correctAnswer)
255257
}
256258

257-
test("SPARK-36194: Negative case: The aggregate expressions not the sub aggregateExprs") {
259+
test("SPARK-36194: Negative case: child distinct keys is not the subset of required keys") {
258260
Seq(LeftSemi, LeftAnti).foreach { joinType =>
259-
val originalQuery = x.groupBy('a, 'b)('a, 'b)
261+
val originalQuery1 = x.groupBy('a, 'b)('a, 'b)
260262
.join(y, joinType, Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr))
261-
.groupBy("x.a".attr, "x.b".attr)(TrueLiteral)
262-
263-
val optimized = Optimize.execute(originalQuery.analyze)
264-
comparePlans(optimized, originalQuery.analyze)
265-
}
266-
}
263+
.groupBy("x.a".attr)("x.a".attr)
264+
.analyze
265+
comparePlans(Optimize.execute(originalQuery1), originalQuery1)
267266

268-
test("SPARK-36194: Negative case: The aggregate expressions not same") {
269-
Seq(LeftSemi, LeftAnti).foreach { joinType =>
270-
val originalQuery = x.groupBy('a, 'b)('a, 'b)
267+
val originalQuery2 = x.groupBy('a, 'b)('a, 'b)
271268
.join(y, joinType, Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr))
272269
.groupBy("x.a".attr)(count("x.b".attr))
273-
274-
val optimized = Optimize.execute(originalQuery.analyze)
275-
comparePlans(optimized, originalQuery.analyze)
270+
.analyze
271+
comparePlans(Optimize.execute(originalQuery2), originalQuery2)
276272
}
277273
}
278274

279-
test("SPARK-36194: Negative case: The aggregate expressions with Literal") {
280-
Seq(LeftSemi, LeftAnti).foreach { joinType =>
281-
val originalQuery = x.groupBy('a, 'b)('a, TrueLiteral)
282-
.join(y, joinType, Some("x.a".attr === "y.a".attr))
283-
.groupBy("x.a".attr)("x.a".attr, TrueLiteral)
284-
285-
val optimized = Optimize.execute(originalQuery.analyze)
286-
comparePlans(optimized, originalQuery.analyze)
287-
}
275+
test("SPARK-36194: Negative case: child distinct keys is empty") {
276+
val originalQuery = Distinct(x.groupBy('a, 'b)('a, TrueLiteral)).analyze
277+
comparePlans(Optimize.execute(originalQuery), originalQuery)
288278
}
289279

290280
test("SPARK-36194: Negative case: Remove aggregation from contains non-deterministic") {

0 commit comments

Comments
 (0)