Skip to content

Commit 4847f73

Browse files
imback82maropu
authored andcommitted
[SPARK-30298][SQL] Respect aliases in output partitioning of projects and aggregates
### What changes were proposed in this pull request? Currently, in the following scenario, bucket join is not utilized: ```scala val df = (0 until 20).map(i => (i, i)).toDF("i", "j").as("df") df.write.format("parquet").bucketBy(8, "i").saveAsTable("t") sql("CREATE VIEW v AS SELECT * FROM t") sql("SELECT * FROM t a JOIN v b ON a.i = b.i").explain ``` ``` == Physical Plan == *(4) SortMergeJoin [i#13], [i#15], Inner :- *(1) Sort [i#13 ASC NULLS FIRST], false, 0 : +- *(1) Project [i#13, j#14] : +- *(1) Filter isnotnull(i#13) : +- *(1) ColumnarToRow : +- FileScan parquet default.t[i#13,j#14] Batched: true, DataFilters: [isnotnull(i#13)], Format: Parquet, Location: InMemoryFileIndex[file:..., PartitionFilters: [], PushedFilters: [IsNotNull(i)], ReadSchema: struct<i:int,j:int>, SelectedBucketsCount: 8 out of 8 +- *(3) Sort [i#15 ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(i#15, 8), true, [id=#64] <----- Exchange node introduced +- *(2) Project [i#13 AS i#15, j#14 AS j#16] +- *(2) Filter isnotnull(i#13) +- *(2) ColumnarToRow +- FileScan parquet default.t[i#13,j#14] Batched: true, DataFilters: [isnotnull(i#13)], Format: Parquet, Location: InMemoryFileIndex[file:..., PartitionFilters: [], PushedFilters: [IsNotNull(i)], ReadSchema: struct<i:int,j:int>, SelectedBucketsCount: 8 out of 8 ``` Notice that `Exchange` is present. This is because `Project` introduces aliases and `outputPartitioning` and `requiredChildDistribution` do not consider aliases while considering bucket join in `EnsureRequirements`. This PR addresses to allow this scenario. ### Why are the changes needed? This allows bucket join to be utilized in the above example. ### Does this PR introduce any user-facing change? Yes, now with the fix, the `explain` out is as follows: ``` == Physical Plan == *(3) SortMergeJoin [i#13], [i#15], Inner :- *(1) Sort [i#13 ASC NULLS FIRST], false, 0 : +- *(1) Project [i#13, j#14] : +- *(1) Filter isnotnull(i#13) : +- *(1) ColumnarToRow : +- FileScan parquet default.t[i#13,j#14] Batched: true, DataFilters: [isnotnull(i#13)], Format: Parquet, Location: InMemoryFileIndex[file:.., PartitionFilters: [], PushedFilters: [IsNotNull(i)], ReadSchema: struct<i:int,j:int>, SelectedBucketsCount: 8 out of 8 +- *(2) Sort [i#15 ASC NULLS FIRST], false, 0 +- *(2) Project [i#13 AS i#15, j#14 AS j#16] +- *(2) Filter isnotnull(i#13) +- *(2) ColumnarToRow +- FileScan parquet default.t[i#13,j#14] Batched: true, DataFilters: [isnotnull(i#13)], Format: Parquet, Location: InMemoryFileIndex[file:.., PartitionFilters: [], PushedFilters: [IsNotNull(i)], ReadSchema: struct<i:int,j:int>, SelectedBucketsCount: 8 out of 8 ``` Note that the `Exchange` is no longer present. ### How was this patch tested? Closes #26943 from imback82/bucket_alias. Authored-by: Terry Kim <yuminkim@gmail.com> Signed-off-by: Takeshi Yamamuro <yamamuro@apache.org>
1 parent 3228d72 commit 4847f73

File tree

7 files changed

+166
-10
lines changed

7 files changed

+166
-10
lines changed
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.sql.execution
18+
19+
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression, NamedExpression}
20+
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning}
21+
22+
/**
23+
* A trait that handles aliases in the `outputExpressions` to produce `outputPartitioning`
24+
* that satisfies output distribution requirements.
25+
*/
26+
trait AliasAwareOutputPartitioning extends UnaryExecNode {
27+
protected def outputExpressions: Seq[NamedExpression]
28+
29+
final override def outputPartitioning: Partitioning = {
30+
if (hasAlias) {
31+
child.outputPartitioning match {
32+
case h: HashPartitioning => h.copy(expressions = replaceAliases(h.expressions))
33+
case other => other
34+
}
35+
} else {
36+
child.outputPartitioning
37+
}
38+
}
39+
40+
private def hasAlias: Boolean = outputExpressions.collectFirst { case _: Alias => }.isDefined
41+
42+
private def replaceAliases(exprs: Seq[Expression]): Seq[Expression] = {
43+
exprs.map {
44+
case a: AttributeReference => replaceAlias(a).getOrElse(a)
45+
case other => other
46+
}
47+
}
48+
49+
private def replaceAlias(attr: AttributeReference): Option[Attribute] = {
50+
outputExpressions.collectFirst {
51+
case a @ Alias(child: AttributeReference, _) if child.semanticEquals(attr) =>
52+
a.toAttribute
53+
}
54+
}
55+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ case class HashAggregateExec(
5353
initialInputBufferOffset: Int,
5454
resultExpressions: Seq[NamedExpression],
5555
child: SparkPlan)
56-
extends UnaryExecNode with BlockingOperatorWithCodegen {
56+
extends UnaryExecNode with BlockingOperatorWithCodegen with AliasAwareOutputPartitioning {
5757

5858
private[this] val aggregateBufferAttributes = {
5959
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
@@ -75,7 +75,7 @@ case class HashAggregateExec(
7575

7676
override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
7777

78-
override def outputPartitioning: Partitioning = child.outputPartitioning
78+
override protected def outputExpressions: Seq[NamedExpression] = resultExpressions
7979

8080
override def producedAttributes: AttributeSet =
8181
AttributeSet(aggregateAttributes) ++

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ case class ObjectHashAggregateExec(
6767
initialInputBufferOffset: Int,
6868
resultExpressions: Seq[NamedExpression],
6969
child: SparkPlan)
70-
extends UnaryExecNode {
70+
extends UnaryExecNode with AliasAwareOutputPartitioning {
7171

7272
private[this] val aggregateBufferAttributes = {
7373
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
@@ -97,7 +97,7 @@ case class ObjectHashAggregateExec(
9797
}
9898
}
9999

100-
override def outputPartitioning: Partitioning = child.outputPartitioning
100+
override protected def outputExpressions: Seq[NamedExpression] = resultExpressions
101101

102102
protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
103103
val numOutputRows = longMetric("numOutputRows")

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._
2424
import org.apache.spark.sql.catalyst.expressions.aggregate._
2525
import org.apache.spark.sql.catalyst.plans.physical._
2626
import org.apache.spark.sql.catalyst.util.truncatedString
27-
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
27+
import org.apache.spark.sql.execution.{AliasAwareOutputPartitioning, SparkPlan, UnaryExecNode}
2828
import org.apache.spark.sql.execution.metric.SQLMetrics
2929

3030
/**
@@ -38,7 +38,7 @@ case class SortAggregateExec(
3838
initialInputBufferOffset: Int,
3939
resultExpressions: Seq[NamedExpression],
4040
child: SparkPlan)
41-
extends UnaryExecNode {
41+
extends UnaryExecNode with AliasAwareOutputPartitioning {
4242

4343
private[this] val aggregateBufferAttributes = {
4444
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
@@ -66,7 +66,7 @@ case class SortAggregateExec(
6666
groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
6767
}
6868

69-
override def outputPartitioning: Partitioning = child.outputPartitioning
69+
override protected def outputExpressions: Seq[NamedExpression] = resultExpressions
7070

7171
override def outputOrdering: Seq[SortOrder] = {
7272
groupingExpressions.map(SortOrder(_, Ascending))

sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler}
3838

3939
/** Physical plan for Project. */
4040
case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)
41-
extends UnaryExecNode with CodegenSupport {
41+
extends UnaryExecNode with CodegenSupport with AliasAwareOutputPartitioning {
4242

4343
override def output: Seq[Attribute] = projectList.map(_.toAttribute)
4444

@@ -81,7 +81,7 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)
8181

8282
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
8383

84-
override def outputPartitioning: Partitioning = child.outputPartitioning
84+
override protected def outputExpressions: Seq[NamedExpression] = projectList
8585

8686
override def verboseStringWithOperatorId(): String = {
8787
s"""
@@ -92,7 +92,6 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)
9292
}
9393
}
9494

95-
9695
/** Physical plan for Filter. */
9796
case class FilterExec(condition: Expression, child: SparkPlan)
9897
extends UnaryExecNode with CodegenSupport with PredicateHelper {

sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.plans._
2525
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range, Repartition, Sort, Union}
2626
import org.apache.spark.sql.catalyst.plans.physical._
2727
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
28+
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
2829
import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
2930
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec}
3031
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
@@ -937,6 +938,93 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
937938
}
938939
}
939940
}
941+
942+
test("aliases in the project should not introduce extra shuffle") {
943+
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
944+
withTempView("df1", "df2") {
945+
spark.range(10).selectExpr("id AS key", "0").repartition($"key").createTempView("df1")
946+
spark.range(20).selectExpr("id AS key", "0").repartition($"key").createTempView("df2")
947+
val planned = sql(
948+
"""
949+
|SELECT * FROM
950+
| (SELECT key AS k from df1) t1
951+
|INNER JOIN
952+
| (SELECT key AS k from df2) t2
953+
|ON t1.k = t2.k
954+
""".stripMargin).queryExecution.executedPlan
955+
val exchanges = planned.collect { case s: ShuffleExchangeExec => s }
956+
assert(exchanges.size == 2)
957+
}
958+
}
959+
}
960+
961+
test("aliases to expressions should not be replaced") {
962+
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
963+
withTempView("df1", "df2") {
964+
spark.range(10).selectExpr("id AS key", "0").repartition($"key").createTempView("df1")
965+
spark.range(20).selectExpr("id AS key", "0").repartition($"key").createTempView("df2")
966+
val planned = sql(
967+
"""
968+
|SELECT * FROM
969+
| (SELECT key + 1 AS k1 from df1) t1
970+
|INNER JOIN
971+
| (SELECT key + 1 AS k2 from df2) t2
972+
|ON t1.k1 = t2.k2
973+
|""".stripMargin).queryExecution.executedPlan
974+
val exchanges = planned.collect { case s: ShuffleExchangeExec => s }
975+
976+
// Make sure aliases to an expression (key + 1) are not replaced.
977+
Seq("k1", "k2").foreach { alias =>
978+
assert(exchanges.exists(_.outputPartitioning match {
979+
case HashPartitioning(Seq(a: AttributeReference), _) => a.name == alias
980+
case _ => false
981+
}))
982+
}
983+
}
984+
}
985+
}
986+
987+
test("aliases in the aggregate expressions should not introduce extra shuffle") {
988+
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
989+
val t1 = spark.range(10).selectExpr("floor(id/4) as k1")
990+
val t2 = spark.range(20).selectExpr("floor(id/4) as k2")
991+
992+
val agg1 = t1.groupBy("k1").agg(count(lit("1")).as("cnt1"))
993+
val agg2 = t2.groupBy("k2").agg(count(lit("1")).as("cnt2")).withColumnRenamed("k2", "k3")
994+
995+
val planned = agg1.join(agg2, $"k1" === $"k3").queryExecution.executedPlan
996+
997+
assert(planned.collect { case h: HashAggregateExec => h }.nonEmpty)
998+
999+
val exchanges = planned.collect { case s: ShuffleExchangeExec => s }
1000+
assert(exchanges.size == 2)
1001+
}
1002+
}
1003+
1004+
test("aliases in the object hash/sort aggregate expressions should not introduce extra shuffle") {
1005+
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
1006+
Seq(true, false).foreach { useObjectHashAgg =>
1007+
withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> useObjectHashAgg.toString) {
1008+
val t1 = spark.range(10).selectExpr("floor(id/4) as k1")
1009+
val t2 = spark.range(20).selectExpr("floor(id/4) as k2")
1010+
1011+
val agg1 = t1.groupBy("k1").agg(collect_list("k1"))
1012+
val agg2 = t2.groupBy("k2").agg(collect_list("k2")).withColumnRenamed("k2", "k3")
1013+
1014+
val planned = agg1.join(agg2, $"k1" === $"k3").queryExecution.executedPlan
1015+
1016+
if (useObjectHashAgg) {
1017+
assert(planned.collect { case o: ObjectHashAggregateExec => o }.nonEmpty)
1018+
} else {
1019+
assert(planned.collect { case s: SortAggregateExec => s }.nonEmpty)
1020+
}
1021+
1022+
val exchanges = planned.collect { case s: ShuffleExchangeExec => s }
1023+
assert(exchanges.size == 2)
1024+
}
1025+
}
1026+
}
1027+
}
9401028
}
9411029

9421030
// Used for unit-testing EnsureRequirements

sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,20 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils {
604604
}
605605
}
606606

607+
test("bucket join should work with SubqueryAlias plan") {
608+
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") {
609+
withTable("t") {
610+
withView("v") {
611+
spark.range(20).selectExpr("id as i").write.bucketBy(8, "i").saveAsTable("t")
612+
sql("CREATE VIEW v AS SELECT * FROM t").collect()
613+
614+
val plan = sql("SELECT * FROM t a JOIN v b ON a.i = b.i").queryExecution.executedPlan
615+
assert(plan.collect { case exchange: ShuffleExchangeExec => exchange }.isEmpty)
616+
}
617+
}
618+
}
619+
}
620+
607621
test("avoid shuffle when grouping keys are a super-set of bucket keys") {
608622
withTable("bucketed_table") {
609623
df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table")

0 commit comments

Comments
 (0)