Skip to content

Commit 19f7d72

Browse files
committed
Address all comments
1 parent 33db6df commit 19f7d72

File tree

3 files changed

+12
-18
lines changed

3 files changed

+12
-18
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregates.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,12 @@ object RemoveRedundantAggregates extends Rule[LogicalPlan] with AliasHelper {
4848
newAggregate
4949
}
5050

51-
case agg @ Aggregate(groupingExps, _, child)
52-
if agg.groupOnly && child.deterministic &&
53-
child.distinctKeys.exists(_.subsetOf(ExpressionSet(groupingExps))) =>
51+
case agg @ Aggregate(groupingExps, _, child)
52+
if agg.groupOnly && child.distinctKeys.exists(_.subsetOf(ExpressionSet(groupingExps))) =>
5453
Project(agg.aggregateExpressions, child)
5554

5655
case agg @ Aggregate(groupingExps, aggregateExps, child)
5756
if aggregateExps.forall(a => a.isInstanceOf[Alias] && a.children.forall(_.foldable)) &&
58-
child.deterministic &&
5957
child.distinctKeys.exists(_.subsetOf(ExpressionSet(groupingExps))) =>
6058
Project(agg.aggregateExpressions, child)
6159
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctKeyVisitor.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ object DistinctKeyVisitor extends LogicalPlanVisitor[Set[ExpressionSet]] {
3838
expressionSet.map { expression =>
3939
expression transform {
4040
case expr: Expression =>
41+
// TODO: Expand distinctKeys for redundant aliases on the same expression
4142
aliases
4243
.collectFirst { case a: Alias if a.child.semanticEquals(expr) => a.toAttribute }
4344
.getOrElse(expr)
@@ -60,7 +61,7 @@ object DistinctKeyVisitor extends LogicalPlanVisitor[Set[ExpressionSet]] {
6061
override def visitDistinct(p: Distinct): Set[ExpressionSet] = Set(ExpressionSet(p.output))
6162

6263
override def visitExcept(p: Except): Set[ExpressionSet] =
63-
if (!p.isAll && p.deterministic) Set(ExpressionSet(p.output)) else default(p)
64+
if (!p.isAll) Set(ExpressionSet(p.output)) else default(p)
6465

6566
override def visitExpand(p: Expand): Set[ExpressionSet] = default(p)
6667

@@ -76,7 +77,7 @@ object DistinctKeyVisitor extends LogicalPlanVisitor[Set[ExpressionSet]] {
7677
}
7778

7879
override def visitIntersect(p: Intersect): Set[ExpressionSet] = {
79-
if (!p.isAll && p.deterministic) Set(ExpressionSet(p.output)) else default(p)
80+
if (!p.isAll) Set(ExpressionSet(p.output)) else default(p)
8081
}
8182

8283
override def visitJoin(p: Join): Set[ExpressionSet] = {
@@ -108,7 +109,7 @@ object DistinctKeyVisitor extends LogicalPlanVisitor[Set[ExpressionSet]] {
108109

109110
override def visitProject(p: Project): Set[ExpressionSet] = {
110111
if (p.child.distinctKeys.nonEmpty) {
111-
projectDistinctKeys(p.child.distinctKeys.map(ExpressionSet(_)), p.projectList)
112+
projectDistinctKeys(p.child.distinctKeys, p.projectList)
112113
} else {
113114
default(p)
114115
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregatesSuite.scala

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -168,13 +168,17 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
168168
}
169169
}
170170

171-
test("Keep non-redundant aggregate - upper references non-deterministic non-grouping") {
171+
test("Remove non-redundant aggregate - upper references non-deterministic non-grouping") {
172172
val query = relation
173173
.groupBy('a)('a, ('a + rand(0)) as 'c)
174174
.groupBy('a, 'c)('a, 'c)
175175
.analyze
176+
val expected = relation
177+
.groupBy('a)('a, ('a + rand(0)) as 'c)
178+
.select('a, 'c)
179+
.analyze
176180
val optimized = Optimize.execute(query)
177-
comparePlans(optimized, query)
181+
comparePlans(optimized, expected)
178182
}
179183

180184
test("SPARK-36194: Remove aggregation from left semi/anti join if aggregation the same") {
@@ -283,13 +287,4 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
283287
val originalQuery = Distinct(x.groupBy('a, 'b)('a, TrueLiteral)).analyze
284288
comparePlans(Optimize.execute(originalQuery), originalQuery)
285289
}
286-
287-
test("SPARK-36194: Negative case: Remove aggregation from contains non-deterministic") {
288-
val query = relation
289-
.groupBy('a)('a, (count('b) + rand(0)).as("cnt"))
290-
.groupBy('a, 'cnt)('a, 'cnt)
291-
.analyze
292-
val optimized = Optimize.execute(query)
293-
comparePlans(optimized, query)
294-
}
295290
}

0 commit comments

Comments
 (0)