Skip to content

Commit c31f727

Browse files
wangyumGitHub Enterprise
authored and
GitHub Enterprise
committed
[CARMEL-5845] Add a logical plan visitor to propagate the distinct attributes (#884)
* Remove the aggregation from left semi/anti join if the same aggregation has already been done on left side * Add more test * grouping -> groupingExps * Add DistinctAttributesVisitor * Fix test name * Improve DistinctAttributesVisitor * Fix test. * DistinctKeyVisitor * Address comments * Fix scala 2.13 * Address comments * Address all comments * Address all comments * fix * fix test * [SPARK-38489][SQL] Aggregate.groupOnly support foldable expressions ### What changes were proposed in this pull request? This pr makes `Aggregate.groupOnly` support foldable expressions. ### Why are the changes needed? Improve query performance. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit test. Closes #35795 from wangyum/SPARK-38489. Authored-by: Yuming Wang <yumwang@ebay.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com> (cherry picked from commit bcf7849) * Remove * fix * Update DistinctKeyVisitor.scala
1 parent 8adae4c commit c31f727

File tree

11 files changed

+565
-7
lines changed

11 files changed

+565
-7
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
108108
RewriteCorrelatedScalarSubquery,
109109
EliminateSerialization,
110110
RemoveRedundantAliases,
111+
RemoveRedundantAggregates,
111112
UnwrapCastInBinaryComparison,
112113
RemoveNoopOperators,
113114
SimplifyExtractValueOps,
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.optimizer
19+
20+
import org.apache.spark.sql.catalyst.expressions.{Alias, AliasHelper, ExpressionSet}
21+
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
22+
import org.apache.spark.sql.catalyst.rules.Rule
23+
24+
/**
25+
* Remove redundant aggregates from a query plan. A redundant aggregate is an aggregate whose
26+
* only goal is to keep distinct values, while its parent aggregate would ignore duplicate values.
27+
*/
28+
object RemoveRedundantAggregates extends Rule[LogicalPlan] with AliasHelper {
29+
def apply(plan: LogicalPlan): LogicalPlan = plan.transform {
30+
case agg @ Aggregate(groupingExps, _, child)
31+
if agg.groupOnly && child.distinctKeys.exists(_.subsetOf(ExpressionSet(groupingExps))) =>
32+
Project(agg.aggregateExpressions, child)
33+
}
34+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -169,16 +169,16 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper {
169169
if (j.joinType == newJoinType) f else Filter(condition, j.copy(joinType = newJoinType))
170170

171171
case a @ Aggregate(_, _, Join(left, _, LeftOuter, _, _))
172-
if a.groupOnly && a.references.subsetOf(AttributeSet(left.output)) =>
172+
if a.groupOnly && a.references.subsetOf(left.outputSet) =>
173173
a.copy(child = left)
174174
case a @ Aggregate(_, _, Join(_, right, RightOuter, _, _))
175-
if a.groupOnly && a.references.subsetOf(AttributeSet(right.output)) =>
175+
if a.groupOnly && a.references.subsetOf(right.outputSet) =>
176176
a.copy(child = right)
177177
case a @ Aggregate(_, _, p @ Project(_, Join(left, _, LeftOuter, _, _)))
178-
if a.groupOnly && a.references.subsetOf(AttributeSet(left.output)) =>
178+
if a.groupOnly && p.references.subsetOf(left.outputSet) =>
179179
a.copy(child = p.copy(child = left))
180180
case a @ Aggregate(_, _, p @ Project(_, Join(_, right, RightOuter, _, _)))
181-
if a.groupOnly && a.references.subsetOf(AttributeSet(right.output)) =>
181+
if a.groupOnly && p.references.subsetOf(right.outputSet) =>
182182
a.copy(child = p.copy(child = right))
183183
}
184184
}
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.plans.logical
19+
20+
import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, ExpressionSet, NamedExpression}
21+
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
22+
import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, LeftSemiOrAnti, RightOuter}
23+
24+
/**
25+
* A visitor pattern for traversing a [[LogicalPlan]] tree and propagate the distinct attributes.
26+
*/
27+
object DistinctKeyVisitor extends LogicalPlanVisitor[Set[ExpressionSet]] {
28+
29+
private def projectDistinctKeys(
30+
keys: Set[ExpressionSet], projectList: Seq[NamedExpression]): Set[ExpressionSet] = {
31+
val outputSet = ExpressionSet(projectList.map(_.toAttribute))
32+
val aliases = projectList.filter(_.isInstanceOf[Alias])
33+
if (aliases.isEmpty) {
34+
keys.filter(_.subsetOf(outputSet)).filter(_.nonEmpty)
35+
} else {
36+
val expressions = keys.flatMap(_.toSet)
37+
projectList.filter {
38+
case a: Alias => expressions.exists(_.semanticEquals(a.child))
39+
case ne => expressions.exists(_.semanticEquals(ne))
40+
}.toSet.subsets(keys.map(_.size).min).take(50).filter { s =>
41+
val references = s.map {
42+
case a: Alias => a.child
43+
case ne => ne
44+
}
45+
keys.exists(_.equals(ExpressionSet(references)))
46+
}.map(s => ExpressionSet(s.map(_.toAttribute))).filter(_.nonEmpty).toSet
47+
}
48+
}
49+
50+
override def default(p: LogicalPlan): Set[ExpressionSet] = Set.empty[ExpressionSet]
51+
52+
override def visitAggregate(p: Aggregate): Set[ExpressionSet] = {
53+
val groupingExps = ExpressionSet(p.groupingExpressions) // handle group by a, a
54+
projectDistinctKeys(Set(groupingExps), p.aggregateExpressions)
55+
}
56+
57+
override def visitDistinct(p: Distinct): Set[ExpressionSet] = Set(ExpressionSet(p.output))
58+
59+
override def visitExcept(p: Except): Set[ExpressionSet] =
60+
if (!p.isAll) Set(ExpressionSet(p.output)) else default(p)
61+
62+
override def visitExpand(p: Expand): Set[ExpressionSet] = default(p)
63+
64+
override def visitFilter(p: Filter): Set[ExpressionSet] = p.child.distinctKeys
65+
66+
override def visitGenerate(p: Generate): Set[ExpressionSet] = default(p)
67+
68+
override def visitGlobalLimit(p: GlobalLimit): Set[ExpressionSet] = {
69+
p.maxRows match {
70+
case Some(value) if value <= 1 => Set(ExpressionSet(p.output))
71+
case _ => p.child.distinctKeys
72+
}
73+
}
74+
75+
override def visitIntersect(p: Intersect): Set[ExpressionSet] = {
76+
if (!p.isAll) Set(ExpressionSet(p.output)) else default(p)
77+
}
78+
79+
override def visitJoin(p: Join): Set[ExpressionSet] = {
80+
p match {
81+
case Join(_, _, LeftSemiOrAnti(_), _, _) =>
82+
p.left.distinctKeys
83+
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, left, right, _)
84+
if left.distinctKeys.nonEmpty || right.distinctKeys.nonEmpty =>
85+
val rightJoinKeySet = ExpressionSet(rightKeys)
86+
val leftJoinKeySet = ExpressionSet(leftKeys)
87+
joinType match {
88+
case Inner if left.distinctKeys.exists(_.subsetOf(leftJoinKeySet)) &&
89+
right.distinctKeys.exists(_.subsetOf(rightJoinKeySet)) =>
90+
left.distinctKeys ++ right.distinctKeys
91+
case Inner | LeftOuter if right.distinctKeys.exists(_.subsetOf(rightJoinKeySet)) =>
92+
p.left.distinctKeys
93+
case Inner | RightOuter if left.distinctKeys.exists(_.subsetOf(leftJoinKeySet)) =>
94+
p.right.distinctKeys
95+
case _ =>
96+
default(p)
97+
}
98+
case _ => default(p)
99+
}
100+
}
101+
102+
override def visitLocalLimit(p: LocalLimit): Set[ExpressionSet] = p.child.distinctKeys
103+
104+
override def visitPivot(p: Pivot): Set[ExpressionSet] = default(p)
105+
106+
override def visitProject(p: Project): Set[ExpressionSet] = {
107+
if (p.child.distinctKeys.nonEmpty) {
108+
projectDistinctKeys(p.child.distinctKeys, p.projectList)
109+
} else {
110+
default(p)
111+
}
112+
}
113+
114+
override def visitRepartition(p: Repartition): Set[ExpressionSet] = p.child.distinctKeys
115+
116+
override def visitRepartitionByExpr(p: RepartitionByExpression): Set[ExpressionSet] =
117+
p.child.distinctKeys
118+
119+
override def visitSample(p: Sample): Set[ExpressionSet] = {
120+
if (!p.withReplacement) p.child.distinctKeys else default(p)
121+
}
122+
123+
override def visitScriptTransform(p: ScriptTransformation): Set[ExpressionSet] = default(p)
124+
125+
override def visitUnion(p: Union): Set[ExpressionSet] = default(p)
126+
127+
override def visitWindow(p: Window): Set[ExpressionSet] = p.child.distinctKeys
128+
129+
override def visitSort(p: Sort): Set[ExpressionSet] = p.child.distinctKeys
130+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ abstract class LogicalPlan
3030
extends QueryPlan[LogicalPlan]
3131
with AnalysisHelper
3232
with LogicalPlanStats
33+
with LogicalPlanDistinctKeys
3334
with QueryPlanConstraints
3435
with Logging {
3536

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.plans.logical
19+
20+
import org.apache.spark.sql.catalyst.expressions.ExpressionSet
21+
import org.apache.spark.sql.internal.SQLConf.PROPAGATE_DISTINCT_KEYS_ENABLED
22+
23+
/**
24+
* A trait to add distinct attributes to [[LogicalPlan]]. For example:
25+
* {{{
26+
* SELECT a, b, SUM(c) FROM Tab1 GROUP BY a, b
27+
* // returns a, b
28+
* }}}
29+
*/
30+
trait LogicalPlanDistinctKeys { self: LogicalPlan =>
31+
lazy val distinctKeys: Set[ExpressionSet] = {
32+
if (conf.getConf(PROPAGATE_DISTINCT_KEYS_ENABLED)) DistinctKeyVisitor.visit(self) else Set.empty
33+
}
34+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,12 @@ case class Aggregate(
648648

649649
// Whether this Aggregate operator is group only. For example: SELECT a, a FROM t GROUP BY a
650650
private[sql] def groupOnly: Boolean = {
651-
aggregateExpressions.forall(a => groupingExpressions.exists(g => a.semanticEquals(g)))
651+
// aggregateExpressions can be empty through Dateset.agg,
652+
// so we should also check groupingExpressions is non empty
653+
groupingExpressions.nonEmpty && aggregateExpressions.map {
654+
case Alias(child, _) => child
655+
case e => e
656+
}.forall(a => a.foldable || groupingExpressions.exists(g => a.semanticEquals(g)))
652657
}
653658
}
654659

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -731,6 +731,15 @@ object SQLConf {
731731
.booleanConf
732732
.createWithDefault(true)
733733

734+
val PROPAGATE_DISTINCT_KEYS_ENABLED =
735+
buildConf("spark.sql.optimizer.propagateDistinctKeys.enabled")
736+
.internal()
737+
.doc("When true, the query optimizer will propagate a set of distinct attributes from the " +
738+
"current node and use it to optimize query.")
739+
.version("3.3.0")
740+
.booleanConf
741+
.createWithDefault(true)
742+
734743
val ESCAPED_STRING_LITERALS = buildConf("spark.sql.parser.escapedStringLiterals")
735744
.internal()
736745
.doc("When true, string literals (including regex patterns) remain escaped in our SQL " +

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.AnalysisTest
2121
import org.apache.spark.sql.catalyst.dsl.expressions._
2222
import org.apache.spark.sql.catalyst.dsl.plans._
2323
import org.apache.spark.sql.catalyst.expressions.Literal
24+
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
2425
import org.apache.spark.sql.catalyst.plans.{LeftOuter, RightOuter}
2526
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Distinct, LocalRelation, LogicalPlan}
2627
import org.apache.spark.sql.catalyst.rules.RuleExecutor
@@ -122,7 +123,42 @@ class AggregateOptimizeSuite extends AnalysisTest {
122123
Optimize.execute(
123124
x.join(y, LeftOuter, Some("x.a".attr === "y.a".attr))
124125
.groupBy("x.a".attr)("x.a".attr, Literal(1)).analyze),
125-
x.join(y, LeftOuter, Some("x.a".attr === "y.a".attr))
126-
.groupBy("x.a".attr)("x.a".attr, Literal(1)).analyze)
126+
x.groupBy("x.a".attr)("x.a".attr, Literal(1)).analyze)
127+
}
128+
129+
test("SPARK-37292: Removes outer join if it only has DISTINCT on streamed side with alias") {
130+
val x = testRelation.subquery('x)
131+
val y = testRelation.subquery('y)
132+
comparePlans(
133+
Optimize.execute(
134+
Distinct(x.join(y, LeftOuter, Some("x.a".attr === "y.a".attr))
135+
.select("x.b".attr.as("newAlias"))).analyze),
136+
x.select("x.b".attr.as("newAlias")).groupBy("newAlias".attr)("newAlias".attr).analyze)
137+
138+
comparePlans(
139+
Optimize.execute(
140+
Distinct(x.join(y, RightOuter, Some("x.a".attr === "y.a".attr))
141+
.select("y.b".attr.as("newAlias"))).analyze),
142+
y.select("y.b".attr.as("newAlias")).groupBy("newAlias".attr)("newAlias".attr).analyze)
143+
144+
comparePlans(
145+
Optimize.execute(
146+
Distinct(x.join(y, LeftOuter, Some("x.a".attr === "y.a".attr))
147+
.select("x.b".attr.as("newAlias1"), "x.b".attr.as("newAlias2"))).analyze),
148+
x.select("x.b".attr.as("newAlias1"), "x.b".attr.as("newAlias2"))
149+
.groupBy("newAlias1".attr, "newAlias2".attr)("newAlias1".attr, "newAlias2".attr).analyze)
150+
}
151+
152+
test("SPARK-38489: Aggregate.groupOnly support foldable expressions") {
153+
val x = testRelation.subquery('x)
154+
val y = testRelation.subquery('y)
155+
comparePlans(
156+
Optimize.execute(
157+
Distinct(x.join(y, LeftOuter, Some("x.a".attr === "y.a".attr))
158+
.select("x.b".attr, TrueLiteral, FalseLiteral.as("newAlias")))
159+
.analyze),
160+
x.select("x.b".attr, TrueLiteral, FalseLiteral.as("newAlias"))
161+
.groupBy("x.b".attr)("x.b".attr, TrueLiteral, FalseLiteral.as("newAlias"))
162+
.analyze)
127163
}
128164
}

0 commit comments

Comments
 (0)