Skip to content

Commit c4bbfd1

Browse files
gaborgsomogyiMarcelo Vanzin
authored andcommitted
[SPARK-24063][SS] Add maximum epoch queue threshold for ContinuousExecution
## What changes were proposed in this pull request? Continuous processing is waiting on epochs which are not yet complete (for example one partition is not making progress) and stores pending items in queues. These queues are unbounded and can consume up all the memory easily. In this PR I've added `spark.sql.streaming.continuous.epochBacklogQueueSize` configuration possibility to make them bounded. If the related threshold reached then the query will stop with `IllegalStateException`. ## How was this patch tested? Existing + additional unit tests. Closes #23156 from gaborgsomogyi/SPARK-24063. Authored-by: Gabor Somogyi <gabor.g.somogyi@gmail.com> Signed-off-by: Marcelo Vanzin <vanzin@cloudera.com>
1 parent a6ddc9d commit c4bbfd1

File tree

5 files changed

+177
-3
lines changed

5 files changed

+177
-3
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1436,6 +1436,13 @@ object SQLConf {
14361436
.booleanConf
14371437
.createWithDefault(true)
14381438

1439+
val CONTINUOUS_STREAMING_EPOCH_BACKLOG_QUEUE_SIZE =
1440+
buildConf("spark.sql.streaming.continuous.epochBacklogQueueSize")
1441+
.doc("The max number of entries to be stored in queue to wait for late epochs. " +
1442+
"If this parameter is exceeded by the size of the queue, stream will stop with an error.")
1443+
.intConf
1444+
.createWithDefault(10000)
1445+
14391446
val CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE =
14401447
buildConf("spark.sql.streaming.continuous.executorQueueSize")
14411448
.internal()
@@ -2066,6 +2073,9 @@ class SQLConf extends Serializable with Logging {
20662073

20672074
def literalPickMinimumPrecision: Boolean = getConf(LITERAL_PICK_MINIMUM_PRECISION)
20682075

2076+
def continuousStreamingEpochBacklogQueueSize: Int =
2077+
getConf(CONTINUOUS_STREAMING_EPOCH_BACKLOG_QUEUE_SIZE)
2078+
20692079
def continuousStreamingExecutorQueueSize: Int = getConf(CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE)
20702080

20712081
def continuousStreamingExecutorPollIntervalMs: Long =

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming.continuous
1919

2020
import java.util.UUID
2121
import java.util.concurrent.TimeUnit
22+
import java.util.concurrent.atomic.AtomicReference
2223
import java.util.function.UnaryOperator
2324

2425
import scala.collection.JavaConverters._
@@ -58,6 +59,9 @@ class ContinuousExecution(
5859
// For use only in test harnesses.
5960
private[sql] var currentEpochCoordinatorId: String = _
6061

62+
// Throwable that caused the execution to fail
63+
private val failure: AtomicReference[Throwable] = new AtomicReference[Throwable](null)
64+
6165
override val logicalPlan: LogicalPlan = {
6266
val v2ToRelationMap = MutableMap[StreamingRelationV2, StreamingDataSourceV2Relation]()
6367
var nextSourceId = 0
@@ -261,6 +265,11 @@ class ContinuousExecution(
261265
lastExecution.toRdd
262266
}
263267
}
268+
269+
val f = failure.get()
270+
if (f != null) {
271+
throw f
272+
}
264273
} catch {
265274
case t: Throwable if StreamExecution.isInterruptionException(t, sparkSession.sparkContext) &&
266275
state.get() == RECONFIGURING =>
@@ -373,6 +382,35 @@ class ContinuousExecution(
373382
}
374383
}
375384

385+
/**
386+
* Stores error and stops the query execution thread to terminate the query in new thread.
387+
*/
388+
def stopInNewThread(error: Throwable): Unit = {
389+
if (failure.compareAndSet(null, error)) {
390+
logError(s"Query $prettyIdString received exception $error")
391+
stopInNewThread()
392+
}
393+
}
394+
395+
/**
396+
* Stops the query execution thread to terminate the query in new thread.
397+
*/
398+
private def stopInNewThread(): Unit = {
399+
new Thread("stop-continuous-execution") {
400+
setDaemon(true)
401+
402+
override def run(): Unit = {
403+
try {
404+
ContinuousExecution.this.stop()
405+
} catch {
406+
case e: Throwable =>
407+
logError(e.getMessage, e)
408+
throw e
409+
}
410+
}
411+
}.start()
412+
}
413+
376414
/**
377415
* Stops the query execution thread to terminate the query.
378416
*/

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,9 @@ private[continuous] class EpochCoordinator(
123123
override val rpcEnv: RpcEnv)
124124
extends ThreadSafeRpcEndpoint with Logging {
125125

126+
private val epochBacklogQueueSize =
127+
session.sqlContext.conf.continuousStreamingEpochBacklogQueueSize
128+
126129
private var queryWritesStopped: Boolean = false
127130

128131
private var numReaderPartitions: Int = _
@@ -212,6 +215,7 @@ private[continuous] class EpochCoordinator(
212215
if (!partitionCommits.isDefinedAt((epoch, partitionId))) {
213216
partitionCommits.put((epoch, partitionId), message)
214217
resolveCommitsAtEpoch(epoch)
218+
checkProcessingQueueBoundaries()
215219
}
216220

217221
case ReportPartitionOffset(partitionId, epoch, offset) =>
@@ -223,6 +227,22 @@ private[continuous] class EpochCoordinator(
223227
query.addOffset(epoch, stream, thisEpochOffsets.toSeq)
224228
resolveCommitsAtEpoch(epoch)
225229
}
230+
checkProcessingQueueBoundaries()
231+
}
232+
233+
private def checkProcessingQueueBoundaries() = {
234+
if (partitionOffsets.size > epochBacklogQueueSize) {
235+
query.stopInNewThread(new IllegalStateException("Size of the partition offset queue has " +
236+
"exceeded its maximum"))
237+
}
238+
if (partitionCommits.size > epochBacklogQueueSize) {
239+
query.stopInNewThread(new IllegalStateException("Size of the partition commit queue has " +
240+
"exceeded its maximum"))
241+
}
242+
if (epochsWaitingToBeCommitted.size > epochBacklogQueueSize) {
243+
query.stopInNewThread(new IllegalStateException("Size of the epoch queue has " +
244+
"exceeded its maximum"))
245+
}
226246
}
227247

228248
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {

sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import org.apache.spark.sql.execution.streaming._
2525
import org.apache.spark.sql.execution.streaming.continuous._
2626
import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream
2727
import org.apache.spark.sql.functions._
28+
import org.apache.spark.sql.internal.SQLConf.CONTINUOUS_STREAMING_EPOCH_BACKLOG_QUEUE_SIZE
2829
import org.apache.spark.sql.streaming.{StreamTest, Trigger}
2930
import org.apache.spark.sql.test.TestSparkSession
3031

@@ -343,3 +344,33 @@ class ContinuousMetaSuite extends ContinuousSuiteBase {
343344
}
344345
}
345346
}
347+
348+
class ContinuousEpochBacklogSuite extends ContinuousSuiteBase {
349+
import testImplicits._
350+
351+
override protected def createSparkSession = new TestSparkSession(
352+
new SparkContext(
353+
"local[1]",
354+
"continuous-stream-test-sql-context",
355+
sparkConf.set("spark.sql.testkey", "true")))
356+
357+
// This test forces the backlog to overflow by not standing up enough executors for the query
358+
// to make progress.
359+
test("epoch backlog overflow") {
360+
withSQLConf((CONTINUOUS_STREAMING_EPOCH_BACKLOG_QUEUE_SIZE.key, "10")) {
361+
val df = spark.readStream
362+
.format("rate")
363+
.option("numPartitions", "2")
364+
.option("rowsPerSecond", "500")
365+
.load()
366+
.select('value)
367+
368+
testStream(df, useV2Sink = true)(
369+
StartStream(Trigger.Continuous(1)),
370+
ExpectFailure[IllegalStateException] { e =>
371+
e.getMessage.contains("queue has exceeded its maximum")
372+
}
373+
)
374+
}
375+
}
376+
}

sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala

Lines changed: 78 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,17 @@
1717

1818
package org.apache.spark.sql.streaming.continuous
1919

20+
import org.mockito.{ArgumentCaptor, InOrder}
2021
import org.mockito.ArgumentMatchers.{any, eq => eqTo}
21-
import org.mockito.InOrder
22-
import org.mockito.Mockito.{inOrder, never, verify}
22+
import org.mockito.Mockito._
2323
import org.scalatest.BeforeAndAfterEach
2424
import org.scalatest.mockito.MockitoSugar
2525

2626
import org.apache.spark._
2727
import org.apache.spark.rpc.RpcEndpointRef
2828
import org.apache.spark.sql.LocalSparkSession
2929
import org.apache.spark.sql.execution.streaming.continuous._
30+
import org.apache.spark.sql.internal.SQLConf.CONTINUOUS_STREAMING_EPOCH_BACKLOG_QUEUE_SIZE
3031
import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, PartitionOffset}
3132
import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage
3233
import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite
@@ -43,14 +44,19 @@ class EpochCoordinatorSuite
4344
private var writeSupport: StreamingWrite = _
4445
private var query: ContinuousExecution = _
4546
private var orderVerifier: InOrder = _
47+
private val epochBacklogQueueSize = 10
4648

4749
override def beforeEach(): Unit = {
4850
val stream = mock[ContinuousStream]
4951
writeSupport = mock[StreamingWrite]
5052
query = mock[ContinuousExecution]
5153
orderVerifier = inOrder(writeSupport, query)
5254

53-
spark = new TestSparkSession()
55+
spark = new TestSparkSession(
56+
new SparkContext(
57+
"local[2]", "test-sql-context",
58+
new SparkConf().set("spark.sql.testkey", "true")
59+
.set(CONTINUOUS_STREAMING_EPOCH_BACKLOG_QUEUE_SIZE, epochBacklogQueueSize)))
5460

5561
epochCoordinator
5662
= EpochCoordinatorRef.create(writeSupport, stream, query, "test", 1, spark, SparkEnv.get)
@@ -186,6 +192,66 @@ class EpochCoordinatorSuite
186192
verifyCommitsInOrderOf(List(1, 2, 3, 4, 5))
187193
}
188194

195+
test("several epochs, max epoch backlog reached by partitionOffsets") {
196+
setWriterPartitions(1)
197+
setReaderPartitions(1)
198+
199+
reportPartitionOffset(0, 1)
200+
// Commit messages not arriving
201+
for (i <- 2 to epochBacklogQueueSize + 1) {
202+
reportPartitionOffset(0, i)
203+
}
204+
205+
makeSynchronousCall()
206+
207+
for (i <- 1 to epochBacklogQueueSize + 1) {
208+
verifyNoCommitFor(i)
209+
}
210+
verifyStoppedWithException("Size of the partition offset queue has exceeded its maximum")
211+
}
212+
213+
test("several epochs, max epoch backlog reached by partitionCommits") {
214+
setWriterPartitions(1)
215+
setReaderPartitions(1)
216+
217+
commitPartitionEpoch(0, 1)
218+
// Offset messages not arriving
219+
for (i <- 2 to epochBacklogQueueSize + 1) {
220+
commitPartitionEpoch(0, i)
221+
}
222+
223+
makeSynchronousCall()
224+
225+
for (i <- 1 to epochBacklogQueueSize + 1) {
226+
verifyNoCommitFor(i)
227+
}
228+
verifyStoppedWithException("Size of the partition commit queue has exceeded its maximum")
229+
}
230+
231+
test("several epochs, max epoch backlog reached by epochsWaitingToBeCommitted") {
232+
setWriterPartitions(2)
233+
setReaderPartitions(2)
234+
235+
commitPartitionEpoch(0, 1)
236+
reportPartitionOffset(0, 1)
237+
238+
// For partition 2 epoch 1 messages never arriving
239+
// +2 because the first epoch not yet arrived
240+
for (i <- 2 to epochBacklogQueueSize + 2) {
241+
commitPartitionEpoch(0, i)
242+
reportPartitionOffset(0, i)
243+
commitPartitionEpoch(1, i)
244+
reportPartitionOffset(1, i)
245+
}
246+
247+
makeSynchronousCall()
248+
249+
for (i <- 1 to epochBacklogQueueSize + 2) {
250+
verifyNoCommitFor(i)
251+
}
252+
verifyStoppedWithException("Size of the epoch queue has exceeded its maximum")
253+
}
254+
189255
private def setWriterPartitions(numPartitions: Int): Unit = {
190256
epochCoordinator.askSync[Unit](SetWriterPartitions(numPartitions))
191257
}
@@ -221,4 +287,13 @@ class EpochCoordinatorSuite
221287
private def verifyCommitsInOrderOf(epochs: Seq[Long]): Unit = {
222288
epochs.foreach(verifyCommit)
223289
}
290+
291+
private def verifyStoppedWithException(msg: String): Unit = {
292+
val exceptionCaptor = ArgumentCaptor.forClass(classOf[Throwable]);
293+
verify(query, atLeastOnce()).stopInNewThread(exceptionCaptor.capture())
294+
295+
import scala.collection.JavaConverters._
296+
val throwable = exceptionCaptor.getAllValues.asScala.find(_.getMessage === msg)
297+
assert(throwable != null, "Stream stopped with an exception but expected message is missing")
298+
}
224299
}

0 commit comments

Comments
 (0)