Skip to content

Commit f7f8ec1

Browse files
colinmjjGitHub Enterprise
authored andcommitted
[HADP-55702] Fix incorrect output binding in BroadcastRangeJoinExec (apache#652)
1 parent ad98038 commit f7f8ec1

File tree

3 files changed

+180
-4
lines changed

3 files changed

+180
-4
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors
3535
import org.apache.spark.sql.execution.aggregate.AggUtils
3636
import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
3737
import org.apache.spark.sql.execution.command._
38-
import org.apache.spark.sql.execution.datasources.{WriteFiles, WriteFilesExec}
38+
import org.apache.spark.sql.execution.datasources.{LogicalRelation, WriteFiles, WriteFilesExec}
3939
import org.apache.spark.sql.execution.exchange.{REBALANCE_PARTITIONS_BY_COL, REBALANCE_PARTITIONS_BY_NONE, REPARTITION_BY_COL, REPARTITION_BY_NUM, ShuffleExchangeExec}
4040
import org.apache.spark.sql.execution.joins.{BroadcastRangeJoinExec, RangeInfo}
4141
import org.apache.spark.sql.execution.python._
@@ -414,6 +414,22 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
414414
}
415415
}
416416

417+
// only process the plan with only one LogicalRelation
418+
def getPartitionColumns(plan: LogicalPlan): Seq[String] = {
419+
var partitionColumns = Seq.empty[String]
420+
var findRelation = false
421+
plan foreach {
422+
case LogicalRelation(_, _, catalogTable, _) if catalogTable.isDefined =>
423+
if (findRelation) {
424+
return Seq.empty
425+
}
426+
findRelation = true
427+
partitionColumns = catalogTable.get.partitionColumnNames
428+
case _ =>
429+
}
430+
partitionColumns
431+
}
432+
417433
def createBroadcastRangeJoinExec(leftRangeKeys: Seq[Expression],
418434
rightRangeKeys: Seq[Expression],
419435
equality: Seq[Boolean],
@@ -435,10 +451,29 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
435451
(buildKeys.flatMap(e => QueryPlan.normalizePredicates(e :: Nil, buildPlan.output)),
436452
buildPlan.output.map(QueryPlan.normalizeExpressions(_, buildPlan.output)))
437453

454+
// for partition table, the partition columns should be put at the end of the output
455+
val streamPartitionColumns = getPartitionColumns(streamedPlan)
456+
val streamOutput = if (streamPartitionColumns.isEmpty) {
457+
streamedPlan.output
458+
} else {
459+
var orderedPartitionOutput = Seq.empty[Attribute]
460+
// order partition columns
461+
streamPartitionColumns.foreach(col => {
462+
streamedPlan.output.foreach(att => {
463+
if (att.name == col) {
464+
orderedPartitionOutput = orderedPartitionOutput :+ att
465+
}
466+
})
467+
})
468+
// merge all output with order
469+
streamedPlan.output.filter(
470+
att => !streamPartitionColumns.contains(att.name)) ++ orderedPartitionOutput
471+
}
472+
438473
val (normalizedStreamedKeys, normalizedStreamedPlanOutput) =
439474
(streamedKeys.flatMap(e =>
440-
QueryPlan.normalizePredicates(e :: Nil, streamedPlan.output)),
441-
streamedPlan.output.map(QueryPlan.normalizeExpressions(_, streamedPlan.output)))
475+
QueryPlan.normalizePredicates(e :: Nil, streamOutput)),
476+
streamOutput.map(QueryPlan.normalizeExpressions(_, streamOutput)))
442477

443478
val allOutput = left.output ++ right.output
444479

sql/core/src/test/scala/org/apache/spark/sql/execution/mv/MaterializedViewOptimizerBaseSuite.scala

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import java.io.File
2121

2222
import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
2323
import org.apache.spark.sql.catalyst.plans.logical._
24+
import org.apache.spark.sql.execution.SparkPlan
2425
import org.apache.spark.sql.test.SharedSparkSession
2526
import org.apache.spark.util.Utils
2627

@@ -179,6 +180,26 @@ class MaterializedViewOptimizerBaseSuite extends SparkFunSuite with SharedSparkS
179180
|) using parquet
180181
|""".stripMargin)
181182

183+
sql(
184+
"""
185+
|create table db2.range_t1 (
186+
|col1 int,
187+
|col2 string,
188+
|d1 date,
189+
|d2 date
190+
|) using parquet
191+
|""".stripMargin)
192+
193+
sql(
194+
"""
195+
|create table db2.range_t2 (
196+
|col1 int,
197+
|col2 string,
198+
|d1 date,
199+
|d2 date
200+
|) using parquet
201+
|""".stripMargin)
202+
182203
mvDbPath = Utils.createTempDir(dir.getAbsolutePath, "mv_db")
183204
sql(s"create database mv_db location '${mvDbPath.getAbsolutePath}'")
184205
}
@@ -194,6 +215,8 @@ class MaterializedViewOptimizerBaseSuite extends SparkFunSuite with SharedSparkS
194215
sql("drop table db2.company")
195216
sql("drop table db2.dependents")
196217
sql("drop table db2.locations")
218+
sql("drop table db2.range_t1")
219+
sql("drop table db2.range_t2")
197220
sql("drop database db1")
198221
sql("drop database db2")
199222
sql("drop database mv_db")
@@ -254,12 +277,19 @@ class MaterializedViewOptimizerBaseSuite extends SparkFunSuite with SharedSparkS
254277
mvName: String,
255278
mvQuery: String,
256279
query: String,
257-
expectedResult: String): Unit = {
280+
expectedResult: String,
281+
partitionColumns: Seq[String] = Seq.empty): Unit = {
258282
val mvTablePath = new File(mvDbPath, mvName)
283+
val partitionDesc = if (!partitionColumns.isEmpty) {
284+
"partitioned by (" + partitionColumns.mkString(", ") + ")"
285+
} else {
286+
""
287+
}
259288
try {
260289
sql(
261290
s"""
262291
|create materialized view mv_db.$mvName using parquet
292+
|$partitionDesc
263293
|as $mvQuery
264294
|""".stripMargin)
265295

@@ -281,6 +311,40 @@ class MaterializedViewOptimizerBaseSuite extends SparkFunSuite with SharedSparkS
281311
}
282312
}
283313

314+
protected def checkSparkPlanWithMaterializedView(
315+
mvName: String,
316+
mvQuery: String,
317+
query: String,
318+
partitionColumns: Seq[String] = Seq.empty)(f: SparkPlan => Unit): Unit = {
319+
val mvTablePath = new File(mvDbPath, mvName)
320+
val partitionDesc = if (!partitionColumns.isEmpty) {
321+
"partitioned by (" + partitionColumns.mkString(", ") + ")"
322+
} else {
323+
""
324+
}
325+
try {
326+
sql(
327+
s"""
328+
|create materialized view mv_db.$mvName using parquet
329+
|$partitionDesc
330+
|as $mvQuery
331+
|""".stripMargin)
332+
333+
MaterializedViewManager.cacheMaterializedView()
334+
335+
val dfResult = sql(query)
336+
assert(dfResult.queryExecution.materialized.
337+
getOptimizeTags().contains("MATERIALIZED_VIEW_OPTIMIZED"))
338+
f(dfResult.queryExecution.sparkPlan)
339+
} catch {
340+
case ex: Exception =>
341+
fail(ex)
342+
} finally {
343+
dropMaterializedView(mvName, mvTablePath)
344+
MaterializedViewManager.clearCache()
345+
}
346+
}
347+
284348
protected def withUnSatisfiedMaterializedView(
285349
mvName: String,
286350
mvQuery: String,

sql/core/src/test/scala/org/apache/spark/sql/execution/mv/MaterializedViewOptimizerSuite.scala

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
package org.apache.spark.sql.execution.mv
1919

2020
import org.apache.spark.SparkContext
21+
import org.apache.spark.sql.catalyst.expressions.AttributeReference
22+
import org.apache.spark.sql.execution.joins.BroadcastRangeJoinExec
23+
import org.apache.spark.sql.internal.SQLConf
24+
import org.apache.spark.sql.types.IntegerType
2125

2226
class MaterializedViewOptimizerSuite extends MaterializedViewOptimizerBaseSuite {
2327

@@ -3256,4 +3260,77 @@ class MaterializedViewOptimizerSuite extends MaterializedViewOptimizerBaseSuite
32563260
mvOptimizedCount("testmv") == 2)
32573261
}
32583262
}
3263+
3264+
test("testWithRangerJoin1") {
3265+
withSQLConf(SQLConf.RANGE_JOIN_ENABLED.key -> "true") {
3266+
checkSparkPlanWithMaterializedView("testmv",
3267+
mvQuery =
3268+
"""
3269+
|select * from db2.range_t1
3270+
|""".stripMargin,
3271+
partitionColumns = Seq("d1"),
3272+
query =
3273+
"""
3274+
|select t1.d1, t1.d2, t1.col1, t2.col2
3275+
|from
3276+
| db2.range_t1 t1
3277+
|join
3278+
| db2.range_t2 t2
3279+
| on t1.d1 < t2.d1
3280+
| and t1.d1 >= t2.d1 - interval '30' day
3281+
|""".stripMargin
3282+
) {
3283+
p => {
3284+
assert(p.exists(_.isInstanceOf[BroadcastRangeJoinExec]))
3285+
p.foreach {
3286+
case b: BroadcastRangeJoinExec =>
3287+
val rangeInfo = b.rangeInfo
3288+
val keys = rangeInfo.normalizedStreamedKeys
3289+
val output = rangeInfo.normalizedStreamedPlanOutput
3290+
assert(output.last.exprId == keys.head.asInstanceOf[AttributeReference].exprId)
3291+
case _ =>
3292+
}
3293+
}
3294+
}
3295+
}
3296+
}
3297+
3298+
test("testWithRangerJoin2") {
3299+
withSQLConf(SQLConf.RANGE_JOIN_ENABLED.key -> "true") {
3300+
checkSparkPlanWithMaterializedView("testmv",
3301+
mvQuery =
3302+
"""
3303+
|select * from db2.range_t1
3304+
|""".stripMargin,
3305+
partitionColumns = Seq("d1", "col1"),
3306+
query =
3307+
"""
3308+
|select t1.d1, t1.d2, t1.col1, t2.col2
3309+
|from
3310+
| db2.range_t1 t1
3311+
|join
3312+
| db2.range_t2 t2
3313+
| on t1.d1 < t2.d1
3314+
| and t1.d1 >= t2.d1 - interval '30' day
3315+
|""".stripMargin
3316+
) {
3317+
p => {
3318+
assert(p.exists(_.isInstanceOf[BroadcastRangeJoinExec]))
3319+
p.foreach {
3320+
case b: BroadcastRangeJoinExec =>
3321+
val rangeInfo = b.rangeInfo
3322+
val keys = rangeInfo.normalizedStreamedKeys
3323+
val output = rangeInfo.normalizedStreamedPlanOutput
3324+
// should be d2, d1, col1
3325+
assert(output.size == 3)
3326+
// should be d1
3327+
assert(output(1).exprId == keys.head.asInstanceOf[AttributeReference].exprId)
3328+
// should be col1
3329+
assert(output.last.dataType == IntegerType)
3330+
case _ =>
3331+
}
3332+
}
3333+
}
3334+
}
3335+
}
32593336
}

0 commit comments

Comments
 (0)