Skip to content

Commit 2854091

Browse files
ajithmecloud-fan
authored andcommitted
[SPARK-22590][SQL] Copy sparkContext.localproperties to child thread in BroadcastExchangeExec.executionContext
### What changes were proposed in this pull request? In `org.apache.spark.sql.execution.exchange.BroadcastExchangeExec#relationFuture` make a copy of `org.apache.spark.SparkContext#localProperties` and pass it to the broadcast execution thread in `org.apache.spark.sql.execution.exchange.BroadcastExchangeExec#executionContext` ### Why are the changes needed? When executing `BroadcastExchangeExec`, the relationFuture is evaluated via a separate thread. The threads inherit the `localProperties` from `sparkContext` as they are the child threads. These threads are created in the executionContext (thread pools). Each Thread pool has a default `keepAliveSeconds` of 60 seconds for idle threads. Scenarios where the thread pool has threads which are idle and reused for a subsequent new query, the thread local properties will not be inherited from spark context (thread properties are inherited only on thread creation) hence end up having old or no properties set. This will cause taskset properties to be missing when properties are transferred by child thread via `sparkContext.runJob/submitJob` ### Does this PR introduce any user-facing change? No ### How was this patch tested? Added UT Closes #27266 from ajithme/broadcastlocalprop. Authored-by: Ajith <ajith2489@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent afaeb29 commit 2854091

File tree

5 files changed

+56
-20
lines changed

5 files changed

+56
-20
lines changed

core/src/main/scala/org/apache/spark/util/ThreadUtils.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.util
1919

2020
import java.util.concurrent._
21+
import java.util.concurrent.{Future => JFuture}
2122
import java.util.concurrent.locks.ReentrantLock
2223

2324
import scala.concurrent.{Awaitable, ExecutionContext, ExecutionContextExecutor, Future}
@@ -304,6 +305,22 @@ private[spark] object ThreadUtils {
304305
}
305306
// scalastyle:on awaitresult
306307

308+
@throws(classOf[SparkException])
309+
def awaitResult[T](future: JFuture[T], atMost: Duration): T = {
310+
try {
311+
atMost match {
312+
case Duration.Inf => future.get()
313+
case _ => future.get(atMost._1, atMost._2)
314+
}
315+
} catch {
316+
case e: SparkFatalException =>
317+
throw e.throwable
318+
case NonFatal(t)
319+
if !t.isInstanceOf[TimeoutException] && !t.isInstanceOf[RpcAbortException] =>
320+
throw new SparkException("Exception thrown in awaitResult: ", t)
321+
}
322+
}
323+
307324
// scalastyle:off awaitready
308325
/**
309326
* Preferred alternative to `Await.ready()`.

sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,9 @@
1717

1818
package org.apache.spark.sql.execution
1919

20-
import java.util.concurrent.ConcurrentHashMap
20+
import java.util.concurrent.{ConcurrentHashMap, ExecutorService, Future => JFuture}
2121
import java.util.concurrent.atomic.AtomicLong
2222

23-
import scala.concurrent.{ExecutionContext, Future}
24-
2523
import org.apache.spark.SparkContext
2624
import org.apache.spark.internal.config.Tests.IS_TESTING
2725
import org.apache.spark.sql.SparkSession
@@ -172,11 +170,11 @@ object SQLExecution {
172170
* SparkContext local properties are forwarded to execution thread
173171
*/
174172
def withThreadLocalCaptured[T](
175-
sparkSession: SparkSession, exec: ExecutionContext)(body: => T): Future[T] = {
173+
sparkSession: SparkSession, exec: ExecutorService) (body: => T): JFuture[T] = {
176174
val activeSession = sparkSession
177175
val sc = sparkSession.sparkContext
178176
val localProps = Utils.cloneProperties(sc.getLocalProperties)
179-
Future {
177+
exec.submit(() => {
180178
val originalSession = SparkSession.getActiveSession
181179
val originalLocalProps = sc.getLocalProperties
182180
SparkSession.setActiveSession(activeSession)
@@ -190,6 +188,6 @@ object SQLExecution {
190188
SparkSession.clearActiveSession()
191189
}
192190
res
193-
}(exec)
191+
})
194192
}
195193
}

sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@
1717

1818
package org.apache.spark.sql.execution
1919

20+
import java.util.concurrent.{Future => JFuture}
2021
import java.util.concurrent.TimeUnit._
2122

2223
import scala.collection.mutable
23-
import scala.concurrent.{ExecutionContext, Future}
24+
import scala.concurrent.{ExecutionContext}
2425
import scala.concurrent.duration.Duration
2526

2627
import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext}
@@ -746,7 +747,7 @@ case class SubqueryExec(name: String, child: SparkPlan)
746747
"collectTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to collect"))
747748

748749
@transient
749-
private lazy val relationFuture: Future[Array[InternalRow]] = {
750+
private lazy val relationFuture: JFuture[Array[InternalRow]] = {
750751
// relationFuture is used in "doExecute". Therefore we can get the execution id correctly here.
751752
val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
752753
SQLExecution.withThreadLocalCaptured[Array[InternalRow]](

sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
3434
import org.apache.spark.sql.execution.joins.HashedRelation
3535
import org.apache.spark.sql.execution.metric.SQLMetrics
3636
import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
37-
import org.apache.spark.util.{SparkFatalException, ThreadUtils}
37+
import org.apache.spark.util.{SparkFatalException, ThreadUtils, Utils}
3838

3939
/**
4040
* A [[BroadcastExchangeExec]] collects, transforms and finally broadcasts the result of
@@ -73,13 +73,8 @@ case class BroadcastExchangeExec(
7373

7474
@transient
7575
private[sql] lazy val relationFuture: Future[broadcast.Broadcast[Any]] = {
76-
// relationFuture is used in "doExecute". Therefore we can get the execution id correctly here.
77-
val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
78-
val task = new Callable[broadcast.Broadcast[Any]]() {
79-
override def call(): broadcast.Broadcast[Any] = {
80-
// This will run in another thread. Set the execution id so that we can connect these jobs
81-
// with the correct execution.
82-
SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) {
76+
SQLExecution.withThreadLocalCaptured[broadcast.Broadcast[Any]](
77+
sqlContext.sparkSession, BroadcastExchangeExec.executionContext) {
8378
try {
8479
// Setup a job group here so later it may get cancelled by groupId if necessary.
8580
sparkContext.setJobGroup(runId.toString, s"broadcast exchange (runId $runId)",
@@ -121,7 +116,7 @@ case class BroadcastExchangeExec(
121116
val broadcasted = sparkContext.broadcast(relation)
122117
longMetric("broadcastTime") += NANOSECONDS.toMillis(
123118
System.nanoTime() - beforeBroadcast)
124-
119+
val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
125120
SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq)
126121
promise.success(broadcasted)
127122
broadcasted
@@ -146,10 +141,7 @@ case class BroadcastExchangeExec(
146141
promise.failure(e)
147142
throw e
148143
}
149-
}
150-
}
151144
}
152-
BroadcastExchangeExec.executionContext.submit[broadcast.Broadcast[Any]](task)
153145
}
154146

155147
override protected def doPrepare(): Unit = {

sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,34 @@ class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils {
159159
}
160160
}
161161
}
162+
163+
test("SPARK-22590 propagate local properties to broadcast execution thread") {
164+
withSQLConf(StaticSQLConf.BROADCAST_EXCHANGE_MAX_THREAD_THRESHOLD.key -> "1") {
165+
val df1 = Seq(true).toDF()
166+
val confKey = "spark.sql.y"
167+
val confValue1 = UUID.randomUUID().toString()
168+
val confValue2 = UUID.randomUUID().toString()
169+
170+
def generateBroadcastDataFrame(confKey: String, confValue: String): Dataset[Boolean] = {
171+
val df = spark.range(1).mapPartitions { _ =>
172+
Iterator(TaskContext.get.getLocalProperty(confKey) == confValue)
173+
}
174+
df.hint("broadcast")
175+
}
176+
177+
// set local propert and assert
178+
val df2 = generateBroadcastDataFrame(confKey, confValue1)
179+
spark.sparkContext.setLocalProperty(confKey, confValue1)
180+
val checks = df1.join(df2).collect()
181+
assert(checks.forall(_.toSeq == Seq(true, true)))
182+
183+
// change local property and re-assert
184+
val df3 = generateBroadcastDataFrame(confKey, confValue2)
185+
spark.sparkContext.setLocalProperty(confKey, confValue2)
186+
val checks2 = df1.join(df3).collect()
187+
assert(checks2.forall(_.toSeq == Seq(true, true)))
188+
}
189+
}
162190
}
163191

164192
case class SQLConfAssertPlan(confToCheck: Seq[(String, String)]) extends LeafExecNode {

0 commit comments

Comments
 (0)