Skip to content

Commit 04f7f6d

Browse files
wzhfycloud-fan
authored andcommitted
[SPARK-32748][SQL] Support local property propagation in SubqueryBroadcastExec
### What changes were proposed in this pull request? Since [SPARK-22590](2854091), local property propagation is supported through `SQLExecution.withThreadLocalCaptured` in both `BroadcastExchangeExec` and `SubqueryExec` when computing `relationFuture`. This pr adds the support in `SubqueryBroadcastExec`. ### Why are the changes needed? Local property propagation is missed in `SubqueryBroadcastExec`. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Add a new test. Closes #29589 from wzhfy/thread_local. Authored-by: Zhenhua Wang <wzh_zju@163.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent b0322bf commit 04f7f6d

File tree

2 files changed

+72
-7
lines changed

2 files changed

+72
-7
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryBroadcastExec.scala

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717

1818
package org.apache.spark.sql.execution
1919

20-
import scala.concurrent.{ExecutionContext, Future}
20+
import java.util.concurrent.{Future => JFuture}
21+
22+
import scala.concurrent.ExecutionContext
2123
import scala.concurrent.duration.Duration
2224

2325
import org.apache.spark.rdd.RDD
@@ -26,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions._
2628
import org.apache.spark.sql.catalyst.plans.QueryPlan
2729
import org.apache.spark.sql.execution.joins.{HashedRelation, HashJoin, LongHashedRelation}
2830
import org.apache.spark.sql.execution.metric.SQLMetrics
31+
import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
2932
import org.apache.spark.util.ThreadUtils
3033

3134
/**
@@ -60,10 +63,12 @@ case class SubqueryBroadcastExec(
6063
}
6164

6265
@transient
63-
private lazy val relationFuture: Future[Array[InternalRow]] = {
66+
private lazy val relationFuture: JFuture[Array[InternalRow]] = {
6467
// relationFuture is used in "doExecute". Therefore we can get the execution id correctly here.
6568
val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
66-
Future {
69+
SQLExecution.withThreadLocalCaptured[Array[InternalRow]](
70+
sqlContext.sparkSession,
71+
SubqueryBroadcastExec.executionContext) {
6772
// This will run in another thread. Set the execution id so that we can connect these jobs
6873
// with the correct execution.
6974
SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) {
@@ -89,7 +94,7 @@ case class SubqueryBroadcastExec(
8994

9095
rows
9196
}
92-
}(SubqueryBroadcastExec.executionContext)
97+
}
9398
}
9499

95100
protected override def doPrepare(): Unit = {
@@ -110,5 +115,6 @@ case class SubqueryBroadcastExec(
110115

111116
object SubqueryBroadcastExec {
112117
private[execution] val executionContext = ExecutionContext.fromExecutorService(
113-
ThreadUtils.newDaemonCachedThreadPool("dynamicpruning", 16))
118+
ThreadUtils.newDaemonCachedThreadPool("dynamic-pruning",
119+
SQLConf.get.getConf(StaticSQLConf.BROADCAST_EXCHANGE_MAX_THREAD_THRESHOLD)))
114120
}

sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ import org.apache.spark.{SparkException, SparkFunSuite, TaskContext}
2525
import org.apache.spark.rdd.RDD
2626
import org.apache.spark.sql.{Dataset, SparkSession}
2727
import org.apache.spark.sql.catalyst.InternalRow
28-
import org.apache.spark.sql.catalyst.expressions.Attribute
28+
import org.apache.spark.sql.catalyst.expressions.{Attribute, DynamicPruningExpression}
2929
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
30-
import org.apache.spark.sql.execution.{LeafExecNode, QueryExecution, SparkPlan}
30+
import org.apache.spark.sql.execution.{FileSourceScanExec, InSubqueryExec, LeafExecNode, QueryExecution, SparkPlan, SubqueryBroadcastExec}
3131
import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecution
3232
import org.apache.spark.sql.execution.debug.codegenStringSeq
3333
import org.apache.spark.sql.functions.col
@@ -188,6 +188,65 @@ class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils {
188188
assert(checks2.forall(_.toSeq == Seq(true, true)))
189189
}
190190
}
191+
192+
test("SPARK-32748: propagate local properties to dynamic pruning thread") {
193+
val factTable = "fact_local_prop_dpp"
194+
val dimTable = "dim_local_prop_dpp"
195+
196+
def checkPropertyValueByUdfResult(propKey: String, propValue: String): Unit = {
197+
spark.sparkContext.setLocalProperty(propKey, propValue)
198+
val df = sql(
199+
s"""
200+
|SELECT compare_property_value(f.id, '$propKey', '$propValue') as col
201+
|FROM $factTable f
202+
|INNER JOIN $dimTable s
203+
|ON f.id = s.id AND s.value < 3
204+
""".stripMargin)
205+
206+
val subqueryBroadcastSeq = df.queryExecution.executedPlan.flatMap {
207+
case s: FileSourceScanExec => s.partitionFilters.collect {
208+
case DynamicPruningExpression(InSubqueryExec(_, b: SubqueryBroadcastExec, _, _)) => b
209+
}
210+
case _ => Nil
211+
}
212+
assert(subqueryBroadcastSeq.nonEmpty,
213+
s"Should trigger DPP with a reused broadcast exchange:\n${df.queryExecution}")
214+
215+
assert(df.collect().forall(_.toSeq == Seq(true)))
216+
}
217+
218+
withTable(factTable, dimTable) {
219+
spark.range(10).select($"id", $"id".as("value"))
220+
.write.partitionBy("id").mode("overwrite").saveAsTable(factTable)
221+
spark.range(5).select($"id", $"id".as("value"))
222+
.write.mode("overwrite").saveAsTable(dimTable)
223+
224+
withSQLConf(
225+
StaticSQLConf.BROADCAST_EXCHANGE_MAX_THREAD_THRESHOLD.key -> "1",
226+
SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true",
227+
SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") {
228+
229+
try {
230+
spark.udf.register(
231+
"compare_property_value",
232+
(input: Int, propKey: String, propValue: String) =>
233+
TaskContext.get().getLocalProperty(propKey) == propValue
234+
)
235+
val propKey = "spark.sql.subquery.broadcast.prop.key"
236+
237+
// set local property and assert
238+
val propValue1 = UUID.randomUUID().toString()
239+
checkPropertyValueByUdfResult(propKey, propValue1)
240+
241+
// change local property and re-assert
242+
val propValue2 = UUID.randomUUID().toString()
243+
checkPropertyValueByUdfResult(propKey, propValue2)
244+
} finally {
245+
spark.sessionState.catalog.dropTempFunction("compare_property_value", true)
246+
}
247+
}
248+
}
249+
}
191250
}
192251

193252
case class SQLConfAssertPlan(confToCheck: Seq[(String, String)]) extends LeafExecNode {

0 commit comments

Comments
 (0)