Skip to content

Commit cbe6ead

Browse files
committed
[SPARK-29352][SQL][SS] Track active streaming queries in the SparkSession.sharedState
### What changes were proposed in this pull request? This moves the tracking of active queries from a per SparkSession state, to the shared SparkSession for better safety in isolated Spark Session environments. ### Why are the changes needed? We have checks to prevent the restarting of the same stream on the same spark session, but we can actually make that better in multi-tenant environments by actually putting that state in the SharedState instead of SessionState. This would allow a more comprehensive check for multi-tenant clusters. ### Does this PR introduce any user-facing change? No ### How was this patch tested? Added tests to StreamingQueryManagerSuite Closes #26018 from brkyvz/sharedStreamingQueryManager. Lead-authored-by: Burak Yavuz <burak@databricks.com> Co-authored-by: Burak Yavuz <brkyvz@gmail.com> Signed-off-by: Burak Yavuz <brkyvz@gmail.com>
1 parent 8c34690 commit cbe6ead

File tree

3 files changed

+102
-10
lines changed

3 files changed

+102
-10
lines changed

sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
package org.apache.spark.sql.internal
1919

2020
import java.net.URL
21-
import java.util.Locale
21+
import java.util.{Locale, UUID}
22+
import java.util.concurrent.ConcurrentHashMap
2223

2324
import scala.reflect.ClassTag
2425
import scala.util.control.NonFatal
@@ -33,6 +34,7 @@ import org.apache.spark.sql.catalyst.catalog._
3334
import org.apache.spark.sql.execution.CacheManager
3435
import org.apache.spark.sql.execution.ui.{SQLAppStatusListener, SQLAppStatusStore, SQLTab}
3536
import org.apache.spark.sql.internal.StaticSQLConf._
37+
import org.apache.spark.sql.streaming.StreamingQueryManager
3638
import org.apache.spark.status.ElementTrackingStore
3739
import org.apache.spark.util.Utils
3840

@@ -110,6 +112,12 @@ private[sql] class SharedState(
110112
*/
111113
val cacheManager: CacheManager = new CacheManager
112114

115+
/**
116+
* A map of active streaming queries to the session specific StreamingQueryManager that manages
117+
* the lifecycle of that stream.
118+
*/
119+
private[sql] val activeStreamingQueries = new ConcurrentHashMap[UUID, StreamingQueryManager]()
120+
113121
/**
114122
* A status store to query SQL status/metrics of this Spark application, based on SQL-specific
115123
* [[org.apache.spark.scheduler.SparkListenerEvent]]s.

sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -352,8 +352,10 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo
352352
}
353353
}
354354

355-
// Make sure no other query with same id is active
356-
if (activeQueries.values.exists(_.id == query.id)) {
355+
// Make sure no other query with same id is active across all sessions
356+
val activeOption =
357+
Option(sparkSession.sharedState.activeStreamingQueries.putIfAbsent(query.id, this))
358+
if (activeOption.isDefined || activeQueries.values.exists(_.id == query.id)) {
357359
throw new IllegalStateException(
358360
s"Cannot start query with id ${query.id} as another query with same id is " +
359361
s"already active. Perhaps you are attempting to restart a query from checkpoint " +
@@ -370,19 +372,15 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo
370372
query.streamingQuery.start()
371373
} catch {
372374
case e: Throwable =>
373-
activeQueriesLock.synchronized {
374-
activeQueries -= query.id
375-
}
375+
unregisterTerminatedStream(query.id)
376376
throw e
377377
}
378378
query
379379
}
380380

381381
/** Notify (by the StreamingQuery) that the query has been terminated */
382382
private[sql] def notifyQueryTermination(terminatedQuery: StreamingQuery): Unit = {
383-
activeQueriesLock.synchronized {
384-
activeQueries -= terminatedQuery.id
385-
}
383+
unregisterTerminatedStream(terminatedQuery.id)
386384
awaitTerminationLock.synchronized {
387385
if (lastTerminatedQuery == null || terminatedQuery.exception.nonEmpty) {
388386
lastTerminatedQuery = terminatedQuery
@@ -391,4 +389,12 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo
391389
}
392390
stateStoreCoordinator.deactivateInstances(terminatedQuery.runId)
393391
}
392+
393+
private def unregisterTerminatedStream(terminatedQueryId: UUID): Unit = {
394+
activeQueriesLock.synchronized {
395+
// remove from shared state only if the streaming query manager also matches
396+
sparkSession.sharedState.activeStreamingQueries.remove(terminatedQueryId, this)
397+
activeQueries -= terminatedQueryId
398+
}
399+
}
394400
}

sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql.streaming
1919

20+
import java.io.File
2021
import java.util.concurrent.CountDownLatch
2122

2223
import scala.concurrent.Future
@@ -28,7 +29,7 @@ import org.scalatest.time.Span
2829
import org.scalatest.time.SpanSugar._
2930

3031
import org.apache.spark.SparkException
31-
import org.apache.spark.sql.Dataset
32+
import org.apache.spark.sql.{Dataset, Encoders}
3233
import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation
3334
import org.apache.spark.sql.execution.streaming._
3435
import org.apache.spark.sql.streaming.util.BlockingSource
@@ -242,6 +243,83 @@ class StreamingQueryManagerSuite extends StreamTest {
242243
}
243244
}
244245

246+
testQuietly("can't start a streaming query with the same name in the same session") {
247+
val ds1 = makeDataset._2
248+
val ds2 = makeDataset._2
249+
val queryName = "abc"
250+
251+
val query1 = ds1.writeStream.format("noop").queryName(queryName).start()
252+
try {
253+
val e = intercept[IllegalArgumentException] {
254+
ds2.writeStream.format("noop").queryName(queryName).start()
255+
}
256+
assert(e.getMessage.contains("query with that name is already active"))
257+
} finally {
258+
query1.stop()
259+
}
260+
}
261+
262+
testQuietly("can start a streaming query with the same name in a different session") {
263+
val session2 = spark.cloneSession()
264+
265+
val ds1 = MemoryStream(Encoders.INT, spark.sqlContext).toDS()
266+
val ds2 = MemoryStream(Encoders.INT, session2.sqlContext).toDS()
267+
val queryName = "abc"
268+
269+
val query1 = ds1.writeStream.format("noop").queryName(queryName).start()
270+
val query2 = ds2.writeStream.format("noop").queryName(queryName).start()
271+
272+
query1.stop()
273+
query2.stop()
274+
}
275+
276+
testQuietly("can't start multiple instances of the same streaming query in the same session") {
277+
withTempDir { dir =>
278+
val (ms1, ds1) = makeDataset
279+
val (ms2, ds2) = makeDataset
280+
val chkLocation = new File(dir, "_checkpoint").getCanonicalPath
281+
val dataLocation = new File(dir, "data").getCanonicalPath
282+
283+
val query1 = ds1.writeStream.format("parquet")
284+
.option("checkpointLocation", chkLocation).start(dataLocation)
285+
ms1.addData(1, 2, 3)
286+
try {
287+
val e = intercept[IllegalStateException] {
288+
ds2.writeStream.format("parquet")
289+
.option("checkpointLocation", chkLocation).start(dataLocation)
290+
}
291+
assert(e.getMessage.contains("same id"))
292+
} finally {
293+
query1.stop()
294+
}
295+
}
296+
}
297+
298+
testQuietly(
299+
"can't start multiple instances of the same streaming query in the different sessions") {
300+
withTempDir { dir =>
301+
val session2 = spark.cloneSession()
302+
303+
val ms1 = MemoryStream(Encoders.INT, spark.sqlContext)
304+
val ds2 = MemoryStream(Encoders.INT, session2.sqlContext).toDS()
305+
val chkLocation = new File(dir, "_checkpoint").getCanonicalPath
306+
val dataLocation = new File(dir, "data").getCanonicalPath
307+
308+
val query1 = ms1.toDS().writeStream.format("parquet")
309+
.option("checkpointLocation", chkLocation).start(dataLocation)
310+
ms1.addData(1, 2, 3)
311+
try {
312+
val e = intercept[IllegalStateException] {
313+
ds2.writeStream.format("parquet")
314+
.option("checkpointLocation", chkLocation).start(dataLocation)
315+
}
316+
assert(e.getMessage.contains("same id"))
317+
} finally {
318+
query1.stop()
319+
}
320+
}
321+
}
322+
245323
/** Run a body of code by defining a query on each dataset */
246324
private def withQueriesOn(datasets: Dataset[_]*)(body: Seq[StreamingQuery] => Unit): Unit = {
247325
failAfter(streamingTimeout) {

0 commit comments

Comments
 (0)