From 5d6e9dd6b1212823dd3aa148935723151027f911 Mon Sep 17 00:00:00 2001 From: Changgyoo Park Date: Thu, 20 Jun 2024 08:49:43 +0900 Subject: [PATCH] [SPARK-47986][CONNECT][FOLLOW-UP] Unable to create a new session when the default session is closed by the server ### What changes were proposed in this pull request? This is a Scala port of https://github.com/apache/spark/pull/46221 and https://github.com/apache/spark/pull/46435. A client is unaware of a server restart or the server having closed the client until it receives an error. However, at this point, the client in unable to create a new session to the same connect endpoint, since the stale session is still recorded as the active and default session. With this change, when the server communicates that the session has changed via a GRPC error, the session and the respective client are marked as stale, thereby allowing a new default connection can be created via the session builder. In some cases, particularly when running older versions of the Spark cluster (3.5), the error actually manifests as a mismatch in the observed server-side session id between calls. With this fix, we also capture this case and ensure that this case is also handled. ### Why are the changes needed? Being unable to use getOrCreate() after an error is unacceptable and should be fixed. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? ./build/sbt testOnly *SparkSessionE2ESuite ### Was this patch authored or co-authored using generative AI tooling? No Closes #47008 from changgyoopark-db/SPARK-47986. Authored-by: Changgyoo Park Signed-off-by: Hyukjin Kwon --- .../org/apache/spark/sql/SparkSession.scala | 35 +++++++++++++---- .../spark/sql/SparkSessionE2ESuite.scala | 39 +++++++++++++++++++ .../connect/client/ResponseValidator.scala | 29 +++++++++++++- .../connect/client/SparkConnectClient.scala | 11 ++++++ 4 files changed, 105 insertions(+), 9 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 19c5a3f14c64f..80336fb1eaea4 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -829,10 +829,16 @@ object SparkSession extends Logging { /** * Set the (global) default [[SparkSession]], and (thread-local) active [[SparkSession]] when - * they are not set yet. + * they are not set yet or the associated [[SparkConnectClient]] is unusable. */ private def setDefaultAndActiveSession(session: SparkSession): Unit = { - defaultSession.compareAndSet(null, session) + val currentDefault = defaultSession.getAcquire + if (currentDefault == null || !currentDefault.client.isSessionValid) { + // Update `defaultSession` if it is null or the contained session is not valid. There is a + // chance that the following `compareAndSet` fails if a new default session has just been set, + // but that does not matter since that event has happened after this method was invoked. + defaultSession.compareAndSet(currentDefault, session) + } if (getActiveSession.isEmpty) { setActiveSession(session) } @@ -972,7 +978,7 @@ object SparkSession extends Logging { def appName(name: String): Builder = this private def tryCreateSessionFromClient(): Option[SparkSession] = { - if (client != null) { + if (client != null && client.isSessionValid) { Option(new SparkSession(client, planIdGenerator)) } else { None @@ -1024,7 +1030,16 @@ object SparkSession extends Logging { */ def getOrCreate(): SparkSession = { val session = tryCreateSessionFromClient() - .getOrElse(sessions.get(builder.configuration)) + .getOrElse({ + var existingSession = sessions.get(builder.configuration) + if (!existingSession.client.isSessionValid) { + // If the cached session has become invalid, e.g., due to a server restart, the cache + // entry is invalidated. + sessions.invalidate(builder.configuration) + existingSession = sessions.get(builder.configuration) + } + existingSession + }) setDefaultAndActiveSession(session) applyOptions(session) session @@ -1032,11 +1047,13 @@ object SparkSession extends Logging { } /** - * Returns the default SparkSession. + * Returns the default SparkSession. If the previously set default SparkSession becomes + * unusable, returns None. * * @since 3.5.0 */ - def getDefaultSession: Option[SparkSession] = Option(defaultSession.get()) + def getDefaultSession: Option[SparkSession] = + Option(defaultSession.get()).filter(_.client.isSessionValid) /** * Sets the default SparkSession. @@ -1057,11 +1074,13 @@ object SparkSession extends Logging { } /** - * Returns the active SparkSession for the current thread. + * Returns the active SparkSession for the current thread. If the previously set active + * SparkSession becomes unusable, returns None. * * @since 3.5.0 */ - def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get()) + def getActiveSession: Option[SparkSession] = + Option(activeThreadSession.get()).filter(_.client.isSessionValid) /** * Changes the SparkSession that will be returned in this thread and its children when diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala index 203b1295005af..b28aa905c7a29 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala @@ -382,4 +382,43 @@ class SparkSessionE2ESuite extends ConnectFunSuite with RemoteSparkSession { .create() } } + + test("SPARK-47986: get or create after session changed") { + val remote = s"sc://localhost:$serverPort" + + SparkSession.clearDefaultSession() + SparkSession.clearActiveSession() + + val session1 = SparkSession + .builder() + .remote(remote) + .getOrCreate() + + assert(session1 eq SparkSession.getActiveSession.get) + assert(session1 eq SparkSession.getDefaultSession.get) + assert(session1.range(3).collect().length == 3) + + session1.client.hijackServerSideSessionIdForTesting("-testing") + + val e = intercept[SparkException] { + session1.range(3).analyze + } + + assert(e.getMessage.contains("[INVALID_HANDLE.SESSION_CHANGED]")) + assert(!session1.client.isSessionValid) + assert(SparkSession.getActiveSession.isEmpty) + assert(SparkSession.getDefaultSession.isEmpty) + + val session2 = SparkSession + .builder() + .remote(remote) + .getOrCreate() + + assert(session1 ne session2) + assert(session2.client.isSessionValid) + assert(session2 eq SparkSession.getActiveSession.get) + assert(session2 eq SparkSession.getDefaultSession.get) + assert(session2.range(3).collect().length == 3) + } + } diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ResponseValidator.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ResponseValidator.scala index 29272c96132bc..42c3387335be9 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ResponseValidator.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ResponseValidator.scala @@ -16,7 +16,10 @@ */ package org.apache.spark.sql.connect.client +import java.util.concurrent.atomic.AtomicBoolean + import com.google.protobuf.GeneratedMessageV3 +import io.grpc.{Status, StatusRuntimeException} import io.grpc.stub.StreamObserver import org.apache.spark.internal.Logging @@ -30,6 +33,12 @@ class ResponseValidator extends Logging { // do not use server-side streaming. private var serverSideSessionId: Option[String] = None + // Indicates whether the client and the client information on the server correspond to each other + // This flag being false means that the server has restarted and lost the client information, or + // there is a logic error in the code; both cases, the user should establish a new connection to + // the server. Access to the value has to be synchronized since it can be shared. + private val isSessionActive: AtomicBoolean = new AtomicBoolean(true) + // Returns the server side session ID, used to send it back to the server in the follow-up // requests so the server can validate it session id against the previous requests. def getServerSideSessionId: Option[String] = serverSideSessionId @@ -42,8 +51,25 @@ class ResponseValidator extends Logging { serverSideSessionId = Some(serverSideSessionId.getOrElse("") + suffix) } + /** + * Returns true if the session is valid on both the client and the server. + */ + private[sql] def isSessionValid: Boolean = { + // An active session is considered valid. + isSessionActive.getAcquire + } + def verifyResponse[RespT <: GeneratedMessageV3](fn: => RespT): RespT = { - val response = fn + val response = + try { + fn + } catch { + case e: StatusRuntimeException + if e.getStatus.getCode == Status.Code.INTERNAL && + e.getMessage.contains("[INVALID_HANDLE.SESSION_CHANGED]") => + isSessionActive.setRelease(false) + throw e + } val field = response.getDescriptorForType.findFieldByName("server_side_session_id") // If the field does not exist, we ignore it. New / Old message might not contain it and this // behavior allows us to be compatible. @@ -54,6 +80,7 @@ class ResponseValidator extends Logging { serverSideSessionId match { case Some(id) => if (value != id) { + isSessionActive.setRelease(false) throw new IllegalStateException( s"Server side session ID changed from $id to $value") } diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala index b5eda024bfb3c..7c3108fdb1b0e 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala @@ -71,6 +71,17 @@ private[sql] class SparkConnectClient( stubState.responseValidator.hijackServerSideSessionIdForTesting(suffix) } + /** + * Returns true if the session is valid on both the client and the server. A session becomes + * invalid if the server side information about the client, e.g., session ID, does not + * correspond to the actual client state. + */ + private[sql] def isSessionValid: Boolean = { + // The last known state of the session is store in `responseValidator`, because it is where the + // client gets responses from the server. + stubState.responseValidator.isSessionValid + } + private[sql] val artifactManager: ArtifactManager = { new ArtifactManager(configuration, sessionId, bstub, stub) }